diff --git a/setup.py b/setup.py index 2f3aeda..61f474e 100755 --- a/setup.py +++ b/setup.py @@ -34,6 +34,7 @@ deps = [ 'cryptography>=2.3', 'idna>=2.5', 'PyYAML>=5.1', + 'cachetools', ] try: import concurrent.futures diff --git a/tests/test_warcprox.py b/tests/test_warcprox.py index 6c49f0a..4323a6c 100755 --- a/tests/test_warcprox.py +++ b/tests/test_warcprox.py @@ -1986,6 +1986,10 @@ def test_socket_timeout_response( def test_empty_response( warcprox_, http_daemon, https_daemon, archiving_proxies, playback_proxies): + # localhost:server_port was added to the `bad_hostnames_ports` cache by + # previous tests and this causes subsequent tests to fail. We clear it. + warcprox_.proxy.bad_hostnames_ports.clear() + url = 'http://localhost:%s/empty-response' % http_daemon.server_port response = requests.get(url, proxies=archiving_proxies, verify=False) assert response.status_code == 502 @@ -2001,6 +2005,10 @@ def test_payload_digest(warcprox_, http_daemon): Tests that digest is of RFC2616 "entity body" (transfer-decoded but not content-decoded) ''' + # localhost:server_port was added to the `bad_hostnames_ports` cache by + # previous tests and this causes subsequent tests to fail. We clear it. + warcprox_.proxy.bad_hostnames_ports.clear() + class HalfMockedMitm(warcprox.mitmproxy.MitmProxyHandler): def __init__(self, url): self.path = url diff --git a/warcprox/mitmproxy.py b/warcprox/mitmproxy.py index b29dcaf..b158162 100644 --- a/warcprox/mitmproxy.py +++ b/warcprox/mitmproxy.py @@ -77,6 +77,7 @@ import time import collections import cProfile from urllib3.util import is_connection_dropped +from urllib3.exceptions import NewConnectionError import doublethink class ProxyingRecorder(object): @@ -252,6 +253,9 @@ class MitmProxyHandler(http_server.BaseHTTPRequestHandler): query=u.query, fragment=u.fragment)) self.hostname = urlcanon.normalize_host(host).decode('ascii') + def _hostname_port_cache_key(self): + return '%s:%s' % (self.hostname, self.port) + def _connect_to_remote_server(self): ''' Connect to destination. @@ -380,7 +384,17 @@ class MitmProxyHandler(http_server.BaseHTTPRequestHandler): else: self._determine_host_port() assert self.url - + # Check if target hostname:port is in `bad_hostnames_ports` cache + # to avoid retrying to connect. cached is a tuple containing + # (status_code, error message) + cached = None + hostname_port = self._hostname_port_cache_key() + with self.server.bad_hostnames_ports_lock: + cached = self.server.bad_hostnames_ports.get(hostname_port) + if cached: + self.logger.info('Cannot connect to %s (cache)', hostname_port) + self.send_error(cached[0], cached[1]) + return # Connect to destination self._connect_to_remote_server() except warcprox.RequestBlockedByRule as e: @@ -388,6 +402,15 @@ class MitmProxyHandler(http_server.BaseHTTPRequestHandler): self.logger.info("%r: %r", self.requestline, e) return except Exception as e: + # If connection fails, add hostname:port to cache to avoid slow + # subsequent reconnection attempts. `NewConnectionError` can be + # caused by many types of errors which are handled by urllib3. + if type(e) in (socket.timeout, NewConnectionError): + host_port = self._hostname_port_cache_key() + with self.server.bad_hostnames_ports_lock: + self.server.bad_hostnames_ports[host_port] = (500, str(e)) + self.logger.info('bad_hostnames_ports cache size: %d', + len(self.server.bad_hostnames_ports)) self.logger.error( "problem processing request %r: %r", self.requestline, e, exc_info=True) @@ -527,7 +550,19 @@ class MitmProxyHandler(http_server.BaseHTTPRequestHandler): # put it back in the pool to reuse it later. if not is_connection_dropped(self._remote_server_conn): self._conn_pool._put_conn(self._remote_server_conn) - except: + except Exception as e: + # A common error is to connect to the remote server successfully + # but raise a `RemoteDisconnected` exception when trying to begin + # downloading. Its caused by prox_rec_res.begin(...) which calls + # http_client._read_status(). In that case, the host is also bad + # and we must add it to `bad_hostnames_ports` cache. + if type(e) == http_client.RemoteDisconnected: + host_port = self._hostname_port_cache_key() + with self.server.bad_hostnames_ports_lock: + self.server.bad_hostnames_ports[host_port] = (502, str(e)) + self.logger.info('bad_hostnames_ports cache size: %d', + len(self.server.bad_hostnames_ports)) + self._remote_server_conn.sock.shutdown(socket.SHUT_RDWR) self._remote_server_conn.sock.close() raise diff --git a/warcprox/playback.py b/warcprox/playback.py index 91f86aa..8bfa42f 100644 --- a/warcprox/playback.py +++ b/warcprox/playback.py @@ -42,6 +42,7 @@ from warcprox.mitmproxy import MitmProxyHandler import warcprox import sqlite3 import threading +from cachetools import TTLCache class PlaybackProxyHandler(MitmProxyHandler): logger = logging.getLogger("warcprox.playback.PlaybackProxyHandler") @@ -219,6 +220,8 @@ class PlaybackProxy(socketserver.ThreadingMixIn, http_server.HTTPServer): self.playback_index_db = playback_index_db self.warcs_dir = options.directory self.options = options + self.bad_hostnames_ports = TTLCache(maxsize=1024, ttl=60) + self.bad_hostnames_ports_lock = threading.RLock() def server_activate(self): http_server.HTTPServer.server_activate(self) diff --git a/warcprox/warcproxy.py b/warcprox/warcproxy.py index 8898898..2d072b9 100644 --- a/warcprox/warcproxy.py +++ b/warcprox/warcproxy.py @@ -48,6 +48,8 @@ import tempfile import hashlib import doublethink import re +from threading import RLock +from cachetools import TTLCache class WarcProxyHandler(warcprox.mitmproxy.MitmProxyHandler): ''' @@ -431,6 +433,11 @@ class SingleThreadedWarcProxy(http_server.HTTPServer, object): self.status_callback = status_callback self.stats_db = stats_db self.options = options + # TTLCache is not thread-safe. Access to the shared cache from multiple + # threads must be properly synchronized with an RLock according to ref: + # https://cachetools.readthedocs.io/en/latest/ + self.bad_hostnames_ports = TTLCache(maxsize=1024, ttl=60) + self.bad_hostnames_ports_lock = RLock() self.remote_connection_pool = PoolManager( num_pools=max(round(options.max_threads / 6), 200) if options.max_threads else 200) server_address = (