1
0
mirror of https://github.com/webrecorder/pywb.git synced 2025-03-23 06:32:24 +01:00
pywb/recorder/recorderapp.py
Ilya Kreymer d38bb5a1fd filters: add extensible 'skip filters', with default filters to accept certain collections, filter out
recording of range requests. Opportunity to skip recording at request or response time
RespWrapper handles reading stream fully on close() (no need for old ReadFullyStream),
skips recording if read was interrupted/incomplete
writer: avoiding writing duplicate content-length/content-type headers
2016-03-21 11:47:12 -07:00

252 lines
7.4 KiB
Python

from webagg.utils import StreamIter, chunk_encode_iter, BUFF_SIZE
from webagg.inputrequest import DirectWSGIInputRequest
from pywb.utils.statusandheaders import StatusAndHeadersParser
from pywb.warc.recordloader import ArcWarcRecord
from pywb.warc.recordloader import ArcWarcRecordLoader
from recorder.filters import SkipRangeRequestFilter, CollectionFilter
from six.moves.urllib.parse import parse_qsl
import json
import tempfile
from requests.structures import CaseInsensitiveDict
import requests
import traceback
import gevent.queue
import gevent
#==============================================================================
class RecorderApp(object):
def __init__(self, upstream_host, writer, skip_filters=None, **kwargs):
self.upstream_host = upstream_host
self.writer = writer
self.parser = StatusAndHeadersParser([], verify=False)
self.write_queue = gevent.queue.Queue()
gevent.spawn(self._write_loop)
if not skip_filters:
skip_filters = self.create_default_filters(kwargs)
self.skip_filters = skip_filters
def create_default_filters(self, kwargs):
skip_filters = [SkipRangeRequestFilter()]
accept_colls = kwargs.get('accept_colls')
if accept_colls:
skip_filters.append(CollectionFilter(accept_colls))
return skip_filters
def _write_loop(self):
while True:
try:
self._write_one()
except:
traceback.print_exc()
def _write_one(self):
req = None
resp = None
try:
result = self.write_queue.get()
req_head, req_pay, resp_head, resp_pay, params = result
req = self._create_req_record(req_head, req_pay, 'request')
resp = self._create_resp_record(resp_head, resp_pay, 'response')
self.writer.write_req_resp(req, resp, params)
finally:
try:
if req:
req.stream.close()
if resp:
resp.stream.close()
except Exception as e:
traceback.print_exc()
def _create_req_record(self, req_headers, payload, type_, ct=''):
len_ = payload.tell()
payload.seek(0)
warc_headers = req_headers
status_headers = self.parser.parse(payload)
record = ArcWarcRecord('warc', type_, warc_headers, payload,
status_headers, ct, len_)
return record
def _create_resp_record(self, resp_headers, payload, type_, ct=''):
len_ = payload.tell()
payload.seek(0)
warc_headers = self.parser.parse(payload)
warc_headers = CaseInsensitiveDict(warc_headers.headers)
status_headers = self.parser.parse(payload)
record = ArcWarcRecord('warc', type_, warc_headers, payload,
status_headers, ct, len_)
return record
def send_error(self, exc, start_response):
message = json.dumps({'error': repr(exc)})
headers = [('Content-Type', 'application/json; charset=utf-8'),
('Content-Length', str(len(message)))]
start_response('400 Bad Request', headers)
return [message.encode('utf-8')]
def __call__(self, environ, start_response):
input_req = DirectWSGIInputRequest(environ)
headers = input_req.get_req_headers()
method = input_req.get_req_method()
request_uri = input_req.get_full_request_uri()
input_buff = input_req.get_req_body()
params = dict(parse_qsl(environ.get('QUERY_STRING')))
skipping = any(x.skip_request(headers) for x in self.skip_filters)
if not skipping:
req_stream = ReqWrapper(input_buff, headers)
else:
req_stream = input_buff
data = None
if input_buff:
data = req_stream
try:
res = requests.request(url=self.upstream_host + request_uri,
method=method,
data=data,
headers=headers,
allow_redirects=False,
stream=True)
res.raise_for_status()
except Exception as e:
#traceback.print_exc()
return self.send_error(e, start_response)
start_response('200 OK', list(res.headers.items()))
if not skipping:
resp_stream = RespWrapper(res.raw,
res.headers,
req_stream,
params,
self.write_queue,
self.skip_filters)
else:
resp_stream = res.raw
resp_iter = StreamIter(resp_stream)
if res.headers.get('Transfer-Encoding') == 'chunked':
resp_iter = chunk_encode_iter(resp_iter)
return resp_iter
#==============================================================================
class Wrapper(object):
def __init__(self, stream):
self.stream = stream
self.out = self._create_buffer()
self.interrupted = False
def _create_buffer(self):
return tempfile.SpooledTemporaryFile(max_size=512*1024)
def read(self, *args, **kwargs):
try:
buff = self.stream.read(*args, **kwargs)
except Exception as e:
print('INTERRUPT READ')
self.interrupted = True
raise
self.out.write(buff)
return buff
def close(self):
try:
self.stream.close()
except:
traceback.print_exc()
#==============================================================================
class RespWrapper(Wrapper):
def __init__(self, stream, headers, req,
params, queue, skip_filters):
super(RespWrapper, self).__init__(stream)
self.headers = headers
self.req = req
self.params = params
self.queue = queue
self.skip_filters = skip_filters
def close(self):
try:
while True:
if not self.read(BUFF_SIZE):
break
except Exception as e:
print(e)
self.interrupted = True
finally:
try:
self.stream.close()
except Exception as e:
traceback.print_exc()
self._write_to_file()
def _write_to_file(self):
skipping = any(x.skip_response(self.req.headers, self.headers)
for x in self.skip_filters)
if self.interrupted or skipping:
self.out.close()
self.req.out.close()
self.req.close()
return
try:
entry = (self.req.headers, self.req.out,
self.headers, self.out, self.params)
self.queue.put(entry)
self.req.close()
self.req = None
except:
traceback.print_exc()
#==============================================================================
class ReqWrapper(Wrapper):
def __init__(self, stream, req_headers):
super(ReqWrapper, self).__init__(stream)
self.headers = CaseInsensitiveDict(req_headers)
for n in req_headers.keys():
if not n.upper().startswith('WARC-'):
del self.headers[n]