1
0
mirror of https://github.com/webrecorder/pywb.git synced 2025-03-23 06:32:24 +01:00
pywb/recorder/recorderapp.py
Ilya Kreymer 9adb8da3b7 recorder: add support for filtering collections to record by regex (default: .*)
add support for excluding certain headers when writing WARCs
tests: add first batch of tests for recorder, using live upstream server
2016-03-11 11:12:25 -08:00

208 lines
6.2 KiB
Python

#from gevent import monkey; monkey.patch_all()
from requests import request as remote_request
from requests.structures import CaseInsensitiveDict
from webagg.liverec import ReadFullyStream
from webagg.responseloader import StreamIter
from webagg.inputrequest import DirectWSGIInputRequest
from pywb.utils.statusandheaders import StatusAndHeadersParser
from pywb.warc.recordloader import ArcWarcRecord
from pywb.warc.recordloader import ArcWarcRecordLoader
from recorder.warcrecorder import SingleFileWARCRecorder, PerRecordWARCRecorder
from recorder.redisindexer import WritableRedisIndexer
from six.moves.urllib.parse import parse_qsl, quote
import json
import tempfile
import re
import traceback
import gevent.queue
import gevent
#==============================================================================
class RecorderApp(object):
def __init__(self, upstream_host, writer, accept_colls='.*'):
self.upstream_host = upstream_host
self.writer = writer
self.parser = StatusAndHeadersParser([], verify=False)
self.write_queue = gevent.queue.Queue()
gevent.spawn(self._write_loop)
self.rx_accept_colls = re.compile(accept_colls)
def _write_loop(self):
while True:
self._write_one()
def _write_one(self):
try:
result = self.write_queue.get()
req = None
resp = None
req_head, req_pay, resp_head, resp_pay, params = result
if not self.rx_accept_colls.match(resp_head.get('WebAgg-Source-Coll', '')):
print('COLL', resp_head)
return
req = self._create_req_record(req_head, req_pay, 'request')
resp = self._create_resp_record(resp_head, resp_pay, 'response')
self.writer.write_req_resp(req, resp, params)
except:
traceback.print_exc()
finally:
try:
if req:
req.stream.close()
if resp:
resp.stream.close()
except Exception as e:
traceback.print_exc()
def _create_req_record(self, req_headers, payload, type_, ct=''):
len_ = payload.tell()
payload.seek(0)
warc_headers = req_headers
status_headers = self.parser.parse(payload)
record = ArcWarcRecord('warc', type_, warc_headers, payload,
status_headers, ct, len_)
return record
def _create_resp_record(self, resp_headers, payload, type_, ct=''):
len_ = payload.tell()
payload.seek(0)
warc_headers = self.parser.parse(payload)
warc_headers = CaseInsensitiveDict(warc_headers.headers)
status_headers = self.parser.parse(payload)
record = ArcWarcRecord('warc', type_, warc_headers, payload,
status_headers, ct, len_)
return record
def send_error(self, exc, start_response):
message = json.dumps({'error': repr(exc)})
headers = [('Content-Type', 'application/json; charset=utf-8'),
('Content-Length', str(len(message)))]
start_response('400 Bad Request', headers)
return [message.encode('utf-8')]
def _get_request_uri(self, env):
req_uri = env.get('REQUEST_URI')
if req_uri:
return req_uri
req_uri = quote(env.get('PATH_INFO', ''), safe='/~!$&\'()*+,;=:@')
query = env.get('QUERY_STRING')
if query:
req_uri += '?' + query
return req_uri
def __call__(self, environ, start_response):
request_uri = self._get_request_uri(environ)
input_req = DirectWSGIInputRequest(environ)
headers = input_req.get_req_headers()
method = input_req.get_req_method()
input_buff = input_req.get_req_body()
params = dict(parse_qsl(environ.get('QUERY_STRING')))
req_stream = ReqWrapper(input_buff, headers)
try:
res = remote_request(url=self.upstream_host + request_uri,
method=method,
data=req_stream,
headers=headers,
allow_redirects=False,
stream=True)
except Exception as e:
traceback.print_exc()
return self.send_error(e, start_response)
start_response('200 OK', list(res.headers.items()))
resp_stream = RespWrapper(res.raw, res.headers, req_stream, params, self.write_queue)
return StreamIter(ReadFullyStream(resp_stream))
#==============================================================================
class Wrapper(object):
def __init__(self, stream):
self.stream = stream
self.out = self._create_buffer()
def _create_buffer(self):
return tempfile.SpooledTemporaryFile(max_size=512*1024)
def read(self, limit=-1):
buff = self.stream.read()
self.out.write(buff)
return buff
def close(self):
try:
self.stream.close()
except:
traceback.print_exc()
finally:
self._after_close()
def _after_close(self):
pass
#==============================================================================
class RespWrapper(Wrapper):
def __init__(self, stream, headers, req,
params, queue):
super(RespWrapper, self).__init__(stream)
self.headers = headers
self.req = req
self.params = params
self.queue = queue
def _after_close(self):
if not self.req:
return
try:
entry = (self.req.headers, self.req.out,
self.headers, self.out, self.params)
self.queue.put(entry)
self.req = None
except:
traceback.print_exc()
#==============================================================================
class ReqWrapper(Wrapper):
def __init__(self, stream, req_headers):
super(ReqWrapper, self).__init__(stream)
self.headers = CaseInsensitiveDict(req_headers)
for n in req_headers.keys():
if not n.upper().startswith('WARC-'):
del self.headers[n]