1
0
mirror of https://github.com/webrecorder/pywb.git synced 2025-03-24 23:19:52 +01:00
pywb/recorder/recorderapp.py

208 lines
6.2 KiB
Python
Raw Normal View History

#from gevent import monkey; monkey.patch_all()
2016-03-09 14:33:36 -08:00
from requests import request as remote_request
from requests.structures import CaseInsensitiveDict
from webagg.liverec import ReadFullyStream
from webagg.responseloader import StreamIter
from webagg.inputrequest import DirectWSGIInputRequest
from pywb.utils.statusandheaders import StatusAndHeadersParser
from pywb.warc.recordloader import ArcWarcRecord
from pywb.warc.recordloader import ArcWarcRecordLoader
from recorder.warcrecorder import SingleFileWARCRecorder, PerRecordWARCRecorder
from recorder.redisindexer import WritableRedisIndexer
from six.moves.urllib.parse import parse_qsl, quote
2016-03-09 14:33:36 -08:00
import json
import tempfile
import re
2016-03-09 14:33:36 -08:00
import traceback
import gevent.queue
import gevent
#==============================================================================
class RecorderApp(object):
def __init__(self, upstream_host, writer, accept_colls='.*'):
2016-03-09 14:33:36 -08:00
self.upstream_host = upstream_host
self.writer = writer
self.parser = StatusAndHeadersParser([], verify=False)
self.write_queue = gevent.queue.Queue()
gevent.spawn(self._write_loop)
self.rx_accept_colls = re.compile(accept_colls)
2016-03-09 14:33:36 -08:00
def _write_loop(self):
2016-03-09 14:33:36 -08:00
while True:
self._write_one()
2016-03-09 14:33:36 -08:00
def _write_one(self):
try:
result = self.write_queue.get()
2016-03-09 14:33:36 -08:00
req = None
resp = None
req_head, req_pay, resp_head, resp_pay, params = result
2016-03-09 14:33:36 -08:00
if not self.rx_accept_colls.match(resp_head.get('WebAgg-Source-Coll', '')):
print('COLL', resp_head)
return
2016-03-09 14:33:36 -08:00
req = self._create_req_record(req_head, req_pay, 'request')
resp = self._create_resp_record(resp_head, resp_pay, 'response')
2016-03-09 14:33:36 -08:00
self.writer.write_req_resp(req, resp, params)
except:
traceback.print_exc()
finally:
try:
if req:
req.stream.close()
if resp:
resp.stream.close()
except Exception as e:
traceback.print_exc()
2016-03-09 14:33:36 -08:00
def _create_req_record(self, req_headers, payload, type_, ct=''):
len_ = payload.tell()
payload.seek(0)
warc_headers = req_headers
status_headers = self.parser.parse(payload)
record = ArcWarcRecord('warc', type_, warc_headers, payload,
status_headers, ct, len_)
return record
def _create_resp_record(self, resp_headers, payload, type_, ct=''):
2016-03-09 14:33:36 -08:00
len_ = payload.tell()
payload.seek(0)
warc_headers = self.parser.parse(payload)
warc_headers = CaseInsensitiveDict(warc_headers.headers)
status_headers = self.parser.parse(payload)
record = ArcWarcRecord('warc', type_, warc_headers, payload,
status_headers, ct, len_)
return record
def send_error(self, exc, start_response):
message = json.dumps({'error': repr(exc)})
headers = [('Content-Type', 'application/json; charset=utf-8'),
('Content-Length', str(len(message)))]
start_response('400 Bad Request', headers)
return [message.encode('utf-8')]
def _get_request_uri(self, env):
req_uri = env.get('REQUEST_URI')
if req_uri:
return req_uri
req_uri = quote(env.get('PATH_INFO', ''), safe='/~!$&\'()*+,;=:@')
query = env.get('QUERY_STRING')
if query:
req_uri += '?' + query
return req_uri
2016-03-09 14:33:36 -08:00
def __call__(self, environ, start_response):
request_uri = self._get_request_uri(environ)
2016-03-09 14:33:36 -08:00
input_req = DirectWSGIInputRequest(environ)
headers = input_req.get_req_headers()
method = input_req.get_req_method()
input_buff = input_req.get_req_body()
2016-03-09 14:33:36 -08:00
params = dict(parse_qsl(environ.get('QUERY_STRING')))
req_stream = ReqWrapper(input_buff, headers)
2016-03-09 14:33:36 -08:00
try:
res = remote_request(url=self.upstream_host + request_uri,
method=method,
data=req_stream,
headers=headers,
allow_redirects=False,
stream=True)
except Exception as e:
traceback.print_exc()
return self.send_error(e, start_response)
start_response('200 OK', list(res.headers.items()))
resp_stream = RespWrapper(res.raw, res.headers, req_stream, params, self.write_queue)
2016-03-09 14:33:36 -08:00
return StreamIter(ReadFullyStream(resp_stream))
#==============================================================================
class Wrapper(object):
def __init__(self, stream):
2016-03-09 14:33:36 -08:00
self.stream = stream
self.out = self._create_buffer()
def _create_buffer(self):
return tempfile.SpooledTemporaryFile(max_size=512*1024)
def read(self, limit=-1):
buff = self.stream.read()
self.out.write(buff)
return buff
def close(self):
try:
self.stream.close()
except:
traceback.print_exc()
finally:
self._after_close()
def _after_close(self):
pass
2016-03-09 14:33:36 -08:00
#==============================================================================
class RespWrapper(Wrapper):
def __init__(self, stream, headers, req,
params, queue):
super(RespWrapper, self).__init__(stream)
self.headers = headers
self.req = req
self.params = params
self.queue = queue
def _after_close(self):
if not self.req:
2016-03-09 14:33:36 -08:00
return
try:
entry = (self.req.headers, self.req.out,
2016-03-09 14:33:36 -08:00
self.headers, self.out, self.params)
self.queue.put(entry)
self.req = None
2016-03-09 14:33:36 -08:00
except:
traceback.print_exc()
#==============================================================================
class ReqWrapper(Wrapper):
def __init__(self, stream, req_headers):
super(ReqWrapper, self).__init__(stream)
self.headers = CaseInsensitiveDict(req_headers)
for n in req_headers.keys():
if not n.upper().startswith('WARC-'):
del self.headers[n]