From fe19bb268ff9cc03f6059fab7a09c3bd67ed2d8a Mon Sep 17 00:00:00 2001 From: Noah Levitt Date: Tue, 19 Nov 2019 11:45:14 -0800 Subject: [PATCH] use trough.client instead of warcprox.trough less redundant code! trough.client was based off of warcprox.trough but has been improved since then --- setup.py | 1 + tests/test_warcprox.py | 18 --- warcprox/dedup.py | 41 +++++-- warcprox/trough.py | 246 ----------------------------------------- 4 files changed, 32 insertions(+), 274 deletions(-) delete mode 100644 warcprox/trough.py diff --git a/setup.py b/setup.py index 63d8488..7c7185f 100755 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ deps = [ 'idna>=2.5', 'PyYAML>=5.1', 'cachetools', + 'trough>=0.1.2', ] try: import concurrent.futures diff --git a/tests/test_warcprox.py b/tests/test_warcprox.py index d95f51c..447d2b8 100755 --- a/tests/test_warcprox.py +++ b/tests/test_warcprox.py @@ -2132,24 +2132,6 @@ def test_payload_digest(warcprox_, http_daemon): req, prox_rec_res = mitm.do_GET() assert warcprox.digest_str(prox_rec_res.payload_digest) == GZIP_GZIP_SHA1 -def test_trough_segment_promotion(warcprox_): - if not warcprox_.options.rethinkdb_trough_db_url: - return - cli = warcprox.trough.TroughClient( - warcprox_.options.rethinkdb_trough_db_url, 3) - promoted = [] - def mock(segment_id): - promoted.append(segment_id) - cli.promote = mock - cli.register_schema('default', 'create table foo (bar varchar(100))') - cli.write('my_seg', 'insert into foo (bar) values ("boof")') - assert promoted == [] - time.sleep(3) - assert promoted == ['my_seg'] - promoted = [] - time.sleep(3) - assert promoted == [] - def test_dedup_min_text_size(http_daemon, warcprox_, archiving_proxies): """We use options --dedup-min-text-size=3 --dedup-min-binary-size=5 and we try to download content smaller than these limits to make sure that it is diff --git a/warcprox/dedup.py b/warcprox/dedup.py index 0e09239..17b332b 100644 --- a/warcprox/dedup.py +++ b/warcprox/dedup.py @@ -26,7 +26,7 @@ import os import json from hanzo import warctools import warcprox -import warcprox.trough +import trough.client import sqlite3 import doublethink import datetime @@ -509,7 +509,7 @@ class TroughDedupDb(DedupDb, DedupableMixin): def __init__(self, options=warcprox.Options()): DedupableMixin.__init__(self, options) self.options = options - self._trough_cli = warcprox.trough.TroughClient( + self._trough_cli = trough.client.TroughClient( options.rethinkdb_trough_db_url, promotion_interval=60*60) def loader(self, *args, **kwargs): @@ -531,9 +531,13 @@ class TroughDedupDb(DedupDb, DedupableMixin): record_id = response_record.get_header(warctools.WarcRecord.ID) url = response_record.get_header(warctools.WarcRecord.URL) warc_date = response_record.get_header(warctools.WarcRecord.DATE) - self._trough_cli.write( - bucket, self.WRITE_SQL_TMPL, - (digest_key, url, warc_date, record_id), self.SCHEMA_ID) + try: + self._trough_cli.write( + bucket, self.WRITE_SQL_TMPL, + (digest_key, url, warc_date, record_id), self.SCHEMA_ID) + except: + self.logger.warning( + 'problem posting dedup data to trough', exc_info=True) def batch_save(self, batch, bucket='__unspecified__'): sql_tmpl = ('insert or ignore into dedup\n' @@ -548,12 +552,22 @@ class TroughDedupDb(DedupDb, DedupableMixin): recorded_url.url, recorded_url.warc_records[0].date, recorded_url.warc_records[0].id,]) - self._trough_cli.write(bucket, sql_tmpl, values, self.SCHEMA_ID) + try: + self._trough_cli.write(bucket, sql_tmpl, values, self.SCHEMA_ID) + except: + self.logger.warning( + 'problem posting dedup data to trough', exc_info=True) def lookup(self, digest_key, bucket='__unspecified__', url=None): - results = self._trough_cli.read( - bucket, 'select * from dedup where digest_key=%s;', - (digest_key,)) + try: + results = self._trough_cli.read( + bucket, 'select * from dedup where digest_key=%s;', + (digest_key,)) + except: + self.logger.warning( + 'problem reading dedup data from trough', exc_info=True) + return None + if results: assert len(results) == 1 # sanity check (digest_key is primary key) result = results[0] @@ -570,7 +584,14 @@ class TroughDedupDb(DedupDb, DedupableMixin): '''Returns [{'digest_key': ..., 'url': ..., 'date': ...}, ...]''' sql_tmpl = 'select * from dedup where digest_key in (%s)' % ( ','.join('%s' for i in range(len(digest_keys)))) - results = self._trough_cli.read(bucket, sql_tmpl, digest_keys) + + try: + results = self._trough_cli.read(bucket, sql_tmpl, digest_keys) + except: + self.logger.warning( + 'problem reading dedup data from trough', exc_info=True) + results = None + if results is None: return [] self.logger.debug( diff --git a/warcprox/trough.py b/warcprox/trough.py deleted file mode 100644 index d0839d1..0000000 --- a/warcprox/trough.py +++ /dev/null @@ -1,246 +0,0 @@ -''' -warcprox/trough.py - trough client code - -Copyright (C) 2017 Internet Archive - -This program is free software; you can redistribute it and/or -modify it under the terms of the GNU General Public License -as published by the Free Software Foundation; either version 2 -of the License, or (at your option) any later version. - -This program is distributed in the hope that it will be useful, -but WITHOUT ANY WARRANTY; without even the implied warranty of -MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -GNU General Public License for more details. - -You should have received a copy of the GNU General Public License -along with this program; if not, write to the Free Software -Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, -USA. -''' - -from __future__ import absolute_import - -import logging -import os -import json -import requests -import doublethink -import rethinkdb as r -import datetime -import threading -import time - -class TroughClient(object): - logger = logging.getLogger("warcprox.trough.TroughClient") - - def __init__(self, rethinkdb_trough_db_url, promotion_interval=None): - ''' - TroughClient constructor - - Args: - rethinkdb_trough_db_url: url with schema rethinkdb:// pointing to - trough configuration database - promotion_interval: if specified, `TroughClient` will spawn a - thread that "promotes" (pushed to hdfs) "dirty" trough segments - (segments that have received writes) periodically, sleeping for - `promotion_interval` seconds between cycles (default None) - ''' - parsed = doublethink.parse_rethinkdb_url(rethinkdb_trough_db_url) - self.rr = doublethink.Rethinker( - servers=parsed.hosts, db=parsed.database) - self.svcreg = doublethink.ServiceRegistry(self.rr) - self._write_url_cache = {} - self._read_url_cache = {} - self._dirty_segments = set() - self._dirty_segments_lock = threading.RLock() - - self.promotion_interval = promotion_interval - self._promoter_thread = None - if promotion_interval: - self._promoter_thread = threading.Thread( - target=self._promotrix, name='TroughClient-promoter') - self._promoter_thread.setDaemon(True) - self._promoter_thread.start() - - def _promotrix(self): - while True: - time.sleep(self.promotion_interval) - try: - with self._dirty_segments_lock: - dirty_segments = list(self._dirty_segments) - self._dirty_segments.clear() - logging.info( - 'promoting %s trough segments', len(dirty_segments)) - for segment_id in dirty_segments: - try: - self.promote(segment_id) - except: - logging.error( - 'problem promoting segment %s', segment_id, - exc_info=True) - except: - logging.error( - 'caught exception doing segment promotion', - exc_info=True) - - def promote(self, segment_id): - url = os.path.join(self.segment_manager_url(), 'promote') - payload_dict = {'segment': segment_id} - response = requests.post(url, json=payload_dict, timeout=21600) - if response.status_code != 200: - raise Exception( - 'Received %s: %r in response to POST %s with data %s' % ( - response.status_code, response.text, url, - json.dumps(payload_dict))) - - @staticmethod - def sql_value(x): - if x is None: - return 'null' - elif isinstance(x, datetime.datetime): - return 'datetime(%r)' % x.isoformat() - elif isinstance(x, bool): - return int(x) - elif isinstance(x, str) or isinstance(x, bytes): - # the only character that needs escaped in sqlite string literals - # is single-quote, which is escaped as two single-quotes - if isinstance(x, bytes): - s = x.decode('utf-8') - else: - s = x - return "'" + s.replace("'", "''") + "'" - elif isinstance(x, (int, float)): - return x - else: - raise Exception( - "don't know how to make an sql value from %r (%r)" % ( - x, type(x))) - - def segment_manager_url(self): - master_node = self.svcreg.unique_service('trough-sync-master') - assert master_node - return master_node['url'] - - def write_url_nocache(self, segment_id, schema_id='default'): - provision_url = os.path.join(self.segment_manager_url(), 'provision') - payload_dict = {'segment': segment_id, 'schema': schema_id} - response = requests.post(provision_url, json=payload_dict, timeout=600) - if response.status_code != 200: - raise Exception( - 'Received %s: %r in response to POST %s with data %s' % ( - response.status_code, response.text, provision_url, - json.dumps(payload_dict))) - result_dict = response.json() - # assert result_dict['schema'] == schema_id # previously provisioned? - return result_dict['write_url'] - - def read_url_nocache(self, segment_id): - reql = self.rr.table('services').get_all( - segment_id, index='segment').filter( - {'role':'trough-read'}).filter( - lambda svc: r.now().sub( - svc['last_heartbeat']).lt(svc['ttl']) - ).order_by('load') - self.logger.debug('querying rethinkdb: %r', reql) - results = reql.run() - if results: - return results[0]['url'] - else: - return None - - def write_url(self, segment_id, schema_id='default'): - if not segment_id in self._write_url_cache: - self._write_url_cache[segment_id] = self.write_url_nocache( - segment_id, schema_id) - self.logger.info( - 'segment %r write url is %r', segment_id, - self._write_url_cache[segment_id]) - return self._write_url_cache[segment_id] - - def read_url(self, segment_id): - if not self._read_url_cache.get(segment_id): - self._read_url_cache[segment_id] = self.read_url_nocache(segment_id) - self.logger.info( - 'segment %r read url is %r', segment_id, - self._read_url_cache[segment_id]) - return self._read_url_cache[segment_id] - - def write(self, segment_id, sql_tmpl, values=(), schema_id='default'): - write_url = self.write_url(segment_id, schema_id) - sql = sql_tmpl % tuple(self.sql_value(v) for v in values) - sql_bytes = sql.encode('utf-8') - - try: - response = requests.post( - write_url, sql_bytes, timeout=600, - headers={'content-type': 'application/sql;charset=utf-8'}) - if response.status_code != 200: - raise Exception( - 'Received %s: %r in response to POST %s with data %r' % ( - response.status_code, response.text, write_url, sql)) - if segment_id not in self._dirty_segments: - with self._dirty_segments_lock: - self._dirty_segments.add(segment_id) - except: - self._write_url_cache.pop(segment_id, None) - self.logger.error( - 'problem with trough write url %r', write_url, - exc_info=True) - return - if response.status_code != 200: - self._write_url_cache.pop(segment_id, None) - self.logger.warning( - 'unexpected response %r %r %r from %r to sql=%r', - response.status_code, response.reason, response.text, - write_url, sql) - return - self.logger.debug('posted to %s: %r', write_url, sql) - - def read(self, segment_id, sql_tmpl, values=()): - read_url = self.read_url(segment_id) - if not read_url: - return None - sql = sql_tmpl % tuple(self.sql_value(v) for v in values) - sql_bytes = sql.encode('utf-8') - try: - response = requests.post( - read_url, sql_bytes, timeout=600, - headers={'content-type': 'application/sql;charset=utf-8'}) - except: - self._read_url_cache.pop(segment_id, None) - self.logger.error( - 'problem with trough read url %r', read_url, exc_info=True) - return None - if response.status_code != 200: - self._read_url_cache.pop(segment_id, None) - self.logger.warn( - 'unexpected response %r %r %r from %r to sql=%r', - response.status_code, response.reason, response.text, - read_url, sql) - return None - self.logger.trace( - 'got %r from posting query %r to %r', response.text, sql, - read_url) - results = json.loads(response.text) - return results - - def schema_exists(self, schema_id): - url = os.path.join(self.segment_manager_url(), 'schema', schema_id) - response = requests.get(url, timeout=60) - if response.status_code == 200: - return True - elif response.status_code == 404: - return False - else: - response.raise_for_status() - - def register_schema(self, schema_id, sql): - url = os.path.join( - self.segment_manager_url(), 'schema', schema_id, 'sql') - response = requests.put(url, sql, timeout=600) - if response.status_code not in (201, 204): - raise Exception( - 'Received %s: %r in response to PUT %r with data %r' % ( - response.status_code, response.text, sql, url)) -