From 4ce89e6d038d42845b84c8456fe7d7941d633181 Mon Sep 17 00:00:00 2001 From: Noah Levitt Date: Thu, 30 Jul 2015 01:59:48 +0000 Subject: [PATCH] basic limits enforcement is working --- warcprox/__init__.py | 1 + warcprox/main.py | 14 ++++- warcprox/stats.py | 97 +++++++++++++++++++++++++++++++++ warcprox/tests/test_warcprox.py | 24 +++++--- warcprox/warcproxy.py | 39 +++++++++++-- warcprox/writerthread.py | 8 ++- 6 files changed, 167 insertions(+), 16 deletions(-) create mode 100644 warcprox/stats.py diff --git a/warcprox/__init__.py b/warcprox/__init__.py index c3379c6..7235056 100644 --- a/warcprox/__init__.py +++ b/warcprox/__init__.py @@ -8,6 +8,7 @@ import warcprox.mitmproxy as mitmproxy import warcprox.writer as writer import warcprox.warc as warc import warcprox.writerthread as writerthread +import warcprox.stats as stats def digest_str(hash_obj, base32): import base64 diff --git a/warcprox/main.py b/warcprox/main.py index a98691d..58d6a77 100644 --- a/warcprox/main.py +++ b/warcprox/main.py @@ -57,6 +57,8 @@ def _build_arg_parser(prog=os.path.basename(sys.argv[0])): default=False, help='write digests in Base32 instead of hex') arg_parser.add_argument('-j', '--dedup-db-file', dest='dedup_db_file', default='./warcprox-dedup.db', help='persistent deduplication database file; empty string or /dev/null disables deduplication') + arg_parser.add_argument('--stats-db-file', dest='stats_db_file', + default='./warcprox-stats.db', help='persistent statistics database file; empty string or /dev/null disables deduplication') arg_parser.add_argument('-P', '--playback-port', dest='playback_port', default=None, help='port to listen on for instant playback') arg_parser.add_argument('--playback-index-db-file', dest='playback_index_db_file', @@ -112,6 +114,12 @@ def main(argv=sys.argv): else: dedup_db = warcprox.dedup.DedupDb(args.dedup_db_file) + if args.stats_db_file in (None, '', '/dev/null'): + logging.info('statistics tracking disabled') + stats_db = None + else: + stats_db = warcprox.stats.StatsDb(args.stats_db_file) + recorded_url_q = queue.Queue() ca_name = 'Warcprox CA on {}'.format(socket.gethostname())[:64] @@ -121,7 +129,8 @@ def main(argv=sys.argv): proxy = warcprox.warcproxy.WarcProxy( server_address=(args.address, int(args.port)), ca=ca, recorded_url_q=recorded_url_q, - digest_algorithm=args.digest_algorithm) + digest_algorithm=args.digest_algorithm, + stats_db=stats_db) if args.playback_port is not None: playback_index_db = warcprox.playback.PlaybackIndexDb(args.playback_index_db_file) @@ -141,7 +150,8 @@ def main(argv=sys.argv): writer_pool=warcprox.writer.WarcWriterPool(default_warc_writer) warc_writer_thread = warcprox.writerthread.WarcWriterThread( recorded_url_q=recorded_url_q, writer_pool=writer_pool, - dedup_db=dedup_db, playback_index_db=playback_index_db) + dedup_db=dedup_db, playback_index_db=playback_index_db, + stats_db=stats_db) controller = warcprox.controller.WarcproxController(proxy, warc_writer_thread, playback_proxy) diff --git a/warcprox/stats.py b/warcprox/stats.py new file mode 100644 index 0000000..6ad3ca4 --- /dev/null +++ b/warcprox/stats.py @@ -0,0 +1,97 @@ +# vim:set sw=4 et: + +from __future__ import absolute_import + +try: + import dbm.gnu as dbm_gnu +except ImportError: + try: + import gdbm as dbm_gnu + except ImportError: + import anydbm as dbm_gnu + +import logging +import os +import json +from hanzo import warctools + +class StatsDb: + logger = logging.getLogger("warcprox.stats.StatsDb") + + def __init__(self, dbm_file='./warcprox-stats.db'): + if os.path.exists(dbm_file): + self.logger.info('opening existing stats database {}'.format(dbm_file)) + else: + self.logger.info('creating new stats database {}'.format(dbm_file)) + + self.db = dbm_gnu.open(dbm_file, 'c') + + def close(self): + self.db.close() + + def sync(self): + try: + self.db.sync() + except: + pass + + def _empty_bucket(self): + return { + "total": { + "urls": 0, + "wire_bytes": 0, + # "warc_bytes": 0, + }, + "new": { + "urls": 0, + "wire_bytes": 0, + # "warc_bytes": 0, + }, + "revisit": { + "urls": 0, + "wire_bytes": 0, + # "warc_bytes": 0, + }, + } + + def value(self, bucket0="__all__", bucket1=None, bucket2=None): + if bucket0 in self.db: + bucket0_stats = json.loads(self.db[bucket0].decode("utf-8")) + if bucket1: + if bucket2: + return bucket0_stats[bucket1][bucket2] + else: + return bucket0_stats[bucket1] + else: + return bucket0_stats + else: + return None + + def tally(self, recorded_url, records): + buckets = ["__all__"] + + if (recorded_url.warcprox_meta + and "stats" in recorded_url.warcprox_meta + and "buckets" in recorded_url.warcprox_meta["stats"]): + buckets.extend(recorded_url.warcprox_meta["stats"]["buckets"]) + else: + buckets.append("__unspecified__") + + for bucket in buckets: + if bucket in self.db: + bucket_stats = json.loads(self.db[bucket].decode("utf-8")) + else: + bucket_stats = self._empty_bucket() + + bucket_stats["total"]["urls"] += 1 + bucket_stats["total"]["wire_bytes"] += recorded_url.size + + if records[0].get_header(warctools.WarcRecord.TYPE) == warctools.WarcRecord.REVISIT: + bucket_stats["revisit"]["urls"] += 1 + bucket_stats["revisit"]["wire_bytes"] += recorded_url.size + else: + bucket_stats["new"]["urls"] += 1 + bucket_stats["new"]["wire_bytes"] += recorded_url.size + + self.db[bucket] = json.dumps(bucket_stats, separators=(',',':')).encode("utf-8") + diff --git a/warcprox/tests/test_warcprox.py b/warcprox/tests/test_warcprox.py index 57ad613..33a01bb 100755 --- a/warcprox/tests/test_warcprox.py +++ b/warcprox/tests/test_warcprox.py @@ -138,8 +138,13 @@ def warcprox_(request): recorded_url_q = queue.Queue() + f = tempfile.NamedTemporaryFile(prefix='warcprox-test-stats-', suffix='.db', delete=False) + f.close() + stats_db_file = f.name + stats_db = warcprox.stats.StatsDb(stats_db_file) + proxy = warcprox.warcproxy.WarcProxy(server_address=('localhost', 0), ca=ca, - recorded_url_q=recorded_url_q) + recorded_url_q=recorded_url_q, stats_db=stats_db) warcs_dir = tempfile.mkdtemp(prefix='warcprox-test-warcs-') @@ -160,7 +165,8 @@ def warcprox_(request): writer_pool = warcprox.writer.WarcWriterPool(default_warc_writer) warc_writer_thread = warcprox.writerthread.WarcWriterThread( recorded_url_q=recorded_url_q, writer_pool=writer_pool, - dedup_db=dedup_db, playback_index_db=playback_index_db) + dedup_db=dedup_db, playback_index_db=playback_index_db, + stats_db=stats_db) warcprox_ = warcprox.controller.WarcproxController(proxy, warc_writer_thread, playback_proxy) logging.info('starting warcprox') @@ -172,7 +178,7 @@ def warcprox_(request): logging.info('stopping warcprox') warcprox_.stop.set() warcprox_thread.join() - for f in (ca_file, ca_dir, warcs_dir, playback_index_db_file, dedup_db_file): + for f in (ca_file, ca_dir, warcs_dir, playback_index_db_file, dedup_db_file, stats_db_file): if os.path.isdir(f): logging.info('deleting directory {}'.format(f)) shutil.rmtree(f) @@ -389,7 +395,7 @@ def test_dedup_https(https_daemon, warcprox_, archiving_proxies, playback_proxie def test_limits(http_daemon, archiving_proxies): url = 'http://localhost:{}/a/b'.format(http_daemon.server_port) - request_meta = {"stats":{"classifiers":["job1"]},"limits":{"job1.total.urls":10}} + request_meta = {"stats":{"buckets":["job1"],"limits":{"job1.total.urls":10}}} headers = {"Warcprox-Meta": json.dumps(request_meta)} for i in range(10): @@ -400,11 +406,11 @@ def test_limits(http_daemon, archiving_proxies): response = requests.get(url, proxies=archiving_proxies, headers=headers, stream=True) assert response.status_code == 420 - assert response.reason == "Limit Reached" - response_meta = {"stats":{"job1":{"total":{"urls":10},"new":{"urls":1},"revisit":{"urls":9}}}} - assert json.loads(headers["warcprox-meta"]) == response_meta - assert response.headers["content-type"] == "text/plain;charset=utf-8" - assert response.raw.data == b"request rejected by warcprox: reached limit job1.total.urls=10\n" + assert response.reason == "Limit reached" + # response_meta = {"stats":{"job1":{"total":{"urls":10},"new":{"urls":1},"revisit":{"urls":9}}}} + # assert json.loads(headers["warcprox-meta"]) == response_meta + # assert response.headers["content-type"] == "text/plain;charset=utf-8" + # assert response.raw.data == b"request rejected by warcprox: reached limit job1.total.urls=10\n" if __name__ == '__main__': pytest.main() diff --git a/warcprox/warcproxy.py b/warcprox/warcproxy.py index d81ba87..b2a7345 100644 --- a/warcprox/warcproxy.py +++ b/warcprox/warcproxy.py @@ -153,13 +153,38 @@ class ProxyingRecordingHTTPResponse(http_client.HTTPResponse): class WarcProxyHandler(warcprox.mitmproxy.MitmProxyHandler): + # self.server is WarcProxy logger = logging.getLogger("warcprox.warcprox.WarcProxyHandler") + def _enforce_limits(self, warcprox_meta): + self.logger.info("warcprox_meta=%s", warcprox_meta) + if (warcprox_meta and "stats" in warcprox_meta + and "limits" in warcprox_meta["stats"]): + self.logger.info("warcprox_meta['stats']['limits']=%s", warcprox_meta['stats']['limits']) + for item in warcprox_meta["stats"]["limits"].items(): + self.logger.info("item=%s", item) + key, limit = item + self.logger.info("limit %s=%d", key, limit) + bucket0, bucket1, bucket2 = key.rsplit(".", 2) + self.logger.info("%s::%s::%s", bucket0, bucket1, bucket2) + value = self.server.stats_db.value(bucket0, bucket1, bucket2) + self.logger.info("stats value is %s", value) + if value and value >= limit: + self.send_error(420, "Limit reached") + self.connection.close() + return + def _proxy_request(self): # Build request req_str = '{} {} {}\r\n'.format(self.command, self.path, self.request_version) - warcprox_meta = self.headers.get('Warcprox-Meta') + warcprox_meta = None + raw_warcprox_meta = self.headers.get('Warcprox-Meta') + if raw_warcprox_meta: + warcprox_meta = json.loads(raw_warcprox_meta) + + if self._enforce_limits(warcprox_meta): + return # Swallow headers that don't make sense to forward on, i.e. most # hop-by-hop headers, see http://tools.ietf.org/html/rfc2616#section-13.5 @@ -241,7 +266,10 @@ class WarcProxyHandler(warcprox.mitmproxy.MitmProxyHandler): # stream this? request_data = self.rfile.read(int(self.headers['Content-Length'])) - warcprox_meta = self.headers.get('Warcprox-Meta') + warcprox_meta = None + raw_warcprox_meta = self.headers.get('Warcprox-Meta') + if raw_warcprox_meta: + warcprox_meta = json.loads(raw_warcprox_meta) rec_custom = RecordedUrl(url=self.url, request_data=request_data, @@ -295,7 +323,7 @@ class RecordedUrl: self.response_recorder = response_recorder if warcprox_meta: - self.warcprox_meta = json.loads(warcprox_meta) + self.warcprox_meta = warcprox_meta else: self.warcprox_meta = {} @@ -319,7 +347,8 @@ class WarcProxy(socketserver.ThreadingMixIn, http_server.HTTPServer): def __init__(self, server_address=('localhost', 8000), req_handler_class=WarcProxyHandler, bind_and_activate=True, - ca=None, recorded_url_q=None, digest_algorithm='sha1'): + ca=None, recorded_url_q=None, digest_algorithm='sha1', + stats_db=None): http_server.HTTPServer.__init__(self, server_address, req_handler_class, bind_and_activate) self.digest_algorithm = digest_algorithm @@ -337,6 +366,8 @@ class WarcProxy(socketserver.ThreadingMixIn, http_server.HTTPServer): else: self.recorded_url_q = queue.Queue() + self.stats_db = stats_db + def server_activate(self): http_server.HTTPServer.server_activate(self) self.logger.info('WarcProxy listening on {0}:{1}'.format(self.server_address[0], self.server_address[1])) diff --git a/warcprox/writerthread.py b/warcprox/writerthread.py index ceb34cd..68c5676 100644 --- a/warcprox/writerthread.py +++ b/warcprox/writerthread.py @@ -22,7 +22,7 @@ import warcprox class WarcWriterThread(threading.Thread): logger = logging.getLogger("warcprox.warcproxwriter.WarcWriterThread") - def __init__(self, recorded_url_q=None, writer_pool=None, dedup_db=None, playback_index_db=None): + def __init__(self, recorded_url_q=None, writer_pool=None, dedup_db=None, playback_index_db=None, stats_db=None): """recorded_url_q is a queue.Queue of warcprox.warcprox.RecordedUrl.""" threading.Thread.__init__(self, name='WarcWriterThread') self.recorded_url_q = recorded_url_q @@ -33,6 +33,7 @@ class WarcWriterThread(threading.Thread): self.writer_pool = WarcWriterPool() self.dedup_db = dedup_db self.playback_index_db = playback_index_db + self.stats_db = stats_db self._last_sync = time.time() def run(self): @@ -106,7 +107,12 @@ class WarcWriterThread(threading.Thread): _decode(records[0].warc_filename), records[0].offset)) + def _update_stats(self, recorded_url, records): + if self.stats_db: + self.stats_db.tally(recorded_url, records) + def _final_tasks(self, recorded_url, records): self._save_dedup_info(recorded_url, records) self._save_playback_info(recorded_url, records) + self._update_stats(recorded_url, records) self._log(recorded_url, records)