diff --git a/webagg/aggregator.py b/webagg/aggregator.py index cb7cf10a..73560ae0 100644 --- a/webagg/aggregator.py +++ b/webagg/aggregator.py @@ -15,7 +15,7 @@ from heapq import merge from collections import deque from itertools import chain -from webagg.indexsource import FileIndexSource +from webagg.indexsource import FileIndexSource, RedisIndexSource from pywb.utils.wbexception import NotFoundException, WbException from webagg.utils import ParamFormatter, res_template @@ -51,14 +51,14 @@ class BaseAggregator(object): cdx_iter = iter([]) err_list = [(name, repr(wbe))] - def add_name(cdx): + def add_name(cdx, name): if cdx.get('source'): cdx['source'] = name + ':' + cdx['source'] else: cdx['source'] = name return cdx - return (add_name(cdx) for cdx in cdx_iter), err_list + return (add_name(cdx, name) for cdx in cdx_iter), err_list def load_index(self, params): res_list = self._load_all(params) @@ -271,3 +271,16 @@ class CacheDirectoryIndexSource(DirectoryIndexSource): files = list(files) self.cached_file_list[the_dir] = (stat, files) return files + + +#============================================================================= +class RedisMultiKeyIndexSource(SeqAggMixin, BaseAggregator, RedisIndexSource): + def _iter_sources2(self, params): + redis_key_pattern = res_template(self.redis_key_template, params) + + for key in self.redis.scan_iter(match=redis_key_pattern): + key = key.decode('utf-8') + yield '', RedisIndexSource(None, self.redis, key) + + def _iter_sources(self, params): + return list(self._iter_sources2(params)) diff --git a/webagg/indexsource.py b/webagg/indexsource.py index b37604ba..afb500e6 100644 --- a/webagg/indexsource.py +++ b/webagg/indexsource.py @@ -103,19 +103,30 @@ class LiveIndexSource(BaseIndexSource): #============================================================================= class RedisIndexSource(BaseIndexSource): - def __init__(self, redis_url): + def __init__(self, redis_url, redis=None, key_prefix=None): + if redis_url and not redis: + redis, key_prefix = self.parse_redis_url(redis_url) + + self.redis = redis + self.redis_key_template = key_prefix + + @staticmethod + def parse_redis_url(redis_url): parts = redis_url.split('/') key_prefix = '' if len(parts) > 4: key_prefix = parts[4] redis_url = 'redis://' + parts[2] + '/' + parts[3] - self.redis_url = redis_url - self.redis_key_template = key_prefix - self.redis = redis.StrictRedis.from_url(redis_url) + redis_key_template = key_prefix + red = redis.StrictRedis.from_url(redis_url) + return red, key_prefix def load_index(self, params): - z_key = res_template(self.redis_key_template, params) + return self.load_key_index(self.redis_key_template, params) + + def load_key_index(self, key_template, params): + z_key = res_template(key_template, params) index_list = self.redis.zrangebylex(z_key, b'[' + params['key'], b'(' + params['end_key'])