diff --git a/recorder/filters.py b/recorder/filters.py index 3635a5ab..b2ffc65f 100644 --- a/recorder/filters.py +++ b/recorder/filters.py @@ -1,4 +1,5 @@ from pywb.utils.timeutils import timestamp_to_datetime, datetime_to_iso_date +import re # ============================================================================ @@ -38,3 +39,43 @@ class WriteDupePolicy(object): def __call__(self, cdx): return 'write' + +# ============================================================================ +# Skip Record Filters +# ============================================================================ +class SkipNothingFilter(object): + def skip_request(self, req_headers): + return False + + def skip_response(self, req_headers, resp_headers): + return False + + +# ============================================================================ +class CollectionFilter(SkipNothingFilter): + def __init__(self, accept_colls): + self.rx_accept_colls = re.compile(accept_colls) + + def skip_request(self, req_headers): + if req_headers.get('Recorder-Skip') == '1': + return True + + return False + + def skip_response(self, req_headers, resp_headers): + if not self.rx_accept_colls.match(resp_headers.get('WebAgg-Source-Coll', '')): + return True + + return False + + +# ============================================================================ +class SkipRangeRequestFilter(SkipNothingFilter): + def skip_request(self, req_headers): + range_ = req_headers.get('Range') + if range_ and not range_.lower().startswith('bytes=0-'): + return True + + return False + + diff --git a/recorder/recorderapp.py b/recorder/recorderapp.py index 818b0eed..e567a95d 100644 --- a/recorder/recorderapp.py +++ b/recorder/recorderapp.py @@ -1,16 +1,16 @@ -#from gevent import monkey; monkey.patch_all() -from webagg.utils import ReadFullyStream, StreamIter +from webagg.utils import StreamIter, chunk_encode_iter, BUFF_SIZE from webagg.inputrequest import DirectWSGIInputRequest from pywb.utils.statusandheaders import StatusAndHeadersParser from pywb.warc.recordloader import ArcWarcRecord from pywb.warc.recordloader import ArcWarcRecordLoader +from recorder.filters import SkipRangeRequestFilter, CollectionFilter + from six.moves.urllib.parse import parse_qsl import json import tempfile -import re from requests.structures import CaseInsensitiveDict import requests @@ -23,7 +23,7 @@ import gevent #============================================================================== class RecorderApp(object): - def __init__(self, upstream_host, writer, accept_colls='.*'): + def __init__(self, upstream_host, writer, skip_filters=None, **kwargs): self.upstream_host = upstream_host self.writer = writer @@ -32,7 +32,19 @@ class RecorderApp(object): self.write_queue = gevent.queue.Queue() gevent.spawn(self._write_loop) - self.rx_accept_colls = re.compile(accept_colls) + if not skip_filters: + skip_filters = self.create_default_filters(kwargs) + + self.skip_filters = skip_filters + + def create_default_filters(self, kwargs): + skip_filters = [SkipRangeRequestFilter()] + + accept_colls = kwargs.get('accept_colls') + if accept_colls: + skip_filters.append(CollectionFilter(accept_colls)) + + return skip_filters def _write_loop(self): while True: @@ -49,9 +61,6 @@ class RecorderApp(object): req_head, req_pay, resp_head, resp_pay, params = result - if not self.rx_accept_colls.match(resp_head.get('WebAgg-Source-Coll', '')): - return - req = self._create_req_record(req_head, req_pay, 'request') resp = self._create_resp_record(resp_head, resp_pay, 'response') @@ -109,7 +118,13 @@ class RecorderApp(object): params = dict(parse_qsl(environ.get('QUERY_STRING'))) - req_stream = ReqWrapper(input_buff, headers) + skipping = any(x.skip_request(headers) for x in self.skip_filters) + + if not skipping: + req_stream = ReqWrapper(input_buff, headers) + else: + req_stream = input_buff + data = None if input_buff: data = req_stream @@ -121,15 +136,29 @@ class RecorderApp(object): headers=headers, allow_redirects=False, stream=True) + res.raise_for_status() except Exception as e: - traceback.print_exc() + #traceback.print_exc() return self.send_error(e, start_response) start_response('200 OK', list(res.headers.items())) - resp_stream = RespWrapper(res.raw, res.headers, req_stream, params, self.write_queue) + if not skipping: + resp_stream = RespWrapper(res.raw, + res.headers, + req_stream, + params, + self.write_queue, + self.skip_filters) + else: + resp_stream = res.raw - return StreamIter(ReadFullyStream(resp_stream)) + resp_iter = StreamIter(resp_stream) + + if res.headers.get('Transfer-Encoding') == 'chunked': + resp_iter = chunk_encode_iter(resp_iter) + + return resp_iter #============================================================================== @@ -137,12 +166,19 @@ class Wrapper(object): def __init__(self, stream): self.stream = stream self.out = self._create_buffer() + self.interrupted = False def _create_buffer(self): return tempfile.SpooledTemporaryFile(max_size=512*1024) - def read(self, limit=-1): - buff = self.stream.read() + def read(self, *args, **kwargs): + try: + buff = self.stream.read(*args, **kwargs) + except Exception as e: + print('INTERRUPT READ') + self.interrupted = True + raise + self.out.write(buff) return buff @@ -151,32 +187,53 @@ class Wrapper(object): self.stream.close() except: traceback.print_exc() - finally: - self._after_close() - - def _after_close(self): - pass #============================================================================== class RespWrapper(Wrapper): def __init__(self, stream, headers, req, - params, queue): + params, queue, skip_filters): super(RespWrapper, self).__init__(stream) self.headers = headers self.req = req self.params = params self.queue = queue + self.skip_filters = skip_filters - def _after_close(self): - if not self.req: + def close(self): + try: + while True: + if not self.read(BUFF_SIZE): + break + + except Exception as e: + print(e) + self.interrupted = True + + finally: + try: + self.stream.close() + except Exception as e: + traceback.print_exc() + + self._write_to_file() + + def _write_to_file(self): + skipping = any(x.skip_response(self.req.headers, self.headers) + for x in self.skip_filters) + + if self.interrupted or skipping: + self.out.close() + self.req.out.close() + self.req.close() return try: entry = (self.req.headers, self.req.out, self.headers, self.out, self.params) self.queue.put(entry) + self.req.close() self.req = None except: traceback.print_exc() diff --git a/recorder/test/test_recorder.py b/recorder/test/test_recorder.py index d83f3375..08d400f8 100644 --- a/recorder/test/test_recorder.py +++ b/recorder/test/test_recorder.py @@ -61,8 +61,9 @@ class TestRecorder(LiveServerTests, TempDirTests, BaseTestClass): req_url = '/live/resource/postreq?url=' + url + other_params testapp = webtest.TestApp(recorder_app) resp = testapp.post(req_url, general_req_data.format(host=host, path=path).encode('utf-8')) - #gevent.sleep(0.1) - recorder_app._write_one() + + if not recorder_app.write_queue.empty(): + recorder_app._write_one() assert resp.headers['WebAgg-Source-Coll'] == 'live' diff --git a/recorder/warcwriter.py b/recorder/warcwriter.py index 99c93e06..aee12072 100644 --- a/recorder/warcwriter.py +++ b/recorder/warcwriter.py @@ -91,7 +91,7 @@ class BaseWARCWriter(object): req.rec_headers['WARC-Target-Uri'] = url req.rec_headers['WARC-Date'] = dt req.rec_headers['WARC-Type'] = 'request' - req.rec_headers['Content-Type'] = req.content_type + #req.rec_headers['Content-Type'] = req.content_type resp_id = resp.rec_headers.get('WARC-Record-ID') if resp_id: @@ -142,6 +142,9 @@ class BaseWARCWriter(object): self._line(out, b'WARC/1.0') for n, v in six.iteritems(record.rec_headers): + if n.lower() in ('content-length', 'content-type'): + continue + self._header(out, n, v) content_type = record.content_type