diff --git a/warcprox/dedup.py b/warcprox/dedup.py index cd3b397..cb65408 100644 --- a/warcprox/dedup.py +++ b/warcprox/dedup.py @@ -33,6 +33,7 @@ import datetime import urllib3 from urllib3.exceptions import HTTPError import collections +from concurrent import futures urllib3.disable_warnings() @@ -289,8 +290,26 @@ class BatchTroughStorer(warcprox.BaseBatchPostfetchProcessor): def _process_batch(self, batch): buckets = self._filter_and_bucketize(batch) - for bucket in buckets: - self.trough_dedup_db.batch_save(buckets[bucket], bucket) + if not buckets: + return + fs = {} + with futures.ThreadPoolExecutor(max_workers=len(buckets)) as pool: + # send off requests in parallel + for bucket in buckets: + future = pool.submit( + self.trough_dedup_db.batch_save, + buckets[bucket], bucket) + fs[future] = bucket + + # wait for results + try: + for future in futures.as_completed(fs, timeout=20): + pass + except futures.TimeoutError as e: + # the remaining threads actually keep running in this case, + # there's no way to stop them, but that should be harmless + logging.warn( + 'timed out saving dedup info to trough', exc_info=True) class BatchTroughLoader(warcprox.BaseBatchPostfetchProcessor): def __init__(self, trough_dedup_db, options=warcprox.Options()): @@ -320,7 +339,13 @@ class BatchTroughLoader(warcprox.BaseBatchPostfetchProcessor): def _build_key_index(self, batch): ''' - Returns `{digest_key: [recorded_url, ...]}`. + Builds index of RecordedUrl by digest key. + + Args: + batch(list): list of RecordedUrl + + Returns: + dict `{digest_key: [recorded_url, ...]}` ''' key_index = collections.defaultdict(list) for recorded_url in batch: @@ -331,13 +356,37 @@ class BatchTroughLoader(warcprox.BaseBatchPostfetchProcessor): def _process_batch(self, batch): buckets = self._filter_and_bucketize(batch) - for bucket in buckets: - key_index = self._build_key_index(buckets[bucket]) - results = self.trough_dedup_db.batch_lookup( - key_index.keys(), bucket) - for result in results: - for recorded_url in key_index[result['digest_key']]: - recorded_url.dedup_info = result + if not buckets: + return + fs = {} + with futures.ThreadPoolExecutor(max_workers=len(buckets)) as pool: + # send off the trough requests in parallel + for bucket in buckets: + key_index = self._build_key_index(buckets[bucket]) + future = pool.submit( + self.trough_dedup_db.batch_lookup, + key_index.keys(), bucket) + fs[future] = bucket + + # process results as they come back + try: + for future in futures.as_completed(fs, timeout=20): + bucket = fs[future] + try: + for entry in future.result(): + for recorded_url in key_index[entry['digest_key']]: + recorded_url.dedup_info = entry + except Exception as e: + # batch_lookup raised exception or something + logging.warn( + 'problem looking up dedup info for %s urls ' + 'in bucket %s', len(buckets[bucket]), bucket, + exc_info=True) + except futures.TimeoutError as e: + # the remaining threads actually keep running in this case, + # there's no way to stop them, but that should be harmless + logging.warn( + 'timed out loading dedup info from trough', exc_info=True) class TroughDedupDb(DedupDb): ''' @@ -409,6 +458,7 @@ class TroughDedupDb(DedupDb): return None def batch_lookup(self, digest_keys, bucket='__unspecified__'): + '''Returns [{'digest_key': ..., 'url': ..., 'date': ...}, ...]''' sql_tmpl = 'select * from dedup where digest_key in (%s)' % ( ','.join('%s' for i in range(len(digest_keys)))) results = self._trough_cli.read(bucket, sql_tmpl, digest_keys) @@ -419,7 +469,7 @@ class TroughDedupDb(DedupDb): len(digest_keys), len(results)) assert len(results) >= 0 and len(results) <= len(digest_keys) for result in results: - result['id'] = result['id'].encode('ascii') + result['id'] = result.get('id') and result['id'].encode('ascii') result['url'] = result['url'].encode('ascii') result['date'] = result['date'].encode('ascii') result['digest_key'] = result['digest_key'].encode('ascii')