diff --git a/pywb/warcserver/index/cdxops.py b/pywb/warcserver/index/cdxops.py index d04c1d96..29846ca9 100644 --- a/pywb/warcserver/index/cdxops.py +++ b/pywb/warcserver/index/cdxops.py @@ -157,6 +157,63 @@ def cdx_reverse(cdx_iter, limit): yield cdx +#================================================================= +class CDXFilter(object): + def __init__(self, string): + # invert filter + self.invert = string.startswith('!') + if self.invert: + string = string[1:] + + # exact match + if string.startswith('='): + string = string[1:] + self.compare_func = self.exact + # contains match + elif string.startswith('~'): + string = string[1:] + self.compare_func = self.contains + else: + self.compare_func = self.rx_match + + parts = string.split(':', 1) + # no field set, apply filter to entire cdx + if len(parts) == 1: + self.field = '' + # apply filter to cdx[field] + else: + self.field = parts[0] + self.field = CDXObject.CDX_ALT_FIELDS.get(self.field, + self.field) + string = parts[1] + + # make regex if regex mode + if self.compare_func == self.rx_match: + self.regex = re.compile(string) + else: + self.filter_str = string + + def __call__(self, cdx): + if not self.field: + val = str(cdx) + else: + val = str(cdx.get(self.field, '')) + + matched = self.compare_func(val) + + return matched ^ self.invert + + def exact(self, val): + return (self.filter_str == val) + + def contains(self, val): + return (self.filter_str in val) + + def rx_match(self, val): + res = self.regex.match(val) + return res is not None + + #================================================================= def cdx_filter(cdx_iter, filter_strings): """ @@ -167,63 +224,7 @@ def cdx_filter(cdx_iter, filter_strings): if isinstance(filter_strings, str): filter_strings = [filter_strings] - filters = [] - - class Filter: - def __init__(self, string): - # invert filter - self.invert = string.startswith('!') - if self.invert: - string = string[1:] - - # exact match - if string.startswith('='): - string = string[1:] - self.compare_func = self.exact - # contains match - elif string.startswith('~'): - string = string[1:] - self.compare_func = self.contains - else: - self.compare_func = self.regex - - parts = string.split(':', 1) - # no field set, apply filter to entire cdx - if len(parts) == 1: - self.field = '' - # apply filter to cdx[field] - else: - self.field = parts[0] - self.field = CDXObject.CDX_ALT_FIELDS.get(self.field, - self.field) - string = parts[1] - - # make regex if regex mode - if self.compare_func == self.regex: - self.regex = re.compile(string) - else: - self.filter_str = string - - def __call__(self, cdx): - if not self.field: - val = str(cdx) - else: - val = cdx.get(self.field, '') - - matched = self.compare_func(val) - - return matched ^ self.invert - - def exact(self, val): - return (self.filter_str == val) - - def contains(self, val): - return (self.filter_str in val) - - def regex(self, val): - return self.regex.match(val) is not None - - filters = list(map(Filter, filter_strings)) + filters = [CDXFilter(filter_str) for filter_str in filter_strings] for cdx in cdx_iter: if all(x(cdx) for x in filters): diff --git a/pywb/warcserver/index/indexsource.py b/pywb/warcserver/index/indexsource.py index cb3a0951..aecc2489 100644 --- a/pywb/warcserver/index/indexsource.py +++ b/pywb/warcserver/index/indexsource.py @@ -193,19 +193,38 @@ class RemoteIndexSource(BaseIndexSource): class LiveIndexSource(BaseIndexSource): def __init__(self, proxy_url='{url}'): self.proxy_url = proxy_url + self._init_sesh() def load_index(self, params): + # no fuzzy match for live resources + if params.get('is_fuzzy'): + raise NotFoundException(params['url'] + '*') + cdx = CDXObject() cdx['urlkey'] = params.get('key').decode('utf-8') cdx['timestamp'] = timestamp_now() cdx['url'] = params['url'] cdx['load_url'] = res_template(self.proxy_url, params) cdx['is_live'] = 'true' - cdx['mime'] = params.get('content_type', '') - def live(): - yield cdx - return live() + mime = params.get('content_type', '') + + if params.get('filter') and not mime: + try: + res = self.sesh.head(cdx['url']) + if res.status_code != 405: + cdx['status'] = str(res.status_code) + + content_type = res.headers.get('Content-Type') + if content_type: + mime = content_type.split(';')[0] + + except Exception as e: + pass + + cdx['mime'] = mime + + return iter([cdx]) def __repr__(self): return '{0}()'.format(self.__class__.__name__) @@ -383,11 +402,16 @@ class MementoIndexSource(BaseIndexSource): def handle_timemap(self, params): url = res_template(self.timemap_url, params) headers = self._get_headers(params) - res = self.sesh.get(url, - headers=headers, - timeout=params.get('_timeout')) + try: + res = self.sesh.get(url, + headers=headers, + timeout=params.get('_timeout')) - if res.status_code >= 400 or not res.text: + res.raise_for_status() + assert(res.text) + + except Exception as e: + print('FAILED: ' + str(e)) raise NotFoundException(url) links = res.text