from gevent.pool import Pool import gevent import json import time import os from pywb.cdx.cdxops import process_cdx from pywb.cdx.query import CDXQuery from heapq import merge from collections import deque from rezag.indexsource import FileIndexSource from pywb.utils.wbexception import NotFoundException import six import glob #============================================================================= class BaseAggregator(object): def __call__(self, params): query = CDXQuery(params) self._set_src_params(params) try: cdx_iter = self.load_index(query.params) except NotFoundException as nf: cdx_iter = iter([]) cdx_iter = process_cdx(cdx_iter, query) return cdx_iter def _set_src_params(self, params): src_params = {} for param, value in six.iteritems(params): if not param.startswith('param.'): continue parts = param.split('.', 3)[1:] if len(parts) == 2: src = parts[0] name = parts[1] else: src = '' name = parts[0] if not src in src_params: src_params[src] = {} src_params[src][name] = value params['_all_src_params'] = src_params def load_child_source(self, name, source, all_params): try: _src_params = all_params['_all_src_params'].get(name) #params = dict(url=all_params['url'], # key=all_params['key'], # end_key=all_params['end_key'], # closest=all_params.get('closest'), # _input_req=all_params.get('_input_req'), # _timeout=all_params.get('_timeout'), # _all_src_params=all_params.get('_all_src_params'), # _src_params=_src_params) params = all_params params['_src_params'] = _src_params cdx_iter = source.load_index(params) except NotFoundException as nf: print('Not found in ' + name) cdx_iter = iter([]) def add_name(cdx_iter): for cdx in cdx_iter: if 'source' in cdx: cdx['source'] = name + '.' + cdx['source'] else: cdx['source'] = name yield cdx return add_name(cdx_iter) def load_index(self, params): iter_list = list(self._load_all(params)) #optimization: if only a single entry (or empty) just load directly if len(iter_list) <= 1: cdx_iter = iter_list[0] if iter_list else iter([]) else: cdx_iter = merge(*(iter_list)) return cdx_iter def _load_all(self, params): #pragma: no cover raise NotImplemented() def get_sources(self, params): #pragma: no cover raise NotImplemented() #============================================================================= class BaseSourceListAggregator(BaseAggregator): def __init__(self, sources, **kwargs): self.sources = sources def get_all_sources(self, params): return self.sources def get_sources(self, params): sources = self.get_all_sources(params) srcs_list = params.get('sources') if not srcs_list: return sources.items() sel_sources = tuple(srcs_list.split(',')) return [(name, sources[name]) for name in sources.keys() if name in sel_sources] #============================================================================= class SeqAggMixin(object): def __init__(self, *args, **kwargs): super(SeqAggMixin, self).__init__(*args, **kwargs) def _load_all(self, params): sources = list(self.get_sources(params)) return list([self.load_child_source(name, source, params) for name, source in sources]) #============================================================================= class SimpleAggregator(SeqAggMixin, BaseSourceListAggregator): pass #============================================================================= class TimeoutMixin(object): def __init__(self, *args, **kwargs): super(TimeoutMixin, self).__init__(*args, **kwargs) self.t_count = kwargs.get('t_count', 3) self.t_dura = kwargs.get('t_duration', 20) self.timeouts = {} def is_timed_out(self, name): timeout_deq = self.timeouts.get(name) if not timeout_deq: return False the_time = time.time() for t in list(timeout_deq): if (the_time - t) > self.t_dura: timeout_deq.popleft() if len(timeout_deq) >= self.t_count: print('Skipping {0}, {1} timeouts in {2} seconds'. format(name, self.t_count, self.t_dura)) return True return False def get_sources(self, params): sources = super(TimeoutMixin, self).get_sources(params) for name, source in sources: if not self.is_timed_out(name): yield name, source def track_source_error(self, name): the_time = time.time() if name not in self.timeouts: self.timeouts[name] = deque() self.timeouts[name].append(the_time) print(name + ' timed out!') #============================================================================= class GeventAggMixin(object): def __init__(self, *args, **kwargs): super(GeventAggMixin, self).__init__(*args, **kwargs) self.pool = Pool(size=kwargs.get('size')) self.timeout = kwargs.get('timeout', 5.0) def track_source_error(self, name): pass def _load_all(self, params): params['_timeout'] = self.timeout sources = list(self.get_sources(params)) def do_spawn(name, source): return self.pool.spawn(self.load_child_source, name, source, params) jobs = [do_spawn(name, source) for name, source in sources] gevent.joinall(jobs, timeout=self.timeout) res = [] for name, job in zip(sources, jobs): if job.value: res.append(job.value) else: self.track_source_error(name) return res #============================================================================= class GeventTimeoutAggregator(TimeoutMixin, GeventAggMixin, BaseSourceListAggregator): pass #============================================================================= class BaseDirectoryIndexAggregator(BaseAggregator): CDX_EXT = ('.cdx', '.cdxj') def __init__(self, base_prefix, base_dir): self.base_prefix = base_prefix self.base_dir = base_dir def get_sources(self, params): # see if specific params (when part of another agg) src_params = params.get('_src_params') if not src_params: # try default param. settings src_params = params.get('_all_src_params', {}).get('') if src_params: the_dir = self.base_dir.format(**src_params) else: the_dir = self.base_dir the_dir = os.path.join(self.base_prefix, the_dir) try: sources = list(self._load_files(the_dir)) except Exception: raise NotFoundException(the_dir) return sources def _load_files(self, glob_dir): for the_dir in glob.iglob(glob_dir): print(the_dir) for name in os.listdir(the_dir): filename = os.path.join(the_dir, name) if filename.endswith(self.CDX_EXT): print('Adding ' + filename) rel_path = os.path.relpath(the_dir, self.base_prefix) yield rel_path, FileIndexSource(filename) class DirectoryIndexAggregator(SeqAggMixin, BaseDirectoryIndexAggregator): pass