1
0
mirror of https://github.com/webrecorder/pywb.git synced 2025-03-25 23:47:47 +01:00
pywb/pywb/webagg/aggregator.py
Ilya Kreymer 06b1134be5 aggregator: support 'invert_sources' option to exclude source list, rather than include
can be set explicitly or via '!' on the sources list
tests: test invert sources
filters: include params to skip_response() filter
warc headers: change headers for recording from other source to: WARC-Source-URI and WARC-Created-Date
2017-06-01 07:45:02 -07:00

367 lines
12 KiB
Python

from gevent.pool import Pool
import gevent
import json
import time
import os
from warcio.timeutils import timestamp_now
from pywb.cdx.cdxops import process_cdx
from pywb.cdx.query import CDXQuery
from heapq import merge
from collections import deque
from itertools import chain
from pywb.webagg.indexsource import FileIndexSource, RedisIndexSource
from pywb.utils.wbexception import NotFoundException, WbException
from pywb.webagg.utils import ParamFormatter, res_template
import six
import glob
#=============================================================================
class BaseAggregator(object):
def __call__(self, params):
if params.get('closest') == 'now':
params['closest'] = timestamp_now()
content_type = params.get('content_type')
if content_type:
params['filter'] = '=mime:' + content_type
query = CDXQuery(params)
cdx_iter, errs = self.load_index(query.params)
cdx_iter = process_cdx(cdx_iter, query)
return cdx_iter, dict(errs)
def load_child_source(self, name, source, params):
try:
params['_name'] = name
params['_formatter'] = ParamFormatter(params, name)
res = source.load_index(params)
if isinstance(res, tuple):
cdx_iter, err_list = res
else:
cdx_iter = res
err_list = []
except WbException as wbe:
#print('Not found in ' + name)
cdx_iter = iter([])
err_list = [(name, repr(wbe))]
def add_name(cdx, name):
if cdx.get('source'):
cdx['source'] = name + ':' + cdx['source']
else:
cdx['source'] = name
return cdx
if params.get('nosource') != 'true':
src_coll = params.get('param.' + name + '.src_coll')
if src_coll:
name += ':' + src_coll
cdx_iter = (add_name(cdx, name) for cdx in cdx_iter)
return cdx_iter, err_list
def load_index(self, params):
res_list = self._load_all(params)
iter_list = [res[0] for res in res_list]
err_list = chain(*[res[1] for res in res_list])
#optimization: if only a single entry (or empty) just load directly
if len(iter_list) <= 1:
cdx_iter = iter_list[0] if iter_list else iter([])
else:
cdx_iter = merge(*(iter_list))
return cdx_iter, err_list
def _on_source_error(self, name): #pragma: no cover
pass
def _load_all(self, params): #pragma: no cover
raise NotImplemented()
def _iter_sources(self, params): #pragma: no cover
raise NotImplemented()
def get_source_list(self, params):
sources = self._iter_sources(params)
result = [(name, str(value)) for name, value in sources]
result = {'sources': dict(result)}
return result
#=============================================================================
class BaseSourceListAggregator(BaseAggregator):
def __init__(self, sources, **kwargs):
self.sources = sources
self.sources_key = kwargs.get('sources_key', 'sources')
self.invert_sources = kwargs.get('invert_sources', False)
def get_all_sources(self, params):
return self.sources
def _iter_sources(self, params):
invert_sources = self.invert_sources
sel_sources = params.get(self.sources_key)
if sel_sources and sel_sources[0] == '!':
invert_sources = True
sel_sources = sel_sources[1:]
if not sel_sources or sel_sources == '*':
if not invert_sources:
return six.iteritems(self.get_all_sources(params))
else:
return iter([])
if not invert_sources:
return self.yield_sources(sel_sources, params)
else:
return self.yield_invert_sources(sel_sources, params)
def yield_sources(self, sel_sources, params):
sources = self.get_all_sources(params)
sel_sources = tuple(sel_sources.split(','))
for name in sel_sources:
if name in sources:
yield (name, sources[name])
elif ':' in name:
name, param = name.split(':', 1)
if name in sources:
params['param.' + name + '.src_coll'] = param
yield (name, sources[name])
def yield_invert_sources(self, sel_sources, params):
sources = self.get_all_sources(params)
sel_sources = tuple([src.split(':', 1)[0]
for src in sel_sources.split(',')])
for name in six.iterkeys(sources):
if name not in sel_sources:
yield (name, sources[name])
#=============================================================================
class SeqAggMixin(object):
def __init__(self, *args, **kwargs):
super(SeqAggMixin, self).__init__(*args, **kwargs)
def _load_all(self, params):
sources = self._iter_sources(params)
return [self.load_child_source(name, source, params)
for name, source in sources]
#=============================================================================
class SimpleAggregator(SeqAggMixin, BaseSourceListAggregator):
pass
#=============================================================================
class TimeoutMixin(object):
def __init__(self, *args, **kwargs):
super(TimeoutMixin, self).__init__(*args, **kwargs)
self.t_count = kwargs.get('t_count', 3)
self.t_dura = kwargs.get('t_duration', 20)
self.timeouts = {}
def is_timed_out(self, name):
timeout_deq = self.timeouts.get(name)
if not timeout_deq:
return False
the_time = time.time()
for t in list(timeout_deq):
if (the_time - t) > self.t_dura:
timeout_deq.popleft()
if len(timeout_deq) >= self.t_count:
print('Skipping {0}, {1} timeouts in {2} seconds'.
format(name, self.t_count, self.t_dura))
return True
return False
def _iter_sources(self, params):
sources = super(TimeoutMixin, self)._iter_sources(params)
for name, source in sources:
if not self.is_timed_out(name):
yield name, source
def _on_source_error(self, name):
the_time = time.time()
if name not in self.timeouts:
self.timeouts[name] = deque()
self.timeouts[name].append(the_time)
print(name + ' timed out!')
#=============================================================================
class GeventMixin(object):
def __init__(self, *args, **kwargs):
super(GeventMixin, self).__init__(*args, **kwargs)
self.pool = Pool(size=kwargs.get('size'))
self.timeout = kwargs.get('timeout', 5.0)
def _load_all(self, params):
params['_timeout'] = self.timeout
sources = list(self._iter_sources(params))
def do_spawn(name, source):
return self.pool.spawn(self.load_child_source, name, source, params)
jobs = [do_spawn(name, source) for name, source in sources]
gevent.joinall(jobs, timeout=self.timeout)
results = []
for (name, source), job in zip(sources, jobs):
if job.value is not None:
results.append(job.value)
else:
results.append((iter([]), [(name, 'timeout')]))
self._on_source_error(name)
return results
#=============================================================================
class GeventTimeoutAggregator(TimeoutMixin, GeventMixin, BaseSourceListAggregator):
pass
#=============================================================================
class BaseDirectoryIndexSource(BaseAggregator):
def __init__(self, base_prefix, base_dir=''):
self.base_prefix = base_prefix
self.base_dir = base_dir
def _iter_sources(self, params):
the_dir = res_template(self.base_dir, params)
the_dir = os.path.join(self.base_prefix, the_dir)
try:
sources = list(self._load_files(the_dir))
except Exception:
raise NotFoundException(the_dir)
return sources
def _load_files(self, glob_dir):
for the_dir in glob.iglob(glob_dir):
for result in self._load_files_single_dir(the_dir):
yield result
def _load_files_single_dir(self, the_dir):
for name in os.listdir(the_dir):
filename = os.path.join(the_dir, name)
if filename.endswith(FileIndexSource.CDX_EXT):
#print('Adding ' + filename)
rel_path = os.path.relpath(the_dir, self.base_prefix)
if rel_path == '.':
full_name = name
else:
full_name = rel_path + '/' + name
yield full_name, FileIndexSource(filename)
def __repr__(self):
return '{0}(file://{1})'.format(self.__class__.__name__,
os.path.join(self.base_prefix, self.base_dir))
def __str__(self):
return 'file_dir'
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return (self.base_prefix == other.base_prefix and
self.base_dir == other.base_dir)
@classmethod
def init_from_string(cls, value):
if '://' not in value and os.path.isdir(value):
return cls(value)
@classmethod
def init_from_config(cls, config):
if config['type'] != 'file':
return
return cls.init_from_string(config['path'])
#=============================================================================
class DirectoryIndexSource(SeqAggMixin, BaseDirectoryIndexSource):
pass
#=============================================================================
class CacheDirectoryIndexSource(DirectoryIndexSource):
def __init__(self, *args, **kwargs):
super(CacheDirectoryIndexSource, self).__init__(*args, **kwargs)
self.cached_file_list = {}
def _load_files_single_dir(self, the_dir):
try:
stat = os.stat(the_dir)
except Exception as e:
stat = 0
result = self.cached_file_list.get(the_dir)
if result:
last_stat, files = result
if stat and last_stat == stat:
print('Dir {0} unchanged'.format(the_dir))
return files
files = super(CacheDirectoryIndexSource, self)._load_files_single_dir(the_dir)
files = list(files)
self.cached_file_list[the_dir] = (stat, files)
return files
#=============================================================================
class BaseRedisMultiKeyIndexSource(BaseAggregator, RedisIndexSource):
def _iter_sources(self, params):
redis_key_pattern = res_template(self.redis_key_template, params)
if '*' not in redis_key_pattern:
keys = [redis_key_pattern.encode('utf-8')]
else:
keys = self.scan_keys(redis_key_pattern, params)
for key in keys:
key = key.decode('utf-8')
res = self._get_source_for_key(key)
if res:
yield key, res
def _get_source_for_key(self, key):
return RedisIndexSource(None, self.redis, key)
def __str__(self):
return 'redis-multikey'
#=============================================================================
class RedisMultiKeyIndexSource(SeqAggMixin, BaseRedisMultiKeyIndexSource):
pass