diff --git a/webagg/inputrequest.py b/webagg/inputrequest.py index f15de60b..207db397 100644 --- a/webagg/inputrequest.py +++ b/webagg/inputrequest.py @@ -4,7 +4,7 @@ from pywb.utils.statusandheaders import StatusAndHeadersParser from six.moves.urllib.parse import urlsplit, quote from six import iteritems -from io import BytesIO +from io import BytesIO, StringIO #============================================================================= @@ -15,6 +15,9 @@ class DirectWSGIInputRequest(object): def get_req_method(self): return self.env['REQUEST_METHOD'].upper() + def get_req_protocol(self): + return self.env['SERVER_PROTOCOL'] + def get_req_headers(self): headers = {} @@ -92,6 +95,38 @@ class DirectWSGIInputRequest(object): return req_uri + def reconstruct_request(self, url=None): + buff = StringIO() + buff.write(self.get_req_method()) + buff.write(' ') + buff.write(self.get_full_request_uri()) + buff.write(' ') + buff.write(self.get_req_protocol()) + buff.write('\r\n') + + headers = self.get_req_headers() + + if url: + parts = urlsplit(url) + buff.write('Host: ') + buff.write(parts.netloc) + buff.write('\r\n') + + for name, value in iteritems(headers): + buff.write(name) + buff.write(': ') + buff.write(value) + buff.write('\r\n') + + buff.write('\r\n') + buff = buff.getvalue().encode('latin-1') + + body = self.get_req_body() + if body: + buff += body.read() + + return buff + #============================================================================= class POSTInputRequest(DirectWSGIInputRequest): @@ -112,6 +147,12 @@ class POSTInputRequest(DirectWSGIInputRequest): return headers + def get_full_request_uri(self): + return self.status_headers.statusline.split(' ', 1)[0] + + def get_req_protocol(self): + return self.status_headers.statusline.split(' ', 1)[-1] + def _get_content_type(self): return self.status_headers.get_header('Content-Type') diff --git a/webagg/test/test_handlers.py b/webagg/test/test_handlers.py index 5b9e510f..cefaa99d 100644 --- a/webagg/test/test_handlers.py +++ b/webagg/test/test_handlers.py @@ -16,7 +16,6 @@ from pywb.utils.bufferedreaders import ChunkedDataReader from io import BytesIO import webtest -import bottle from .testutils import to_path diff --git a/webagg/test/test_inputreq.py b/webagg/test/test_inputreq.py new file mode 100644 index 00000000..7aca5b6a --- /dev/null +++ b/webagg/test/test_inputreq.py @@ -0,0 +1,71 @@ +from webagg.inputrequest import DirectWSGIInputRequest, POSTInputRequest +from bottle import Bottle, request, response +import webtest +import traceback + + +#============================================================================= +class InputReqApp(object): + def __init__(self): + self.application = Bottle() + self.application.default_error_handler = self.err_handler + + @self.application.route('/test/', 'ANY') + def direct_input_request(url=''): + inputreq = DirectWSGIInputRequest(request.environ) + response['Content-Type'] = 'text/plain; charset=utf-8' + return inputreq.reconstruct_request(url) + + @self.application.route('/test-postreq', 'POST') + def post_fullrequest(): + params = dict(request.query) + inputreq = POSTInputRequest(request.environ) + response['Content-Type'] = 'text/plain; charset=utf-8' + return inputreq.reconstruct_request(params.get('url')) + + def err_handler(self, out): + print(out) + traceback.print_exc() + + +#============================================================================= +class TestInputReq(object): + def setup(self): + self.app = InputReqApp() + self.testapp = webtest.TestApp(self.app.application) + + def test_get_direct(self): + res = self.testapp.get('/test/http://example.com/', headers={'Foo': 'Bar'}) + assert res.text == '\ +GET /test/http://example.com/ HTTP/1.0\r\n\ +Host: example.com\r\n\ +Foo: Bar\r\n\ +\r\n\ +' + + def test_post_direct(self): + res = self.testapp.post('/test/http://example.com/', headers={'Foo': 'Bar'}, params='ABC') + lines = res.text.split('\r\n') + assert lines[0] == 'POST /test/http://example.com/ HTTP/1.0' + assert 'Host: example.com' in lines + assert 'Content-Length: 3' in lines + assert 'Content-Type: application/x-www-form-urlencoded' in lines + assert 'Foo: Bar' in lines + + assert 'ABC' in lines + + def test_post_req(self): + postdata = '\ +GET /example.html HTTP/1.0\r\n\ +Foo: Bar\r\n\ +\r\n\ +' + res = self.testapp.post('/test-postreq?url=http://example.com/', params=postdata) + + assert res.text == '\ +GET /example.html HTTP/1.0\r\n\ +Host: example.com\r\n\ +Foo: Bar\r\n\ +\r\n\ +' +