from webagg.utils import StreamIter, chunk_encode_iter, BUFF_SIZE from webagg.inputrequest import DirectWSGIInputRequest from pywb.utils.statusandheaders import StatusAndHeadersParser from pywb.warc.recordloader import ArcWarcRecord from pywb.warc.recordloader import ArcWarcRecordLoader from recorder.filters import SkipRangeRequestFilter, CollectionFilter from six.moves.urllib.parse import parse_qsl import json import tempfile from requests.structures import CaseInsensitiveDict import requests import traceback import gevent.queue import gevent #============================================================================== class RecorderApp(object): def __init__(self, upstream_host, writer, skip_filters=None, **kwargs): self.upstream_host = upstream_host self.writer = writer self.parser = StatusAndHeadersParser([], verify=False) self.write_queue = gevent.queue.Queue() gevent.spawn(self._write_loop) if not skip_filters: skip_filters = self.create_default_filters(kwargs) self.skip_filters = skip_filters def create_default_filters(self, kwargs): skip_filters = [SkipRangeRequestFilter()] accept_colls = kwargs.get('accept_colls') if accept_colls: skip_filters.append(CollectionFilter(accept_colls)) return skip_filters def _write_loop(self): while True: try: self._write_one() except: traceback.print_exc() def _write_one(self): req = None resp = None try: result = self.write_queue.get() req_head, req_pay, resp_head, resp_pay, params = result req = self._create_req_record(req_head, req_pay, 'request') resp = self._create_resp_record(resp_head, resp_pay, 'response') self.writer.write_req_resp(req, resp, params) finally: try: if req: req.stream.close() if resp: resp.stream.close() except Exception as e: traceback.print_exc() def _create_req_record(self, req_headers, payload, type_, ct=''): len_ = payload.tell() payload.seek(0) warc_headers = req_headers status_headers = self.parser.parse(payload) record = ArcWarcRecord('warc', type_, warc_headers, payload, status_headers, ct, len_) return record def _create_resp_record(self, resp_headers, payload, type_, ct=''): len_ = payload.tell() payload.seek(0) warc_headers = self.parser.parse(payload) warc_headers = CaseInsensitiveDict(warc_headers.headers) status_headers = self.parser.parse(payload) record = ArcWarcRecord('warc', type_, warc_headers, payload, status_headers, ct, len_) return record def send_error(self, exc, start_response): message = json.dumps({'error': repr(exc)}) headers = [('Content-Type', 'application/json; charset=utf-8'), ('Content-Length', str(len(message)))] start_response('400 Bad Request', headers) return [message.encode('utf-8')] def __call__(self, environ, start_response): input_req = DirectWSGIInputRequest(environ) headers = input_req.get_req_headers() method = input_req.get_req_method() request_uri = input_req.get_full_request_uri() input_buff = input_req.get_req_body() params = dict(parse_qsl(environ.get('QUERY_STRING'))) skipping = any(x.skip_request(headers) for x in self.skip_filters) if not skipping: req_stream = ReqWrapper(input_buff, headers) else: req_stream = input_buff data = None if input_buff: data = req_stream try: res = requests.request(url=self.upstream_host + request_uri, method=method, data=data, headers=headers, allow_redirects=False, stream=True) res.raise_for_status() except Exception as e: #traceback.print_exc() return self.send_error(e, start_response) start_response('200 OK', list(res.headers.items())) if not skipping: resp_stream = RespWrapper(res.raw, res.headers, req_stream, params, self.write_queue, self.skip_filters) else: resp_stream = res.raw resp_iter = StreamIter(resp_stream) if res.headers.get('Transfer-Encoding') == 'chunked': resp_iter = chunk_encode_iter(resp_iter) return resp_iter #============================================================================== class Wrapper(object): def __init__(self, stream): self.stream = stream self.out = self._create_buffer() self.interrupted = False def _create_buffer(self): return tempfile.SpooledTemporaryFile(max_size=512*1024) def read(self, *args, **kwargs): try: buff = self.stream.read(*args, **kwargs) except Exception as e: print('INTERRUPT READ') self.interrupted = True raise self.out.write(buff) return buff def close(self): try: self.stream.close() except: traceback.print_exc() #============================================================================== class RespWrapper(Wrapper): def __init__(self, stream, headers, req, params, queue, skip_filters): super(RespWrapper, self).__init__(stream) self.headers = headers self.req = req self.params = params self.queue = queue self.skip_filters = skip_filters def close(self): try: while True: if not self.read(BUFF_SIZE): break except Exception as e: print(e) self.interrupted = True finally: try: self.stream.close() except Exception as e: traceback.print_exc() self._write_to_file() def _write_to_file(self): skipping = any(x.skip_response(self.req.headers, self.headers) for x in self.skip_filters) if self.interrupted or skipping: self.out.close() self.req.out.close() self.req.close() return try: entry = (self.req.headers, self.req.out, self.headers, self.out, self.params) self.queue.put(entry) self.req.close() self.req = None except: traceback.print_exc() #============================================================================== class ReqWrapper(Wrapper): def __init__(self, stream, req_headers): super(ReqWrapper, self).__init__(stream) self.headers = CaseInsensitiveDict(req_headers) for n in req_headers.keys(): if not n.upper().startswith('WARC-'): del self.headers[n]