From 555517ab78b5ebf22a244bf27ea105441d4a751a Mon Sep 17 00:00:00 2001 From: Noah Levitt Date: Tue, 19 Nov 2013 17:12:58 -0800 Subject: [PATCH] WarcproxController to ease use of warcprox as a module --- setup.py | 5 +- warcprox/tests/test_warcproxy.py | 57 +++++++++ warcprox/warcprox.py | 204 ++++++++++++++++++++----------- 3 files changed, 196 insertions(+), 70 deletions(-) mode change 100644 => 100755 setup.py create mode 100644 warcprox/tests/test_warcproxy.py diff --git a/setup.py b/setup.py old mode 100644 new mode 100755 index 63cef59..4e37145 --- a/setup.py +++ b/setup.py @@ -12,7 +12,8 @@ setuptools.setup(name='warcprox', long_description=open('README.md').read(), license='GPL', packages=['warcprox'], - install_requires=['pyopenssl', 'gdbm', 'warctools'], + install_requires=['pyopenssl', 'warctools'], # gdbm/dbhash? scripts=['bin/dump-anydbm', 'bin/warcprox'], - zip_safe=False) + zip_safe=False, + test_suite='warcprox.tests') diff --git a/warcprox/tests/test_warcproxy.py b/warcprox/tests/test_warcproxy.py new file mode 100644 index 0000000..a879e52 --- /dev/null +++ b/warcprox/tests/test_warcproxy.py @@ -0,0 +1,57 @@ +# vim: set sw=4 et: + +import unittest +import BaseHTTPServer +import threading +import time +from warcprox import warcprox +import logging +import sys + +class WarcproxTest(unittest.TestCase): + logger = logging.getLogger('WarcproxTest') + + def setUp(self): + logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format='%(asctime)s %(process)d %(threadName)s %(levelname)s %(name)s.%(funcName)s(%(filename)s:%(lineno)d) %(message)s') + + self.httpd = BaseHTTPServer.HTTPServer(('localhost', 0), + RequestHandlerClass=BaseHTTPServer.BaseHTTPRequestHandler) + self.logger.info('starting httpd on {}:{}'.format(self.httpd.server_address[0], self.httpd.server_address[1])) + self.httpd_thread = threading.Thread(name='HttpdThread', + target=self.httpd.serve_forever) + self.httpd_thread.start() + + self.warcprox = warcprox.WarcproxController() + self.logger.info('starting warcprox') + self.warcprox_thread = threading.Thread(name='WarcproxThread', + target=self.warcprox.run_until_shutdown) + self.warcprox_thread.start() + + def tearDown(self): + self.logger.info('stopping warcprox') + self.warcprox.stop.set() + + self.logger.info('stopping httpd') + self.httpd.shutdown() + self.httpd.server_close() + + # Have to wait for threads to finish or the threads will try to use + # variables that have been deleted, resulting in errors like this: + # File "/usr/lib/python2.7/SocketServer.py", line 235, in serve_forever + # r, w, e = _eintr_retry(select.select, [self], [], [], + # AttributeError: 'NoneType' object has no attribute 'select' + self.httpd_thread.join() + self.warcprox_thread.join() + + def test_something(self): + self.logger.info('sleeping for 5 seconds...') + try: + time.sleep(5) + except: + self.logger.info('interrupted') + self.logger.info('finished sleeping') + +if __name__ == '__main__': + unittest.main() + diff --git a/warcprox/warcprox.py b/warcprox/warcprox.py index 3719e2e..0ef0d19 100644 --- a/warcprox/warcprox.py +++ b/warcprox/warcprox.py @@ -34,6 +34,7 @@ import gdbm from StringIO import StringIO class CertificateAuthority(object): + logger = logging.getLogger('warcprox.CertificateAuthority') def __init__(self, ca_file='warcprox-ca.pem', certs_dir='./warcprox-ca'): self.ca_file = ca_file @@ -45,7 +46,7 @@ class CertificateAuthority(object): self._read_ca(ca_file) if not os.path.exists(certs_dir): - logging.info("directory for generated certs {} doesn't exist, creating it".format(certs_dir)) + self.logger.info("directory for generated certs {} doesn't exist, creating it".format(certs_dir)) os.mkdir(certs_dir) @@ -75,13 +76,13 @@ class CertificateAuthority(object): f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.SSL.FILETYPE_PEM, self.key)) f.write(OpenSSL.crypto.dump_certificate(OpenSSL.SSL.FILETYPE_PEM, self.cert)) - logging.info('generated CA key+cert and wrote to {}'.format(self.ca_file)) + self.logger.info('generated CA key+cert and wrote to {}'.format(self.ca_file)) def _read_ca(self, filename): self.cert = OpenSSL.crypto.load_certificate(OpenSSL.SSL.FILETYPE_PEM, open(filename).read()) self.key = OpenSSL.crypto.load_privatekey(OpenSSL.SSL.FILETYPE_PEM, open(filename).read()) - logging.info('read CA key+cert from {}'.format(self.ca_file)) + self.logger.info('read CA key+cert from {}'.format(self.ca_file)) def __getitem__(self, cn): cnp = os.path.sep.join([self.certs_dir, '%s.pem' % cn]) @@ -110,7 +111,7 @@ class CertificateAuthority(object): f.write(OpenSSL.crypto.dump_privatekey(OpenSSL.SSL.FILETYPE_PEM, key)) f.write(OpenSSL.crypto.dump_certificate(OpenSSL.SSL.FILETYPE_PEM, cert)) - logging.info('wrote generated key+cert to {}'.format(cnp)) + self.logger.info('wrote generated key+cert to {}'.format(cnp)) return cnp @@ -121,6 +122,8 @@ class ProxyingRecorder(object): calculating digests, and sending them on to the proxy client. """ + logger = logging.getLogger('warcprox.ProxyingRecordingHTTPResponse') + def __init__(self, fp, proxy_dest, digest_algorithm='sha1'): self.fp = fp # "The file has no name, and will cease to exist when it is closed." @@ -174,8 +177,8 @@ class ProxyingRecorder(object): self.proxy_dest.sendall(hunk) except BaseException as e: self._proxy_dest_conn_open = False - logging.warn('{} sending data to proxy client'.format(e)) - logging.info('will continue downloading from remote server without sending to client') + self.logger.warn('{} sending data to proxy client'.format(e)) + self.logger.info('will continue downloading from remote server without sending to client') self.len += len(hunk) @@ -217,6 +220,7 @@ class ProxyingRecordingHTTPResponse(httplib.HTTPResponse): class MitmProxyHandler(BaseHTTPServer.BaseHTTPRequestHandler): + logger = logging.getLogger('warcprox.MitmProxyHandler') def __init__(self, request, client_address, server): self.is_connect = False @@ -326,16 +330,18 @@ class MitmProxyHandler(BaseHTTPServer.BaseHTTPRequestHandler): return self.do_COMMAND def log_error(self, fmt, *args): - logging.error("{0} - - [{1}] {2}".format(self.address_string(), + self.logger.error("{0} - - [{1}] {2}".format(self.address_string(), self.log_date_time_string(), fmt % args)) def log_message(self, fmt, *args): - logging.info("{} {} - - [{}] {}".format(self.__class__.__name__, + self.logger.info("{} {} - - [{}] {}".format(self.__class__.__name__, self.address_string(), self.log_date_time_string(), fmt % args)) class WarcProxyHandler(MitmProxyHandler): + logger = logging.getLogger('warcprox.WarcProxyHandler') + def _proxy_request(self): # Build request req = '%s %s %s\r\n' % (self.command, self.path, self.request_version) @@ -390,25 +396,36 @@ class RecordedUrl(object): class WarcProxy(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer): + logger = logging.getLogger('warcprox.WarcProxy') - def __init__(self, server_address, req_handler_class=WarcProxyHandler, - bind_and_activate=True, ca=None, recorded_url_q=None, - digest_algorithm='sha1'): + def __init__(self, server_address=('localhost', 8000), + req_handler_class=WarcProxyHandler, bind_and_activate=True, + ca=None, recorded_url_q=None, digest_algorithm='sha1'): BaseHTTPServer.HTTPServer.__init__(self, server_address, req_handler_class, bind_and_activate) - self.ca = ca - self.recorded_url_q = recorded_url_q + self.digest_algorithm = digest_algorithm + if ca is not None: + self.ca = ca + else: + self.ca = CertificateAuthority() + + if recorded_url_q is not None: + self.recorded_url_q = recorded_url_q + else: + self.recorded_url_q = Queue.Queue() + def server_activate(self): BaseHTTPServer.HTTPServer.server_activate(self) - logging.info('WarcProxy listening on {0}:{1}'.format(self.server_address[0], self.server_address[1])) + self.logger.info('WarcProxy listening on {0}:{1}'.format(self.server_address[0], self.server_address[1])) def server_close(self): - logging.info('WarcProxy shutting down') + self.logger.info('WarcProxy shutting down') BaseHTTPServer.HTTPServer.server_close(self) class PlaybackProxyHandler(MitmProxyHandler): + logger = logging.getLogger('warcprox.PlaybackProxyHandler') # @Override def _connect_to_host(self): @@ -419,7 +436,7 @@ class PlaybackProxyHandler(MitmProxyHandler): # @Override def _proxy_request(self): date, location = self.server.playback_index_db.lookup_latest(self.url) - logging.debug('lookup_latest returned {}:{}'.format(date, location)) + self.logger.debug('lookup_latest returned {}:{}'.format(date, location)) status = None if location is not None: @@ -427,7 +444,7 @@ class PlaybackProxyHandler(MitmProxyHandler): status, sz = self._send_response_from_warc(location[b'f'], location[b'o']) except: status = 500 - logging.error('PlaybackProxyHandler problem playing back {}'.format(self.url), exc_info=1) + self.logger.error('PlaybackProxyHandler problem playing back {}'.format(self.url), exc_info=1) payload = '500 Warcprox Error\n\n{}\n'.format(traceback.format_exc()) headers = ('HTTP/1.1 500 Internal Server Error\r\n' + 'Content-Type: text/plain\r\n' @@ -452,7 +469,7 @@ class PlaybackProxyHandler(MitmProxyHandler): def _open_warc_at_offset(self, warcfilename, offset): - logging.debug('opening {} at offset {}'.format(warcfilename, offset)) + self.logger.debug('opening {} at offset {}'.format(warcfilename, offset)) warcpath = None for p in (os.path.sep.join([self.server.warcs_dir, warcfilename]), @@ -486,7 +503,7 @@ class PlaybackProxyHandler(MitmProxyHandler): def _send_headers_and_refd_payload(self, headers, refers_to_target_uri, refers_to_date): location = self.server.playback_index_db.lookup_exact(refers_to_target_uri, refers_to_date) - logging.debug('loading http payload from {}'.format(location)) + self.logger.debug('loading http payload from {}'.format(location)) fh = self._open_warc_at_offset(location['f'], location['o']) try: @@ -543,7 +560,7 @@ class PlaybackProxyHandler(MitmProxyHandler): refers_to_target_uri = record.get_header(warctools.WarcRecord.REFERS_TO_TARGET_URI) refers_to_date = record.get_header(warctools.WarcRecord.REFERS_TO_DATE) - logging.debug('revisit record references {} capture of {}'.format(refers_to_date, refers_to_target_uri)) + self.logger.debug('revisit record references {} capture of {}'.format(refers_to_date, refers_to_target_uri)) return self._send_headers_and_refd_payload(record.content[1], refers_to_target_uri, refers_to_date) else: @@ -556,6 +573,7 @@ class PlaybackProxyHandler(MitmProxyHandler): class PlaybackProxy(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer): + logger = logging.getLogger('warcprox.PlaybackProxy') def __init__(self, server_address, req_handler_class=PlaybackProxyHandler, bind_and_activate=True, ca=None, playback_index_db=None, @@ -567,20 +585,21 @@ class PlaybackProxy(SocketServer.ThreadingMixIn, BaseHTTPServer.HTTPServer): def server_activate(self): BaseHTTPServer.HTTPServer.server_activate(self) - logging.info('PlaybackProxy listening on {0}:{1}'.format(self.server_address[0], self.server_address[1])) + self.logger.info('PlaybackProxy listening on {0}:{1}'.format(self.server_address[0], self.server_address[1])) def server_close(self): - logging.info('PlaybackProxy shutting down') + self.logger.info('PlaybackProxy shutting down') BaseHTTPServer.HTTPServer.server_close(self) class DedupDb(object): + logger = logging.getLogger('warcprox.DedupDb') def __init__(self, dbm_file='./warcprox-dedup.db'): if os.path.exists(dbm_file): - logging.info('opening existing deduplication database {}'.format(dbm_file)) + self.logger.info('opening existing deduplication database {}'.format(dbm_file)) else: - logging.info('creating new deduplication database {}'.format(dbm_file)) + self.logger.info('creating new deduplication database {}'.format(dbm_file)) self.db = gdbm.open(dbm_file, 'c') @@ -599,7 +618,7 @@ class DedupDb(object): json_value = json.dumps(py_value, separators=(',',':')) self.db[key] = json_value - logging.debug('dedup db saved {}:{}'.format(key, json_value)) + self.logger.debug('dedup db saved {}:{}'.format(key, json_value)) def lookup(self, key): @@ -612,12 +631,14 @@ class DedupDb(object): class WarcWriterThread(threading.Thread): + logger = logging.getLogger('warcprox.WarcWriterThread') # port is only used for warc filename - def __init__(self, recorded_url_q, directory, rollover_size=1000000000, - rollover_idle_time=None, gzip=False, prefix='WARCPROX', port=0, - digest_algorithm='sha1', base32=False, dedup_db=None, - playback_index_db=None): + def __init__(self, recorded_url_q=None, directory='./warcs', + rollover_size=1000000000, rollover_idle_time=None, gzip=False, + prefix='WARCPROX', port=0, digest_algorithm='sha1', base32=False, + dedup_db=None, playback_index_db=None): + threading.Thread.__init__(self, name='WarcWriterThread') self.recorded_url_q = recorded_url_q @@ -642,12 +663,11 @@ class WarcWriterThread(threading.Thread): self._serial = 0 if not os.path.exists(directory): - logging.info("warc destination directory {} doesn't exist, creating it".format(directory)) + self.logger.info("warc destination directory {} doesn't exist, creating it".format(directory)) os.mkdir(directory) self.stop = threading.Event() - self.listeners = [] # returns a tuple (principal_record, request_record) where principal_record is either a response or revisit record def build_warc_records(self, recorded_url): @@ -760,7 +780,7 @@ class WarcWriterThread(threading.Thread): def _close_writer(self): if self._fpath: - logging.info('closing {0}'.format(self._f_finalname)) + self.logger.info('closing {0}'.format(self._f_finalname)) self._f.close() finalpath = os.path.sep.join([self.directory, self._f_finalname]) os.rename(self._fpath, finalpath) @@ -828,7 +848,7 @@ class WarcWriterThread(threading.Thread): recorded_url.response_recorder.tempfile.close() def run(self): - logging.info('WarcWriterThread starting, directory={} gzip={} rollover_size={} rollover_idle_time={} prefix={} port={}'.format( + self.logger.info('WarcWriterThread starting, directory={} gzip={} rollover_size={} rollover_idle_time={} prefix={} port={}'.format( os.path.abspath(self.directory), self.gzip, self.rollover_size, self.rollover_idle_time, self.prefix, self.port)) @@ -848,7 +868,7 @@ class WarcWriterThread(threading.Thread): for record in recordset: offset = writer.tell() record.write_to(writer, gzip=self.gzip) - logging.debug('wrote warc record: warc_type={} content_length={} url={} warc={} offset={}'.format( + self.logger.debug('wrote warc record: warc_type={} content_length={} url={} warc={} offset={}'.format( record.get_header(warctools.WarcRecord.TYPE), record.get_header(warctools.WarcRecord.CONTENT_LENGTH), record.get_header(warctools.WarcRecord.URL), @@ -863,7 +883,7 @@ class WarcWriterThread(threading.Thread): and self.rollover_idle_time is not None and self.rollover_idle_time > 0 and time.time() - self._last_activity > self.rollover_idle_time): - logging.debug('rolling over warc file after {} seconds idle'.format(time.time() - self._last_activity)) + self.logger.debug('rolling over warc file after {} seconds idle'.format(time.time() - self._last_activity)) self._close_writer() if time.time() - self._last_sync > 60: @@ -873,17 +893,18 @@ class WarcWriterThread(threading.Thread): self.playback_index_db.sync() self._last_sync = time.time() - logging.info('WarcWriterThread shutting down') + self.logger.info('WarcWriterThread shutting down') self._close_writer(); class PlaybackIndexDb(object): + logger = logging.getLogger('warcprox.PlaybackIndexDb') def __init__(self, dbm_file='./warcprox-playback-index.db'): if os.path.exists(dbm_file): - logging.info('opening existing playback index database {}'.format(dbm_file)) + self.logger.info('opening existing playback index database {}'.format(dbm_file)) else: - logging.info('creating new playback index database {}'.format(dbm_file)) + self.logger.info('creating new playback index database {}'.format(dbm_file)) self.db = gdbm.open(dbm_file, 'c') @@ -913,7 +934,7 @@ class PlaybackIndexDb(object): self.db[url] = json_value - logging.debug('playback index saved: {}:{}'.format(url, json_value)) + self.logger.debug('playback index saved: {}:{}'.format(url, json_value)) def lookup_latest(self, url): @@ -940,32 +961,82 @@ class PlaybackIndexDb(object): return None -def run_until_shutdown(proxy, warc_writer, dedup_db, playback_proxy, playback_index_db): - stop = threading.Event() - signal.signal(signal.SIGTERM, stop.set) +class WarcproxController(object): + logger = logging.getLogger('warcprox.WarcproxController') - try: - while not stop.is_set(): - time.sleep(0.5) - except: - pass - finally: - warc_writer.stop.set() - proxy.shutdown() - proxy.server_close() + def __init__(self, proxy=None, warc_writer=None, playback_proxy=None): + """ + Create warcprox controller. + + If supplied, proxy should be an instance of WarcProxy, and warc_writer + should be an instance of WarcWriterThread. If not supplied, they are + created with default values. + + If supplied, playback_proxy should be an instance of PlaybackProxy. If not + supplied, no playback proxy will run. + """ + if proxy is not None: + self.proxy = proxy + else: + self.proxy = WarcProxy() - if playback_proxy is not None: - playback_proxy.shutdown() - playback_proxy.server_close() + if warc_writer is not None: + self.warc_writer = warc_writer + else: + self.warc_writer = WarcWriterThread(recorded_url_q=self.proxy.recorded_url_q) - if dedup_db is not None: - dedup_db.close() + self.playback_proxy = playback_proxy - if playback_index_db is not None: - playback_index_db.close() + + def run_until_shutdown(self): + """Start warcprox and run until shut down. + + If running in the main thread, SIGTERM initiates a graceful shutdown. + Otherwise, call warcprox_controller.stop.set(). + """ + proxy_thread = threading.Thread(target=self.proxy.serve_forever, name='ProxyThread') + proxy_thread.start() + self.warc_writer.start() + + if self.playback_proxy is not None: + playback_proxy_thread = threading.Thread(target=self.playback_proxy.serve_forever, name='PlaybackProxyThread') + playback_proxy_thread.start() + + self.stop = threading.Event() + + try: + signal.signal(signal.SIGTERM, self.stop.set) + self.logger.info('SIGTERM will initiate graceful shutdown') + except ValueError: + pass + + try: + while not self.stop.is_set(): + time.sleep(0.5) + except: + pass + finally: + self.warc_writer.stop.set() + self.proxy.shutdown() + self.proxy.server_close() + + if self.warc_writer.dedup_db is not None: + self.warc_writer.dedup_db.close() + + if self.playback_proxy is not None: + self.playback_proxy.shutdown() + self.playback_proxy.server_close() + if self.playback_proxy.playback_index_db is not None: + self.playback_proxy.playback_index_db.close() + + # wait for threads to finish + self.warc_writer.join() + proxy_thread.join() + if self.playback_proxy is not None: + playback_proxy_thread.join() -def _build_arg_parser(prog=sys.argv[0]): +def _build_arg_parser(prog=os.path.basename(sys.argv[0])): arg_parser = argparse.ArgumentParser(prog=prog, description='warcprox - WARC writing MITM HTTP/S proxy', formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -1013,7 +1084,7 @@ def _build_arg_parser(prog=sys.argv[0]): def main(argv=sys.argv): - arg_parser = _build_arg_parser(prog=argv[0]) + arg_parser = _build_arg_parser(prog=os.path.basename(argv[0])) args = arg_parser.parse_args(args=argv[1:]) if args.verbose: @@ -1024,7 +1095,7 @@ def main(argv=sys.argv): loglevel = logging.INFO logging.basicConfig(stream=sys.stdout, level=loglevel, - format='%(asctime)s %(process)d %(threadName)s %(levelname)s %(funcName)s(%(filename)s:%(lineno)d) %(message)s') + format='%(asctime)s %(process)d %(threadName)s %(levelname)s %(name)s.%(funcName)s(%(filename)s:%(lineno)d) %(message)s') try: hashlib.new(args.digest_algorithm) @@ -1052,8 +1123,6 @@ def main(argv=sys.argv): playback_proxy = PlaybackProxy(server_address=playback_server_address, ca=ca, playback_index_db=playback_index_db, warcs_dir=args.directory) - playback_proxy_thread = threading.Thread(target=playback_proxy.serve_forever, name='PlaybackProxyThread') - playback_proxy_thread.start() else: playback_index_db = None playback_proxy = None @@ -1066,12 +1135,11 @@ def main(argv=sys.argv): digest_algorithm=args.digest_algorithm, playback_index_db=playback_index_db) - proxy_thread = threading.Thread(target=proxy.serve_forever, name='ProxyThread') - proxy_thread.start() - warc_writer.start() - - run_until_shutdown(proxy, warc_writer, dedup_db, playback_proxy, playback_index_db) + # run_warcprox(proxy, warc_writer, playback_proxy) + warcprox = WarcproxController(proxy, warc_writer, playback_proxy) + warcprox.run_until_shutdown() if __name__ == '__main__': main() +