mirror of
https://github.com/webrecorder/pywb.git
synced 2025-03-23 06:32:24 +01:00
260 lines
7.8 KiB
Python
260 lines
7.8 KiB
Python
from gevent.pool import Pool
|
|
import gevent
|
|
import json
|
|
import time
|
|
import os
|
|
|
|
from pywb.cdx.cdxops import process_cdx
|
|
from pywb.cdx.query import CDXQuery
|
|
|
|
from heapq import merge
|
|
from collections import deque
|
|
|
|
from rezag.indexsource import FileIndexSource
|
|
from pywb.utils.wbexception import NotFoundException
|
|
import six
|
|
import glob
|
|
|
|
|
|
#=============================================================================
|
|
class BaseAggregator(object):
|
|
def __call__(self, params):
|
|
query = CDXQuery(params)
|
|
self._set_src_params(params)
|
|
|
|
try:
|
|
cdx_iter = self.load_index(query.params)
|
|
except NotFoundException as nf:
|
|
cdx_iter = iter([])
|
|
|
|
cdx_iter = process_cdx(cdx_iter, query)
|
|
return cdx_iter
|
|
|
|
def _set_src_params(self, params):
|
|
src_params = {}
|
|
for param, value in six.iteritems(params):
|
|
if not param.startswith('param.'):
|
|
continue
|
|
|
|
parts = param.split('.', 3)[1:]
|
|
|
|
if len(parts) == 2:
|
|
src = parts[0]
|
|
name = parts[1]
|
|
else:
|
|
src = ''
|
|
name = parts[0]
|
|
|
|
if not src in src_params:
|
|
src_params[src] = {}
|
|
|
|
src_params[src][name] = value
|
|
|
|
params['_all_src_params'] = src_params
|
|
|
|
def load_child_source(self, name, source, all_params):
|
|
try:
|
|
_src_params = all_params['_all_src_params'].get(name)
|
|
|
|
#params = dict(url=all_params['url'],
|
|
# key=all_params['key'],
|
|
# end_key=all_params['end_key'],
|
|
# closest=all_params.get('closest'),
|
|
# _input_req=all_params.get('_input_req'),
|
|
# _timeout=all_params.get('_timeout'),
|
|
# _all_src_params=all_params.get('_all_src_params'),
|
|
# _src_params=_src_params)
|
|
|
|
params = all_params
|
|
params['_src_params'] = _src_params
|
|
cdx_iter = source.load_index(params)
|
|
except NotFoundException as nf:
|
|
print('Not found in ' + name)
|
|
cdx_iter = iter([])
|
|
|
|
def add_name(cdx_iter):
|
|
for cdx in cdx_iter:
|
|
if 'source' in cdx:
|
|
cdx['source'] = name + '.' + cdx['source']
|
|
else:
|
|
cdx['source'] = name
|
|
yield cdx
|
|
|
|
return add_name(cdx_iter)
|
|
|
|
def load_index(self, params):
|
|
iter_list = list(self._load_all(params))
|
|
|
|
#optimization: if only a single entry (or empty) just load directly
|
|
if len(iter_list) <= 1:
|
|
cdx_iter = iter_list[0] if iter_list else iter([])
|
|
else:
|
|
cdx_iter = merge(*(iter_list))
|
|
|
|
return cdx_iter
|
|
|
|
def _load_all(self, params): #pragma: no cover
|
|
raise NotImplemented()
|
|
|
|
def get_sources(self, params): #pragma: no cover
|
|
raise NotImplemented()
|
|
|
|
|
|
#=============================================================================
|
|
class BaseSourceListAggregator(BaseAggregator):
|
|
def __init__(self, sources, **kwargs):
|
|
self.sources = sources
|
|
|
|
def get_all_sources(self, params):
|
|
return self.sources
|
|
|
|
def get_sources(self, params):
|
|
sources = self.get_all_sources(params)
|
|
srcs_list = params.get('sources')
|
|
if not srcs_list:
|
|
return sources.items()
|
|
|
|
sel_sources = tuple(srcs_list.split(','))
|
|
|
|
return [(name, sources[name]) for name in sources.keys() if name in sel_sources]
|
|
|
|
|
|
#=============================================================================
|
|
class SeqAggMixin(object):
|
|
def __init__(self, *args, **kwargs):
|
|
super(SeqAggMixin, self).__init__(*args, **kwargs)
|
|
|
|
|
|
def _load_all(self, params):
|
|
sources = list(self.get_sources(params))
|
|
return list([self.load_child_source(name, source, params)
|
|
for name, source in sources])
|
|
|
|
|
|
#=============================================================================
|
|
class SimpleAggregator(SeqAggMixin, BaseSourceListAggregator):
|
|
pass
|
|
|
|
|
|
#=============================================================================
|
|
class TimeoutMixin(object):
|
|
def __init__(self, *args, **kwargs):
|
|
super(TimeoutMixin, self).__init__(*args, **kwargs)
|
|
self.t_count = kwargs.get('t_count', 3)
|
|
self.t_dura = kwargs.get('t_duration', 20)
|
|
self.timeouts = {}
|
|
|
|
def is_timed_out(self, name):
|
|
timeout_deq = self.timeouts.get(name)
|
|
if not timeout_deq:
|
|
return False
|
|
|
|
the_time = time.time()
|
|
for t in list(timeout_deq):
|
|
if (the_time - t) > self.t_dura:
|
|
timeout_deq.popleft()
|
|
|
|
if len(timeout_deq) >= self.t_count:
|
|
print('Skipping {0}, {1} timeouts in {2} seconds'.
|
|
format(name, self.t_count, self.t_dura))
|
|
return True
|
|
|
|
return False
|
|
|
|
def get_sources(self, params):
|
|
sources = super(TimeoutMixin, self).get_sources(params)
|
|
for name, source in sources:
|
|
if not self.is_timed_out(name):
|
|
yield name, source
|
|
|
|
def track_source_error(self, name):
|
|
the_time = time.time()
|
|
if name not in self.timeouts:
|
|
self.timeouts[name] = deque()
|
|
|
|
self.timeouts[name].append(the_time)
|
|
print(name + ' timed out!')
|
|
|
|
|
|
#=============================================================================
|
|
class GeventAggMixin(object):
|
|
def __init__(self, *args, **kwargs):
|
|
super(GeventAggMixin, self).__init__(*args, **kwargs)
|
|
self.pool = Pool(size=kwargs.get('size'))
|
|
self.timeout = kwargs.get('timeout', 5.0)
|
|
|
|
def track_source_error(self, name):
|
|
pass
|
|
|
|
def _load_all(self, params):
|
|
params['_timeout'] = self.timeout
|
|
|
|
sources = list(self.get_sources(params))
|
|
|
|
def do_spawn(name, source):
|
|
return self.pool.spawn(self.load_child_source, name, source, params)
|
|
|
|
jobs = [do_spawn(name, source) for name, source in sources]
|
|
|
|
gevent.joinall(jobs, timeout=self.timeout)
|
|
|
|
res = []
|
|
for name, job in zip(sources, jobs):
|
|
if job.value:
|
|
res.append(job.value)
|
|
else:
|
|
self.track_source_error(name)
|
|
|
|
return res
|
|
|
|
|
|
#=============================================================================
|
|
class GeventTimeoutAggregator(TimeoutMixin, GeventAggMixin, BaseSourceListAggregator):
|
|
pass
|
|
|
|
|
|
#=============================================================================
|
|
class BaseDirectoryIndexAggregator(BaseAggregator):
|
|
CDX_EXT = ('.cdx', '.cdxj')
|
|
|
|
def __init__(self, base_prefix, base_dir):
|
|
self.base_prefix = base_prefix
|
|
self.base_dir = base_dir
|
|
|
|
def get_sources(self, params):
|
|
# see if specific params (when part of another agg)
|
|
src_params = params.get('_src_params')
|
|
if not src_params:
|
|
# try default param. settings
|
|
src_params = params.get('_all_src_params', {}).get('')
|
|
|
|
if src_params:
|
|
the_dir = self.base_dir.format(**src_params)
|
|
else:
|
|
the_dir = self.base_dir
|
|
|
|
the_dir = os.path.join(self.base_prefix, the_dir)
|
|
|
|
try:
|
|
sources = list(self._load_files(the_dir))
|
|
except Exception:
|
|
raise NotFoundException(the_dir)
|
|
|
|
return sources
|
|
|
|
def _load_files(self, glob_dir):
|
|
for the_dir in glob.iglob(glob_dir):
|
|
print(the_dir)
|
|
for name in os.listdir(the_dir):
|
|
filename = os.path.join(the_dir, name)
|
|
|
|
if filename.endswith(self.CDX_EXT):
|
|
print('Adding ' + filename)
|
|
rel_path = os.path.relpath(the_dir, self.base_prefix)
|
|
yield rel_path, FileIndexSource(filename)
|
|
|
|
class DirectoryIndexAggregator(SeqAggMixin, BaseDirectoryIndexAggregator):
|
|
pass
|
|
|
|
|