diff --git a/tests/test_warcprox.py b/tests/test_warcprox.py index 5587d8f..e8c140b 100755 --- a/tests/test_warcprox.py +++ b/tests/test_warcprox.py @@ -1721,6 +1721,24 @@ def test_payload_digest(warcprox_, http_daemon): req, prox_rec_res = mitm.do_GET() assert warcprox.digest_str(prox_rec_res.payload_digest) == GZIP_GZIP_SHA1 +def test_trough_segment_promotion(warcprox_): + if not warcprox_.options.rethinkdb_trough_db_url: + return + cli = warcprox.trough.TroughClient( + warcprox_.options.rethinkdb_trough_db_url, 3) + promoted = [] + def mock(segment_id): + promoted.append(segment_id) + cli.promote = mock + cli.register_schema('default', 'create table foo (bar varchar(100))') + cli.write('my_seg', 'insert into foo (bar) values ("boof")') + assert promoted == [] + time.sleep(3) + assert promoted == ['my_seg'] + promoted = [] + time.sleep(3) + assert promoted == [] + if __name__ == '__main__': pytest.main() diff --git a/warcprox/dedup.py b/warcprox/dedup.py index 2364d41..d1e456d 100644 --- a/warcprox/dedup.py +++ b/warcprox/dedup.py @@ -262,7 +262,7 @@ class TroughDedupDb(object): def __init__(self, options=warcprox.Options()): self.options = options self._trough_cli = warcprox.trough.TroughClient( - options.rethinkdb_trough_db_url) + options.rethinkdb_trough_db_url, promotion_interval=60*60) def start(self): self._trough_cli.register_schema(self.SCHEMA_ID, self.SCHEMA_SQL) diff --git a/warcprox/trough.py b/warcprox/trough.py index ec3a032..6cbe1dd 100644 --- a/warcprox/trough.py +++ b/warcprox/trough.py @@ -1,7 +1,7 @@ ''' warcprox/trough.py - trough client code -Copyright (C) 2013-2017 Internet Archive +Copyright (C) 2017 Internet Archive This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License @@ -28,17 +28,69 @@ import requests import doublethink import rethinkdb as r import datetime +import threading +import time class TroughClient(object): logger = logging.getLogger("warcprox.trough.TroughClient") - def __init__(self, rethinkdb_trough_db_url): + def __init__(self, rethinkdb_trough_db_url, promotion_interval=None): + ''' + TroughClient constructor + + Args: + rethinkdb_trough_db_url: url with schema rethinkdb:// pointing to + trough configuration database + promotion_interval: if specified, `TroughClient` will spawn a + thread that "promotes" (pushed to hdfs) "dirty" trough segments + (segments that have received writes) periodically, sleeping for + `promotion_interval` seconds between cycles (default None) + ''' parsed = doublethink.parse_rethinkdb_url(rethinkdb_trough_db_url) self.rr = doublethink.Rethinker( servers=parsed.hosts, db=parsed.database) self.svcreg = doublethink.ServiceRegistry(self.rr) self._write_url_cache = {} self._read_url_cache = {} + self._dirty_segments = set() + self._dirty_segments_lock = threading.RLock() + + self.promotion_interval = promotion_interval + self._promoter_thread = None + if promotion_interval: + self._promoter_thread = threading.Thread( + target=self._promotrix, name='TroughClient-promoter', + daemon=True) + self._promoter_thread.start() + + def _promotrix(self): + while True: + time.sleep(self.promotion_interval) + try: + with self._dirty_segments_lock: + dirty_segments = list(self._dirty_segments) + self._dirty_segments.clear() + logging.info('promoting %s trough segments') + for segment in dirty_segments: + try: + self.promote(segment) + except: + logging.error( + 'problem promoting segment %s', exc_info=True) + except: + logging.error( + 'caught exception doing segment promotion', + exc_info=True) + + def promote(self, segment_id): + url = os.path.join(self.segment_manager_url(), 'promote') + payload_dict = {'segment': segment_id} + response = requests.post(url, json=payload_dict) + if response.status_code != 200: + raise Exception( + 'Received %s: %r in response to POST %s with data %s' % ( + response.status_code, response.text, url, + json.dumps(payload_dict))) @staticmethod def sql_value(x): @@ -116,12 +168,15 @@ class TroughClient(object): self._read_url_cache[segment_id]) return self._read_url_cache[segment_id] - def write(self, segment_id, sql_tmpl, values, schema_id='default'): + def write(self, segment_id, sql_tmpl, values=(), schema_id='default'): write_url = self.write_url(segment_id, schema_id) sql = sql_tmpl % tuple(self.sql_value(v) for v in values) try: response = requests.post(write_url, sql) + if segment_id not in self._dirty_segments: + with self._dirty_segments_lock: + self._dirty_segments.add(segment_id) except: del self._write_url_cache[segment_id] self.logger.error( @@ -137,7 +192,7 @@ class TroughClient(object): return self.logger.debug('posted %r to %s', sql, write_url) - def read(self, segment_id, sql_tmpl, values): + def read(self, segment_id, sql_tmpl, values=()): read_url = self.read_url(segment_id) if not read_url: return None @@ -173,7 +228,8 @@ class TroughClient(object): response.raise_for_status() def register_schema(self, schema_id, sql): - url = '%s/schema/%s/sql' % (self.segment_manager_url(), schema_id) + url = os.path.join( + self.segment_manager_url(), 'schema', schema_id, 'sql') response = requests.put(url, sql) if response.status_code not in (201, 204): raise Exception(