diff --git a/webagg/responseloader.py b/webagg/responseloader.py index e600533e..de6be389 100644 --- a/webagg/responseloader.py +++ b/webagg/responseloader.py @@ -1,15 +1,18 @@ from webagg.utils import MementoUtils, StreamIter, chunk_encode_iter +from webagg.utils import ParamFormatter from webagg.indexsource import RedisIndexSource from pywb.utils.timeutils import timestamp_to_datetime, datetime_to_timestamp from pywb.utils.timeutils import iso_date_to_datetime, datetime_to_iso_date from pywb.utils.timeutils import http_date_to_datetime, datetime_to_http_date -from pywb.utils.wbexception import LiveResourceException -from pywb.utils.statusandheaders import StatusAndHeaders +from pywb.utils.wbexception import LiveResourceException, WbException +from pywb.utils.statusandheaders import StatusAndHeaders, StatusAndHeadersParser from pywb.warc.resolvingloader import ResolvingLoader +from six.moves.urllib.parse import urlsplit + from io import BytesIO import uuid @@ -77,6 +80,29 @@ class BaseLoader(object): return False + def raise_on_self_redirect(self, params, cdx, status_code, location_url): + """ + Check if response is a 3xx redirect to the same url + If so, reject this capture to avoid causing redirect loop + """ + if not status_code.startswith('3') or status_code == '304': + return + + request_url = params['url'].lower() + if not location_url: + return + + location_url = location_url.lower() + if location_url.startswith('/'): + host = urlsplit(cdx['url']).netloc + location_url = host + location_url + + if request_url == location_url: + msg = 'Self Redirect {0} -> {1}' + msg = msg.format(request_url, location_url) + #print(msg) + raise WbException(msg) + #============================================================================= class PrefixResolver(object): @@ -99,7 +125,6 @@ class RedisResolver(RedisIndexSource): redis_key = cdx._formatter.format(redis_key) res = self.redis.hget(redis_key, filename) - print('REDIS_KEY', redis_key, filename, res) if res: res = res.decode('utf-8') @@ -117,6 +142,9 @@ class WARCPathLoader(BaseLoader): self.resolve_loader = ResolvingLoader(self.resolvers, no_record_parse=True) + + self.headers_parser = StatusAndHeadersParser([], verify=False) + self.cdx_source = cdx_source def cdx_index_source(self, *args, **kwargs): @@ -140,12 +168,23 @@ class WARCPathLoader(BaseLoader): if not cdx.get('filename') or cdx.get('offset') is None: return None - cdx._formatter = params.get('_formatter') + cdx._formatter = ParamFormatter(params, cdx.get('source')) + failed_files = [] headers, payload = (self.resolve_loader. load_headers_and_payload(cdx, failed_files, self.cdx_index_source)) + + if cdx.get('status', '').startswith('3'): + status_headers = self.headers_parser.parse(payload.stream) + self.raise_on_self_redirect(params, cdx, + status_headers.get_statuscode(), + status_headers.get_header('Location')) + http_headers_buff = status_headers.to_bytes() + else: + http_headers_buff = None + warc_headers = payload.rec_headers if headers != payload: @@ -163,7 +202,7 @@ class WARCPathLoader(BaseLoader): headers.stream.close() - return (warc_headers, None, payload.stream) + return (warc_headers, http_headers_buff, payload.stream) def __str__(self): return 'WARCPathLoader' @@ -184,8 +223,6 @@ class LiveWebLoader(BaseLoader): if not load_url: return None - #recorder = HeaderRecorder(self.SKIP_HEADERS) - input_req = params['_input_req'] req_headers = input_req.get_req_headers() @@ -195,13 +232,6 @@ class LiveWebLoader(BaseLoader): if cdx.get('memento_url'): req_headers['Accept-Datetime'] = datetime_to_http_date(dt) - # if different url, ensure origin is not set - # may need to add other headers - if load_url != cdx['url']: - if 'Origin' in req_headers: - splits = urlsplit(load_url) - req_headers['Origin'] = splits.scheme + '://' + splits.netloc - method = input_req.get_req_method() data = input_req.get_req_body() @@ -230,6 +260,11 @@ class LiveWebLoader(BaseLoader): cdx['source'] = upstream_res.headers.get('WebAgg-Source-Coll') return None, upstream_res.headers, upstream_res.raw + self.raise_on_self_redirect(params, cdx, + str(upstream_res.status_code), + upstream_res.headers.get('Location')) + + if upstream_res.raw.version == 11: version = '1.1' else: