Cache bad target hostname:port to avoid reconnection attempts

If connection to a hostname:port fails, add it to a `TTLCache` with
60 sec expiration time. Subsequent requests to the same hostname:port
return really quickly as we check the cache and avoid trying a new
network connection.

The short expiration time guarantees that if a host becomes OK again,
we'll be able to connect to it quickly.

Adding `cachetools` dependency was necessary as there isn't any other
way to have an expiring in-memory cache using stdlib. The library
doesn't have any other dependencies, it has good test coverage and seems
maintained. It also supports Python 3.7.
This commit is contained in:
Vangelis Banos 2019-05-09 10:03:16 +00:00
parent 41d7f0be53
commit 89d987a181
3 changed files with 37 additions and 2 deletions

View File

@ -34,6 +34,7 @@ deps = [
'cryptography>=2.3',
'idna>=2.5',
'PyYAML>=5.1',
'cachetools',
]
try:
import concurrent.futures

View File

@ -77,6 +77,7 @@ import time
import collections
import cProfile
from urllib3.util import is_connection_dropped
from urllib3.exceptions import NewConnectionError
import doublethink
class ProxyingRecorder(object):
@ -252,6 +253,9 @@ class MitmProxyHandler(http_server.BaseHTTPRequestHandler):
query=u.query, fragment=u.fragment))
self.hostname = urlcanon.normalize_host(host).decode('ascii')
def _hostname_port_cache_key(self):
return '%s:%s' % (self.hostname, self.port)
def _connect_to_remote_server(self):
'''
Connect to destination.
@ -380,7 +384,15 @@ class MitmProxyHandler(http_server.BaseHTTPRequestHandler):
else:
self._determine_host_port()
assert self.url
# Check if target hostname:port is in `bad_hostnames_ports` cache
# to avoid retrying to connect.
with self.server.bad_hostnames_ports_lock:
hostname_port = self._hostname_port_cache_key()
if hostname_port in self.server.bad_hostnames_ports:
self.logger.info('Cannot connect to %s (cache)',
hostname_port)
self.send_error(502, 'message timed out')
return
# Connect to destination
self._connect_to_remote_server()
except warcprox.RequestBlockedByRule as e:
@ -388,6 +400,15 @@ class MitmProxyHandler(http_server.BaseHTTPRequestHandler):
self.logger.info("%r: %r", self.requestline, e)
return
except Exception as e:
# If connection fails, add hostname:port to cache to avoid slow
# subsequent reconnection attempts. `NewConnectionError` can be
# caused by many types of errors which are handled by urllib3.
if type(e) in (socket.timeout, NewConnectionError):
with self.server.bad_hostnames_ports_lock:
host_port = self._hostname_port_cache_key()
self.server.bad_hostnames_ports[host_port] = 1
self.logger.info('bad_hostnames_ports cache size: %d',
len(self.server.bad_hostnames_ports))
self.logger.error(
"problem processing request %r: %r",
self.requestline, e, exc_info=True)
@ -527,7 +548,13 @@ class MitmProxyHandler(http_server.BaseHTTPRequestHandler):
# put it back in the pool to reuse it later.
if not is_connection_dropped(self._remote_server_conn):
self._conn_pool._put_conn(self._remote_server_conn)
except:
except Exception as e:
if type(e) in (socket.timeout, NewConnectionError):
with self.server.bad_hostnames_ports_lock:
hostname_port = self._hostname_port_cache_key()
self.server.bad_hostnames_ports[hostname_port] = 1
self.logger.info('bad_hostnames_ports cache size: %d',
len(self.server.bad_hostnames_ports))
self._remote_server_conn.sock.shutdown(socket.SHUT_RDWR)
self._remote_server_conn.sock.close()
raise

View File

@ -48,6 +48,8 @@ import tempfile
import hashlib
import doublethink
import re
from threading import RLock
from cachetools import TTLCache
class WarcProxyHandler(warcprox.mitmproxy.MitmProxyHandler):
'''
@ -431,6 +433,11 @@ class SingleThreadedWarcProxy(http_server.HTTPServer, object):
self.status_callback = status_callback
self.stats_db = stats_db
self.options = options
# TTLCache is not thread-safe. Access to the shared cache from multiple
# threads must be properly synchronized with an RLock according to ref:
# https://cachetools.readthedocs.io/en/latest/
self.bad_hostnames_ports = TTLCache(maxsize=1024, ttl=60)
self.bad_hostnames_ports_lock = RLock()
self.remote_connection_pool = PoolManager(
num_pools=max(round(options.max_threads / 6), 200) if options.max_threads else 200)
server_address = (