From 6ff9030e6728a699030c613840e035a1163bbcdd Mon Sep 17 00:00:00 2001 From: Noah Levitt Date: Tue, 16 Jan 2018 15:18:53 -0800 Subject: [PATCH] improve batching, make tests pass --- tests/test_warcprox.py | 7 ++++++- warcprox/__init__.py | 27 +++++++++++++++++++-------- warcprox/dedup.py | 14 ++++++++------ 3 files changed, 33 insertions(+), 15 deletions(-) diff --git a/tests/test_warcprox.py b/tests/test_warcprox.py index 36b7cea..95532d4 100755 --- a/tests/test_warcprox.py +++ b/tests/test_warcprox.py @@ -339,6 +339,9 @@ def warcprox_(request): logging.info('changing to working directory %r', work_dir) os.chdir(work_dir) + # we can't wait around all day in the tests + warcprox.BaseBatchPostfetchProcessor.MAX_BATCH_SEC = 0.5 + argv = ['warcprox', '--method-filter=GET', '--method-filter=POST', @@ -362,6 +365,7 @@ def warcprox_(request): warcprox_ = warcprox.controller.WarcproxController(options) logging.info('starting warcprox') + warcprox_.start() warcprox_thread = threading.Thread( name='WarcproxThread', target=warcprox_.run_until_shutdown) warcprox_thread.start() @@ -1372,7 +1376,7 @@ def test_controller_with_defaults(): wwt = controller.warc_writer_thread assert wwt assert wwt.inq - assert not wwt.outq + assert wwt.outq assert wwt.writer_pool assert wwt.writer_pool.default_warc_writer assert wwt.writer_pool.default_warc_writer.directory == './warcs' @@ -1396,6 +1400,7 @@ def test_choose_a_port_for_me(warcprox_): '127.0.0.1', controller.proxy.server_port) th = threading.Thread(target=controller.run_until_shutdown) + controller.start() th.start() try: # check that the status api lists the correct port diff --git a/warcprox/__init__.py b/warcprox/__init__.py index ed934b4..60ca2ef 100644 --- a/warcprox/__init__.py +++ b/warcprox/__init__.py @@ -167,16 +167,27 @@ class BaseStandardPostfetchProcessor(BasePostfetchProcessor): class BaseBatchPostfetchProcessor(BasePostfetchProcessor): MAX_BATCH_SIZE = 500 + MAX_BATCH_SEC = 10 def _get_process_put(self): batch = [] - batch.append(self.inq.get(block=True, timeout=0.5)) - try: - while len(batch) < self.MAX_BATCH_SIZE: - batch.append(self.inq.get(block=False)) - except queue.Empty: - pass + start = time.time() + while (len(batch) < self.MAX_BATCH_SIZE + and time.time() - start < self.MAX_BATCH_SEC): + try: + batch.append(self.inq.get(block=True, timeout=0.5)) + except queue.Empty: + if self.stop.is_set(): + break + # else keep adding to the batch + + if not batch: + raise queue.Empty + + self.logger.info( + 'gathered batch of %s in %0.1f sec', + len(batch), time.time() - start) self._process_batch(batch) if self.outq: @@ -187,8 +198,8 @@ class BaseBatchPostfetchProcessor(BasePostfetchProcessor): raise Exception('not implemented') class ListenerPostfetchProcessor(BaseStandardPostfetchProcessor): - def __init__(self, listener, inq, outq, profile=False): - BaseStandardPostfetchProcessor.__init__(self, inq, outq, profile) + def __init__(self, listener, inq, outq, options=Options()): + BaseStandardPostfetchProcessor.__init__(self, inq, outq, options) self.listener = listener self.name = listener.__class__.__name__ diff --git a/warcprox/dedup.py b/warcprox/dedup.py index 962ec3a..b9e136e 100644 --- a/warcprox/dedup.py +++ b/warcprox/dedup.py @@ -39,7 +39,7 @@ urllib3.disable_warnings() class DedupLoader(warcprox.BaseStandardPostfetchProcessor): def __init__(self, dedup_db, inq, outq, options=warcprox.Options()): warcprox.BaseStandardPostfetchProcessor.__init__( - self, inq, outq, profile) + self, inq, outq, options) self.dedup_db = dedup_db def _process_url(self, recorded_url): decorate_with_dedup_info( @@ -71,11 +71,12 @@ class DedupDb(object): conn.commit() conn.close() - def loader(self, inq, outq, profile=False): - return DedupLoader(self, inq, outq, self.options.base32, profile) + def loader(self, inq, outq, *args, **kwargs): + return DedupLoader(self, inq, outq, self.options) - def storer(self, inq, outq, profile=False): - return warcprox.ListenerPostfetchProcessor(self, inq, outq, profile) + def storer(self, inq, outq, *args, **kwargs): + return warcprox.ListenerPostfetchProcessor( + self, inq, outq, self.options) def save(self, digest_key, response_record, bucket=""): record_id = response_record.get_header(warctools.WarcRecord.ID).decode('latin1') @@ -297,7 +298,8 @@ class BatchTroughLoader(warcprox.BaseBatchPostfetchProcessor): buckets = self._filter_and_bucketize(batch) for bucket in buckets: key_index = self._build_key_index(buckets[bucket]) - results = self.trough_dedup_db.batch_lookup(key_index.keys(), bucket) + results = self.trough_dedup_db.batch_lookup( + key_index.keys(), bucket) for result in results: for recorded_url in key_index[result['digest_key']]: recorded_url.dedup_info = result