From d38bb5a1fd317c3252e26a2f71168c7549bc18d9 Mon Sep 17 00:00:00 2001 From: Ilya Kreymer Date: Mon, 21 Mar 2016 11:47:12 -0700 Subject: [PATCH] filters: add extensible 'skip filters', with default filters to accept certain collections, filter out recording of range requests. Opportunity to skip recording at request or response time RespWrapper handles reading stream fully on close() (no need for old ReadFullyStream), skips recording if read was interrupted/incomplete writer: avoiding writing duplicate content-length/content-type headers --- recorder/filters.py | 41 +++++++++++++ recorder/recorderapp.py | 101 ++++++++++++++++++++++++++------- recorder/test/test_recorder.py | 5 +- recorder/warcwriter.py | 5 +- 4 files changed, 127 insertions(+), 25 deletions(-) diff --git a/recorder/filters.py b/recorder/filters.py index 3635a5ab..b2ffc65f 100644 --- a/recorder/filters.py +++ b/recorder/filters.py @@ -1,4 +1,5 @@ from pywb.utils.timeutils import timestamp_to_datetime, datetime_to_iso_date +import re # ============================================================================ @@ -38,3 +39,43 @@ class WriteDupePolicy(object): def __call__(self, cdx): return 'write' + +# ============================================================================ +# Skip Record Filters +# ============================================================================ +class SkipNothingFilter(object): + def skip_request(self, req_headers): + return False + + def skip_response(self, req_headers, resp_headers): + return False + + +# ============================================================================ +class CollectionFilter(SkipNothingFilter): + def __init__(self, accept_colls): + self.rx_accept_colls = re.compile(accept_colls) + + def skip_request(self, req_headers): + if req_headers.get('Recorder-Skip') == '1': + return True + + return False + + def skip_response(self, req_headers, resp_headers): + if not self.rx_accept_colls.match(resp_headers.get('WebAgg-Source-Coll', '')): + return True + + return False + + +# ============================================================================ +class SkipRangeRequestFilter(SkipNothingFilter): + def skip_request(self, req_headers): + range_ = req_headers.get('Range') + if range_ and not range_.lower().startswith('bytes=0-'): + return True + + return False + + diff --git a/recorder/recorderapp.py b/recorder/recorderapp.py index 818b0eed..e567a95d 100644 --- a/recorder/recorderapp.py +++ b/recorder/recorderapp.py @@ -1,16 +1,16 @@ -#from gevent import monkey; monkey.patch_all() -from webagg.utils import ReadFullyStream, StreamIter +from webagg.utils import StreamIter, chunk_encode_iter, BUFF_SIZE from webagg.inputrequest import DirectWSGIInputRequest from pywb.utils.statusandheaders import StatusAndHeadersParser from pywb.warc.recordloader import ArcWarcRecord from pywb.warc.recordloader import ArcWarcRecordLoader +from recorder.filters import SkipRangeRequestFilter, CollectionFilter + from six.moves.urllib.parse import parse_qsl import json import tempfile -import re from requests.structures import CaseInsensitiveDict import requests @@ -23,7 +23,7 @@ import gevent #============================================================================== class RecorderApp(object): - def __init__(self, upstream_host, writer, accept_colls='.*'): + def __init__(self, upstream_host, writer, skip_filters=None, **kwargs): self.upstream_host = upstream_host self.writer = writer @@ -32,7 +32,19 @@ class RecorderApp(object): self.write_queue = gevent.queue.Queue() gevent.spawn(self._write_loop) - self.rx_accept_colls = re.compile(accept_colls) + if not skip_filters: + skip_filters = self.create_default_filters(kwargs) + + self.skip_filters = skip_filters + + def create_default_filters(self, kwargs): + skip_filters = [SkipRangeRequestFilter()] + + accept_colls = kwargs.get('accept_colls') + if accept_colls: + skip_filters.append(CollectionFilter(accept_colls)) + + return skip_filters def _write_loop(self): while True: @@ -49,9 +61,6 @@ class RecorderApp(object): req_head, req_pay, resp_head, resp_pay, params = result - if not self.rx_accept_colls.match(resp_head.get('WebAgg-Source-Coll', '')): - return - req = self._create_req_record(req_head, req_pay, 'request') resp = self._create_resp_record(resp_head, resp_pay, 'response') @@ -109,7 +118,13 @@ class RecorderApp(object): params = dict(parse_qsl(environ.get('QUERY_STRING'))) - req_stream = ReqWrapper(input_buff, headers) + skipping = any(x.skip_request(headers) for x in self.skip_filters) + + if not skipping: + req_stream = ReqWrapper(input_buff, headers) + else: + req_stream = input_buff + data = None if input_buff: data = req_stream @@ -121,15 +136,29 @@ class RecorderApp(object): headers=headers, allow_redirects=False, stream=True) + res.raise_for_status() except Exception as e: - traceback.print_exc() + #traceback.print_exc() return self.send_error(e, start_response) start_response('200 OK', list(res.headers.items())) - resp_stream = RespWrapper(res.raw, res.headers, req_stream, params, self.write_queue) + if not skipping: + resp_stream = RespWrapper(res.raw, + res.headers, + req_stream, + params, + self.write_queue, + self.skip_filters) + else: + resp_stream = res.raw - return StreamIter(ReadFullyStream(resp_stream)) + resp_iter = StreamIter(resp_stream) + + if res.headers.get('Transfer-Encoding') == 'chunked': + resp_iter = chunk_encode_iter(resp_iter) + + return resp_iter #============================================================================== @@ -137,12 +166,19 @@ class Wrapper(object): def __init__(self, stream): self.stream = stream self.out = self._create_buffer() + self.interrupted = False def _create_buffer(self): return tempfile.SpooledTemporaryFile(max_size=512*1024) - def read(self, limit=-1): - buff = self.stream.read() + def read(self, *args, **kwargs): + try: + buff = self.stream.read(*args, **kwargs) + except Exception as e: + print('INTERRUPT READ') + self.interrupted = True + raise + self.out.write(buff) return buff @@ -151,32 +187,53 @@ class Wrapper(object): self.stream.close() except: traceback.print_exc() - finally: - self._after_close() - - def _after_close(self): - pass #============================================================================== class RespWrapper(Wrapper): def __init__(self, stream, headers, req, - params, queue): + params, queue, skip_filters): super(RespWrapper, self).__init__(stream) self.headers = headers self.req = req self.params = params self.queue = queue + self.skip_filters = skip_filters - def _after_close(self): - if not self.req: + def close(self): + try: + while True: + if not self.read(BUFF_SIZE): + break + + except Exception as e: + print(e) + self.interrupted = True + + finally: + try: + self.stream.close() + except Exception as e: + traceback.print_exc() + + self._write_to_file() + + def _write_to_file(self): + skipping = any(x.skip_response(self.req.headers, self.headers) + for x in self.skip_filters) + + if self.interrupted or skipping: + self.out.close() + self.req.out.close() + self.req.close() return try: entry = (self.req.headers, self.req.out, self.headers, self.out, self.params) self.queue.put(entry) + self.req.close() self.req = None except: traceback.print_exc() diff --git a/recorder/test/test_recorder.py b/recorder/test/test_recorder.py index d83f3375..08d400f8 100644 --- a/recorder/test/test_recorder.py +++ b/recorder/test/test_recorder.py @@ -61,8 +61,9 @@ class TestRecorder(LiveServerTests, TempDirTests, BaseTestClass): req_url = '/live/resource/postreq?url=' + url + other_params testapp = webtest.TestApp(recorder_app) resp = testapp.post(req_url, general_req_data.format(host=host, path=path).encode('utf-8')) - #gevent.sleep(0.1) - recorder_app._write_one() + + if not recorder_app.write_queue.empty(): + recorder_app._write_one() assert resp.headers['WebAgg-Source-Coll'] == 'live' diff --git a/recorder/warcwriter.py b/recorder/warcwriter.py index 99c93e06..aee12072 100644 --- a/recorder/warcwriter.py +++ b/recorder/warcwriter.py @@ -91,7 +91,7 @@ class BaseWARCWriter(object): req.rec_headers['WARC-Target-Uri'] = url req.rec_headers['WARC-Date'] = dt req.rec_headers['WARC-Type'] = 'request' - req.rec_headers['Content-Type'] = req.content_type + #req.rec_headers['Content-Type'] = req.content_type resp_id = resp.rec_headers.get('WARC-Record-ID') if resp_id: @@ -142,6 +142,9 @@ class BaseWARCWriter(object): self._line(out, b'WARC/1.0') for n, v in six.iteritems(record.rec_headers): + if n.lower() in ('content-length', 'content-type'): + continue + self._header(out, n, v) content_type = record.content_type