improve batching, make tests pass

This commit is contained in:
Noah Levitt 2018-01-16 15:18:53 -08:00
parent d4bbaf10b7
commit 6ff9030e67
3 changed files with 33 additions and 15 deletions

View File

@ -339,6 +339,9 @@ def warcprox_(request):
logging.info('changing to working directory %r', work_dir) logging.info('changing to working directory %r', work_dir)
os.chdir(work_dir) os.chdir(work_dir)
# we can't wait around all day in the tests
warcprox.BaseBatchPostfetchProcessor.MAX_BATCH_SEC = 0.5
argv = ['warcprox', argv = ['warcprox',
'--method-filter=GET', '--method-filter=GET',
'--method-filter=POST', '--method-filter=POST',
@ -362,6 +365,7 @@ def warcprox_(request):
warcprox_ = warcprox.controller.WarcproxController(options) warcprox_ = warcprox.controller.WarcproxController(options)
logging.info('starting warcprox') logging.info('starting warcprox')
warcprox_.start()
warcprox_thread = threading.Thread( warcprox_thread = threading.Thread(
name='WarcproxThread', target=warcprox_.run_until_shutdown) name='WarcproxThread', target=warcprox_.run_until_shutdown)
warcprox_thread.start() warcprox_thread.start()
@ -1372,7 +1376,7 @@ def test_controller_with_defaults():
wwt = controller.warc_writer_thread wwt = controller.warc_writer_thread
assert wwt assert wwt
assert wwt.inq assert wwt.inq
assert not wwt.outq assert wwt.outq
assert wwt.writer_pool assert wwt.writer_pool
assert wwt.writer_pool.default_warc_writer assert wwt.writer_pool.default_warc_writer
assert wwt.writer_pool.default_warc_writer.directory == './warcs' assert wwt.writer_pool.default_warc_writer.directory == './warcs'
@ -1396,6 +1400,7 @@ def test_choose_a_port_for_me(warcprox_):
'127.0.0.1', controller.proxy.server_port) '127.0.0.1', controller.proxy.server_port)
th = threading.Thread(target=controller.run_until_shutdown) th = threading.Thread(target=controller.run_until_shutdown)
controller.start()
th.start() th.start()
try: try:
# check that the status api lists the correct port # check that the status api lists the correct port

View File

@ -167,16 +167,27 @@ class BaseStandardPostfetchProcessor(BasePostfetchProcessor):
class BaseBatchPostfetchProcessor(BasePostfetchProcessor): class BaseBatchPostfetchProcessor(BasePostfetchProcessor):
MAX_BATCH_SIZE = 500 MAX_BATCH_SIZE = 500
MAX_BATCH_SEC = 10
def _get_process_put(self): def _get_process_put(self):
batch = [] batch = []
batch.append(self.inq.get(block=True, timeout=0.5)) start = time.time()
try:
while len(batch) < self.MAX_BATCH_SIZE:
batch.append(self.inq.get(block=False))
except queue.Empty:
pass
while (len(batch) < self.MAX_BATCH_SIZE
and time.time() - start < self.MAX_BATCH_SEC):
try:
batch.append(self.inq.get(block=True, timeout=0.5))
except queue.Empty:
if self.stop.is_set():
break
# else keep adding to the batch
if not batch:
raise queue.Empty
self.logger.info(
'gathered batch of %s in %0.1f sec',
len(batch), time.time() - start)
self._process_batch(batch) self._process_batch(batch)
if self.outq: if self.outq:
@ -187,8 +198,8 @@ class BaseBatchPostfetchProcessor(BasePostfetchProcessor):
raise Exception('not implemented') raise Exception('not implemented')
class ListenerPostfetchProcessor(BaseStandardPostfetchProcessor): class ListenerPostfetchProcessor(BaseStandardPostfetchProcessor):
def __init__(self, listener, inq, outq, profile=False): def __init__(self, listener, inq, outq, options=Options()):
BaseStandardPostfetchProcessor.__init__(self, inq, outq, profile) BaseStandardPostfetchProcessor.__init__(self, inq, outq, options)
self.listener = listener self.listener = listener
self.name = listener.__class__.__name__ self.name = listener.__class__.__name__

View File

@ -39,7 +39,7 @@ urllib3.disable_warnings()
class DedupLoader(warcprox.BaseStandardPostfetchProcessor): class DedupLoader(warcprox.BaseStandardPostfetchProcessor):
def __init__(self, dedup_db, inq, outq, options=warcprox.Options()): def __init__(self, dedup_db, inq, outq, options=warcprox.Options()):
warcprox.BaseStandardPostfetchProcessor.__init__( warcprox.BaseStandardPostfetchProcessor.__init__(
self, inq, outq, profile) self, inq, outq, options)
self.dedup_db = dedup_db self.dedup_db = dedup_db
def _process_url(self, recorded_url): def _process_url(self, recorded_url):
decorate_with_dedup_info( decorate_with_dedup_info(
@ -71,11 +71,12 @@ class DedupDb(object):
conn.commit() conn.commit()
conn.close() conn.close()
def loader(self, inq, outq, profile=False): def loader(self, inq, outq, *args, **kwargs):
return DedupLoader(self, inq, outq, self.options.base32, profile) return DedupLoader(self, inq, outq, self.options)
def storer(self, inq, outq, profile=False): def storer(self, inq, outq, *args, **kwargs):
return warcprox.ListenerPostfetchProcessor(self, inq, outq, profile) return warcprox.ListenerPostfetchProcessor(
self, inq, outq, self.options)
def save(self, digest_key, response_record, bucket=""): def save(self, digest_key, response_record, bucket=""):
record_id = response_record.get_header(warctools.WarcRecord.ID).decode('latin1') record_id = response_record.get_header(warctools.WarcRecord.ID).decode('latin1')
@ -297,7 +298,8 @@ class BatchTroughLoader(warcprox.BaseBatchPostfetchProcessor):
buckets = self._filter_and_bucketize(batch) buckets = self._filter_and_bucketize(batch)
for bucket in buckets: for bucket in buckets:
key_index = self._build_key_index(buckets[bucket]) key_index = self._build_key_index(buckets[bucket])
results = self.trough_dedup_db.batch_lookup(key_index.keys(), bucket) results = self.trough_dedup_db.batch_lookup(
key_index.keys(), bucket)
for result in results: for result in results:
for recorded_url in key_index[result['digest_key']]: for recorded_url in key_index[result['digest_key']]:
recorded_url.dedup_info = result recorded_url.dedup_info = result