from gevent.pool import Pool import gevent from concurrent import futures import json import time import os from pywb.utils.timeutils import timestamp_now from pywb.cdx.cdxops import process_cdx from pywb.cdx.query import CDXQuery from heapq import merge from collections import deque from itertools import chain from webagg.indexsource import FileIndexSource from pywb.utils.wbexception import NotFoundException, WbException from webagg.utils import ParamFormatter, res_template import six import glob #============================================================================= class BaseAggregator(object): def __call__(self, params): if params.get('closest') == 'now': params['closest'] = timestamp_now() query = CDXQuery(params) cdx_iter, errs = self.load_index(query.params) cdx_iter = process_cdx(cdx_iter, query) return cdx_iter, dict(errs) def load_child_source(self, name, source, params): try: params['_formatter'] = ParamFormatter(params, name) res = source.load_index(params) if isinstance(res, tuple): cdx_iter, err_list = res else: cdx_iter = res err_list = [] except WbException as wbe: #print('Not found in ' + name) cdx_iter = iter([]) err_list = [(name, repr(wbe))] def add_name(cdx): if cdx.get('source'): cdx['source'] = name + ':' + cdx['source'] else: cdx['source'] = name return cdx return (add_name(cdx) for cdx in cdx_iter), err_list def load_index(self, params): res_list = self._load_all(params) iter_list = [res[0] for res in res_list] err_list = chain(*[res[1] for res in res_list]) #optimization: if only a single entry (or empty) just load directly if len(iter_list) <= 1: cdx_iter = iter_list[0] if iter_list else iter([]) else: cdx_iter = merge(*(iter_list)) return cdx_iter, err_list def _on_source_error(self, name): #pragma: no cover pass def _load_all(self, params): #pragma: no cover raise NotImplemented() def _iter_sources(self, params): #pragma: no cover raise NotImplemented() def get_source_list(self, params): srcs = self._iter_sources(params) result = [(name, str(value)) for name, value in srcs] result = {'sources': dict(result)} return result #============================================================================= class BaseSourceListAggregator(BaseAggregator): def __init__(self, sources, **kwargs): self.sources = sources def get_all_sources(self, params): return self.sources def _iter_sources(self, params): sources = self.get_all_sources(params) srcs_list = params.get('sources') if not srcs_list: return sources.items() sel_sources = tuple(srcs_list.split(',')) return [(name, sources[name]) for name in sources.keys() if name in sel_sources] #============================================================================= class SeqAggMixin(object): def __init__(self, *args, **kwargs): super(SeqAggMixin, self).__init__(*args, **kwargs) def _load_all(self, params): sources = self._iter_sources(params) return [self.load_child_source(name, source, params) for name, source in sources] #============================================================================= class SimpleAggregator(SeqAggMixin, BaseSourceListAggregator): pass #============================================================================= class TimeoutMixin(object): def __init__(self, *args, **kwargs): super(TimeoutMixin, self).__init__(*args, **kwargs) self.t_count = kwargs.get('t_count', 3) self.t_dura = kwargs.get('t_duration', 20) self.timeouts = {} def is_timed_out(self, name): timeout_deq = self.timeouts.get(name) if not timeout_deq: return False the_time = time.time() for t in list(timeout_deq): if (the_time - t) > self.t_dura: timeout_deq.popleft() if len(timeout_deq) >= self.t_count: print('Skipping {0}, {1} timeouts in {2} seconds'. format(name, self.t_count, self.t_dura)) return True return False def _iter_sources(self, params): sources = super(TimeoutMixin, self)._iter_sources(params) for name, source in sources: if not self.is_timed_out(name): yield name, source def _on_source_error(self, name): the_time = time.time() if name not in self.timeouts: self.timeouts[name] = deque() self.timeouts[name].append(the_time) print(name + ' timed out!') #============================================================================= class GeventMixin(object): def __init__(self, *args, **kwargs): super(GeventMixin, self).__init__(*args, **kwargs) self.pool = Pool(size=kwargs.get('size')) self.timeout = kwargs.get('timeout', 5.0) def _load_all(self, params): params['_timeout'] = self.timeout sources = list(self._iter_sources(params)) def do_spawn(name, source): return self.pool.spawn(self.load_child_source, name, source, params) jobs = [do_spawn(name, source) for name, source in sources] gevent.joinall(jobs, timeout=self.timeout) results = [] for (name, source), job in zip(sources, jobs): if job.value is not None: results.append(job.value) else: results.append((iter([]), [(name, 'timeout')])) self._on_source_error(name) return results #============================================================================= class GeventTimeoutAggregator(TimeoutMixin, GeventMixin, BaseSourceListAggregator): pass #============================================================================= class BaseDirectoryIndexSource(BaseAggregator): CDX_EXT = ('.cdx', '.cdxj') def __init__(self, base_prefix, base_dir=''): self.base_prefix = base_prefix self.base_dir = base_dir def _iter_sources(self, params): the_dir = res_template(self.base_dir, params) the_dir = os.path.join(self.base_prefix, the_dir) try: sources = list(self._load_files(the_dir)) except Exception: raise NotFoundException(the_dir) return sources def _load_files(self, glob_dir): for the_dir in glob.iglob(glob_dir): for name in os.listdir(the_dir): filename = os.path.join(the_dir, name) if filename.endswith(self.CDX_EXT): print('Adding ' + filename) rel_path = os.path.relpath(the_dir, self.base_prefix) if rel_path == '.': full_name = name else: full_name = rel_path + '/' + name yield full_name, FileIndexSource(filename) def __str__(self): return 'file_dir' class DirectoryIndexSource(SeqAggMixin, BaseDirectoryIndexSource): pass