make test server multithreaded so tests will pass

This commit is contained in:
Noah Levitt 2018-04-05 17:59:10 -07:00
parent 385014c322
commit 38e2a87f31

View File

@ -50,6 +50,7 @@ import io
import gzip import gzip
import mock import mock
import email.message import email.message
import socketserver
try: try:
import http.server as http_server import http.server as http_server
@ -323,17 +324,20 @@ def cert(request):
finally: finally:
f.close() f.close()
class UhhhServer(http_server.HTTPServer): # We need this test server to accept multiple simultaneous connections in order
def get_request(self): # to avoid mysterious looking test failures like these:
try: # https://travis-ci.org/internetarchive/warcprox/builds/362892231
return self.socket.accept() # This is because we can't guarantee (without jumping through hoops) that
except: # MitmProxyHandler._proxy_request() returns the connection to the pool before
logging.error('socket.accept() raised exception', exc_info=True) # the next request tries to get a connection from the pool in
raise # MitmProxyHandler._connect_to_remote_server(). (Unless we run warcprox
# single-threaded for these tests, which maybe we should consider?)
class ThreadedHTTPServer(socketserver.ThreadingMixIn, http_server.HTTPServer):
pass
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def http_daemon(request): def http_daemon(request):
http_daemon = UhhhServer( http_daemon = ThreadedHTTPServer(
('localhost', 0), RequestHandlerClass=_TestHttpRequestHandler) ('localhost', 0), RequestHandlerClass=_TestHttpRequestHandler)
logging.info('starting http://{}:{}'.format(http_daemon.server_address[0], http_daemon.server_address[1])) logging.info('starting http://{}:{}'.format(http_daemon.server_address[0], http_daemon.server_address[1]))
http_daemon_thread = threading.Thread(name='HttpDaemonThread', http_daemon_thread = threading.Thread(name='HttpDaemonThread',
@ -352,7 +356,7 @@ def http_daemon(request):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def https_daemon(request, cert): def https_daemon(request, cert):
# http://www.piware.de/2011/01/creating-an-https-server-in-python/ # http://www.piware.de/2011/01/creating-an-https-server-in-python/
https_daemon = http_server.HTTPServer(('localhost', 0), https_daemon = ThreadedHTTPServer(('localhost', 0),
RequestHandlerClass=_TestHttpRequestHandler) RequestHandlerClass=_TestHttpRequestHandler)
https_daemon.socket = ssl.wrap_socket(https_daemon.socket, certfile=cert, server_side=True) https_daemon.socket = ssl.wrap_socket(https_daemon.socket, certfile=cert, server_side=True)
logging.info('starting https://{}:{}'.format(https_daemon.server_address[0], https_daemon.server_address[1])) logging.info('starting https://{}:{}'.format(https_daemon.server_address[0], https_daemon.server_address[1]))