mirror of
https://github.com/internetarchive/warcprox.git
synced 2025-01-18 13:22:09 +01:00
use trough.client instead of warcprox.trough
less redundant code! trough.client was based off of warcprox.trough but has been improved since then
This commit is contained in:
parent
f77c152037
commit
fe19bb268f
1
setup.py
1
setup.py
@ -35,6 +35,7 @@ deps = [
|
||||
'idna>=2.5',
|
||||
'PyYAML>=5.1',
|
||||
'cachetools',
|
||||
'trough>=0.1.2',
|
||||
]
|
||||
try:
|
||||
import concurrent.futures
|
||||
|
@ -2132,24 +2132,6 @@ def test_payload_digest(warcprox_, http_daemon):
|
||||
req, prox_rec_res = mitm.do_GET()
|
||||
assert warcprox.digest_str(prox_rec_res.payload_digest) == GZIP_GZIP_SHA1
|
||||
|
||||
def test_trough_segment_promotion(warcprox_):
|
||||
if not warcprox_.options.rethinkdb_trough_db_url:
|
||||
return
|
||||
cli = warcprox.trough.TroughClient(
|
||||
warcprox_.options.rethinkdb_trough_db_url, 3)
|
||||
promoted = []
|
||||
def mock(segment_id):
|
||||
promoted.append(segment_id)
|
||||
cli.promote = mock
|
||||
cli.register_schema('default', 'create table foo (bar varchar(100))')
|
||||
cli.write('my_seg', 'insert into foo (bar) values ("boof")')
|
||||
assert promoted == []
|
||||
time.sleep(3)
|
||||
assert promoted == ['my_seg']
|
||||
promoted = []
|
||||
time.sleep(3)
|
||||
assert promoted == []
|
||||
|
||||
def test_dedup_min_text_size(http_daemon, warcprox_, archiving_proxies):
|
||||
"""We use options --dedup-min-text-size=3 --dedup-min-binary-size=5 and we
|
||||
try to download content smaller than these limits to make sure that it is
|
||||
|
@ -26,7 +26,7 @@ import os
|
||||
import json
|
||||
from hanzo import warctools
|
||||
import warcprox
|
||||
import warcprox.trough
|
||||
import trough.client
|
||||
import sqlite3
|
||||
import doublethink
|
||||
import datetime
|
||||
@ -509,7 +509,7 @@ class TroughDedupDb(DedupDb, DedupableMixin):
|
||||
def __init__(self, options=warcprox.Options()):
|
||||
DedupableMixin.__init__(self, options)
|
||||
self.options = options
|
||||
self._trough_cli = warcprox.trough.TroughClient(
|
||||
self._trough_cli = trough.client.TroughClient(
|
||||
options.rethinkdb_trough_db_url, promotion_interval=60*60)
|
||||
|
||||
def loader(self, *args, **kwargs):
|
||||
@ -531,9 +531,13 @@ class TroughDedupDb(DedupDb, DedupableMixin):
|
||||
record_id = response_record.get_header(warctools.WarcRecord.ID)
|
||||
url = response_record.get_header(warctools.WarcRecord.URL)
|
||||
warc_date = response_record.get_header(warctools.WarcRecord.DATE)
|
||||
self._trough_cli.write(
|
||||
bucket, self.WRITE_SQL_TMPL,
|
||||
(digest_key, url, warc_date, record_id), self.SCHEMA_ID)
|
||||
try:
|
||||
self._trough_cli.write(
|
||||
bucket, self.WRITE_SQL_TMPL,
|
||||
(digest_key, url, warc_date, record_id), self.SCHEMA_ID)
|
||||
except:
|
||||
self.logger.warning(
|
||||
'problem posting dedup data to trough', exc_info=True)
|
||||
|
||||
def batch_save(self, batch, bucket='__unspecified__'):
|
||||
sql_tmpl = ('insert or ignore into dedup\n'
|
||||
@ -548,12 +552,22 @@ class TroughDedupDb(DedupDb, DedupableMixin):
|
||||
recorded_url.url,
|
||||
recorded_url.warc_records[0].date,
|
||||
recorded_url.warc_records[0].id,])
|
||||
self._trough_cli.write(bucket, sql_tmpl, values, self.SCHEMA_ID)
|
||||
try:
|
||||
self._trough_cli.write(bucket, sql_tmpl, values, self.SCHEMA_ID)
|
||||
except:
|
||||
self.logger.warning(
|
||||
'problem posting dedup data to trough', exc_info=True)
|
||||
|
||||
def lookup(self, digest_key, bucket='__unspecified__', url=None):
|
||||
results = self._trough_cli.read(
|
||||
bucket, 'select * from dedup where digest_key=%s;',
|
||||
(digest_key,))
|
||||
try:
|
||||
results = self._trough_cli.read(
|
||||
bucket, 'select * from dedup where digest_key=%s;',
|
||||
(digest_key,))
|
||||
except:
|
||||
self.logger.warning(
|
||||
'problem reading dedup data from trough', exc_info=True)
|
||||
return None
|
||||
|
||||
if results:
|
||||
assert len(results) == 1 # sanity check (digest_key is primary key)
|
||||
result = results[0]
|
||||
@ -570,7 +584,14 @@ class TroughDedupDb(DedupDb, DedupableMixin):
|
||||
'''Returns [{'digest_key': ..., 'url': ..., 'date': ...}, ...]'''
|
||||
sql_tmpl = 'select * from dedup where digest_key in (%s)' % (
|
||||
','.join('%s' for i in range(len(digest_keys))))
|
||||
results = self._trough_cli.read(bucket, sql_tmpl, digest_keys)
|
||||
|
||||
try:
|
||||
results = self._trough_cli.read(bucket, sql_tmpl, digest_keys)
|
||||
except:
|
||||
self.logger.warning(
|
||||
'problem reading dedup data from trough', exc_info=True)
|
||||
results = None
|
||||
|
||||
if results is None:
|
||||
return []
|
||||
self.logger.debug(
|
||||
|
@ -1,246 +0,0 @@
|
||||
'''
|
||||
warcprox/trough.py - trough client code
|
||||
|
||||
Copyright (C) 2017 Internet Archive
|
||||
|
||||
This program is free software; you can redistribute it and/or
|
||||
modify it under the terms of the GNU General Public License
|
||||
as published by the Free Software Foundation; either version 2
|
||||
of the License, or (at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program; if not, write to the Free Software
|
||||
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301,
|
||||
USA.
|
||||
'''
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
import doublethink
|
||||
import rethinkdb as r
|
||||
import datetime
|
||||
import threading
|
||||
import time
|
||||
|
||||
class TroughClient(object):
|
||||
logger = logging.getLogger("warcprox.trough.TroughClient")
|
||||
|
||||
def __init__(self, rethinkdb_trough_db_url, promotion_interval=None):
|
||||
'''
|
||||
TroughClient constructor
|
||||
|
||||
Args:
|
||||
rethinkdb_trough_db_url: url with schema rethinkdb:// pointing to
|
||||
trough configuration database
|
||||
promotion_interval: if specified, `TroughClient` will spawn a
|
||||
thread that "promotes" (pushed to hdfs) "dirty" trough segments
|
||||
(segments that have received writes) periodically, sleeping for
|
||||
`promotion_interval` seconds between cycles (default None)
|
||||
'''
|
||||
parsed = doublethink.parse_rethinkdb_url(rethinkdb_trough_db_url)
|
||||
self.rr = doublethink.Rethinker(
|
||||
servers=parsed.hosts, db=parsed.database)
|
||||
self.svcreg = doublethink.ServiceRegistry(self.rr)
|
||||
self._write_url_cache = {}
|
||||
self._read_url_cache = {}
|
||||
self._dirty_segments = set()
|
||||
self._dirty_segments_lock = threading.RLock()
|
||||
|
||||
self.promotion_interval = promotion_interval
|
||||
self._promoter_thread = None
|
||||
if promotion_interval:
|
||||
self._promoter_thread = threading.Thread(
|
||||
target=self._promotrix, name='TroughClient-promoter')
|
||||
self._promoter_thread.setDaemon(True)
|
||||
self._promoter_thread.start()
|
||||
|
||||
def _promotrix(self):
|
||||
while True:
|
||||
time.sleep(self.promotion_interval)
|
||||
try:
|
||||
with self._dirty_segments_lock:
|
||||
dirty_segments = list(self._dirty_segments)
|
||||
self._dirty_segments.clear()
|
||||
logging.info(
|
||||
'promoting %s trough segments', len(dirty_segments))
|
||||
for segment_id in dirty_segments:
|
||||
try:
|
||||
self.promote(segment_id)
|
||||
except:
|
||||
logging.error(
|
||||
'problem promoting segment %s', segment_id,
|
||||
exc_info=True)
|
||||
except:
|
||||
logging.error(
|
||||
'caught exception doing segment promotion',
|
||||
exc_info=True)
|
||||
|
||||
def promote(self, segment_id):
|
||||
url = os.path.join(self.segment_manager_url(), 'promote')
|
||||
payload_dict = {'segment': segment_id}
|
||||
response = requests.post(url, json=payload_dict, timeout=21600)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
'Received %s: %r in response to POST %s with data %s' % (
|
||||
response.status_code, response.text, url,
|
||||
json.dumps(payload_dict)))
|
||||
|
||||
@staticmethod
|
||||
def sql_value(x):
|
||||
if x is None:
|
||||
return 'null'
|
||||
elif isinstance(x, datetime.datetime):
|
||||
return 'datetime(%r)' % x.isoformat()
|
||||
elif isinstance(x, bool):
|
||||
return int(x)
|
||||
elif isinstance(x, str) or isinstance(x, bytes):
|
||||
# the only character that needs escaped in sqlite string literals
|
||||
# is single-quote, which is escaped as two single-quotes
|
||||
if isinstance(x, bytes):
|
||||
s = x.decode('utf-8')
|
||||
else:
|
||||
s = x
|
||||
return "'" + s.replace("'", "''") + "'"
|
||||
elif isinstance(x, (int, float)):
|
||||
return x
|
||||
else:
|
||||
raise Exception(
|
||||
"don't know how to make an sql value from %r (%r)" % (
|
||||
x, type(x)))
|
||||
|
||||
def segment_manager_url(self):
|
||||
master_node = self.svcreg.unique_service('trough-sync-master')
|
||||
assert master_node
|
||||
return master_node['url']
|
||||
|
||||
def write_url_nocache(self, segment_id, schema_id='default'):
|
||||
provision_url = os.path.join(self.segment_manager_url(), 'provision')
|
||||
payload_dict = {'segment': segment_id, 'schema': schema_id}
|
||||
response = requests.post(provision_url, json=payload_dict, timeout=600)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
'Received %s: %r in response to POST %s with data %s' % (
|
||||
response.status_code, response.text, provision_url,
|
||||
json.dumps(payload_dict)))
|
||||
result_dict = response.json()
|
||||
# assert result_dict['schema'] == schema_id # previously provisioned?
|
||||
return result_dict['write_url']
|
||||
|
||||
def read_url_nocache(self, segment_id):
|
||||
reql = self.rr.table('services').get_all(
|
||||
segment_id, index='segment').filter(
|
||||
{'role':'trough-read'}).filter(
|
||||
lambda svc: r.now().sub(
|
||||
svc['last_heartbeat']).lt(svc['ttl'])
|
||||
).order_by('load')
|
||||
self.logger.debug('querying rethinkdb: %r', reql)
|
||||
results = reql.run()
|
||||
if results:
|
||||
return results[0]['url']
|
||||
else:
|
||||
return None
|
||||
|
||||
def write_url(self, segment_id, schema_id='default'):
|
||||
if not segment_id in self._write_url_cache:
|
||||
self._write_url_cache[segment_id] = self.write_url_nocache(
|
||||
segment_id, schema_id)
|
||||
self.logger.info(
|
||||
'segment %r write url is %r', segment_id,
|
||||
self._write_url_cache[segment_id])
|
||||
return self._write_url_cache[segment_id]
|
||||
|
||||
def read_url(self, segment_id):
|
||||
if not self._read_url_cache.get(segment_id):
|
||||
self._read_url_cache[segment_id] = self.read_url_nocache(segment_id)
|
||||
self.logger.info(
|
||||
'segment %r read url is %r', segment_id,
|
||||
self._read_url_cache[segment_id])
|
||||
return self._read_url_cache[segment_id]
|
||||
|
||||
def write(self, segment_id, sql_tmpl, values=(), schema_id='default'):
|
||||
write_url = self.write_url(segment_id, schema_id)
|
||||
sql = sql_tmpl % tuple(self.sql_value(v) for v in values)
|
||||
sql_bytes = sql.encode('utf-8')
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
write_url, sql_bytes, timeout=600,
|
||||
headers={'content-type': 'application/sql;charset=utf-8'})
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
'Received %s: %r in response to POST %s with data %r' % (
|
||||
response.status_code, response.text, write_url, sql))
|
||||
if segment_id not in self._dirty_segments:
|
||||
with self._dirty_segments_lock:
|
||||
self._dirty_segments.add(segment_id)
|
||||
except:
|
||||
self._write_url_cache.pop(segment_id, None)
|
||||
self.logger.error(
|
||||
'problem with trough write url %r', write_url,
|
||||
exc_info=True)
|
||||
return
|
||||
if response.status_code != 200:
|
||||
self._write_url_cache.pop(segment_id, None)
|
||||
self.logger.warning(
|
||||
'unexpected response %r %r %r from %r to sql=%r',
|
||||
response.status_code, response.reason, response.text,
|
||||
write_url, sql)
|
||||
return
|
||||
self.logger.debug('posted to %s: %r', write_url, sql)
|
||||
|
||||
def read(self, segment_id, sql_tmpl, values=()):
|
||||
read_url = self.read_url(segment_id)
|
||||
if not read_url:
|
||||
return None
|
||||
sql = sql_tmpl % tuple(self.sql_value(v) for v in values)
|
||||
sql_bytes = sql.encode('utf-8')
|
||||
try:
|
||||
response = requests.post(
|
||||
read_url, sql_bytes, timeout=600,
|
||||
headers={'content-type': 'application/sql;charset=utf-8'})
|
||||
except:
|
||||
self._read_url_cache.pop(segment_id, None)
|
||||
self.logger.error(
|
||||
'problem with trough read url %r', read_url, exc_info=True)
|
||||
return None
|
||||
if response.status_code != 200:
|
||||
self._read_url_cache.pop(segment_id, None)
|
||||
self.logger.warn(
|
||||
'unexpected response %r %r %r from %r to sql=%r',
|
||||
response.status_code, response.reason, response.text,
|
||||
read_url, sql)
|
||||
return None
|
||||
self.logger.trace(
|
||||
'got %r from posting query %r to %r', response.text, sql,
|
||||
read_url)
|
||||
results = json.loads(response.text)
|
||||
return results
|
||||
|
||||
def schema_exists(self, schema_id):
|
||||
url = os.path.join(self.segment_manager_url(), 'schema', schema_id)
|
||||
response = requests.get(url, timeout=60)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
elif response.status_code == 404:
|
||||
return False
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
def register_schema(self, schema_id, sql):
|
||||
url = os.path.join(
|
||||
self.segment_manager_url(), 'schema', schema_id, 'sql')
|
||||
response = requests.put(url, sql, timeout=600)
|
||||
if response.status_code not in (201, 204):
|
||||
raise Exception(
|
||||
'Received %s: %r in response to PUT %r with data %r' % (
|
||||
response.status_code, response.text, sql, url))
|
||||
|
Loading…
x
Reference in New Issue
Block a user