mirror of
https://github.com/webrecorder/pywb.git
synced 2025-03-23 06:32:24 +01:00
use exc str instead of repr for error message for consistency all tests pass on py2 and py3 again!
170 lines
4.6 KiB
Python
170 lines
4.6 KiB
Python
from pywb.utils.loaders import extract_post_query, append_post_query
|
|
from pywb.utils.loaders import LimitReader
|
|
from pywb.utils.statusandheaders import StatusAndHeadersParser
|
|
|
|
from six.moves.urllib.parse import urlsplit, quote
|
|
from six import iteritems, StringIO
|
|
from io import BytesIO
|
|
|
|
|
|
#=============================================================================
|
|
class DirectWSGIInputRequest(object):
|
|
def __init__(self, env):
|
|
self.env = env
|
|
|
|
def get_req_method(self):
|
|
return self.env['REQUEST_METHOD'].upper()
|
|
|
|
def get_req_protocol(self):
|
|
return self.env['SERVER_PROTOCOL']
|
|
|
|
def get_req_headers(self):
|
|
headers = {}
|
|
|
|
for name, value in iteritems(self.env):
|
|
# will be set by requests to match actual host
|
|
if name == 'HTTP_HOST':
|
|
continue
|
|
|
|
elif name.startswith('HTTP_'):
|
|
name = name[5:].title().replace('_', '-')
|
|
|
|
elif name in ('CONTENT_LENGTH', 'CONTENT_TYPE'):
|
|
name = name.title().replace('_', '-')
|
|
|
|
else:
|
|
value = None
|
|
|
|
if value:
|
|
headers[name] = value
|
|
|
|
return headers
|
|
|
|
def get_req_body(self):
|
|
input_ = self.env['wsgi.input']
|
|
len_ = self._get_content_length()
|
|
enc = self._get_header('Transfer-Encoding')
|
|
|
|
if len_:
|
|
data = LimitReader(input_, int(len_))
|
|
elif enc:
|
|
data = input_
|
|
else:
|
|
data = None
|
|
|
|
return data
|
|
|
|
def _get_content_type(self):
|
|
return self.env.get('CONTENT_TYPE')
|
|
|
|
def _get_content_length(self):
|
|
return self.env.get('CONTENT_LENGTH')
|
|
|
|
def _get_header(self, name):
|
|
return self.env.get('HTTP_' + name.upper().replace('-', '_'))
|
|
|
|
def include_post_query(self, url):
|
|
if not url or self.get_req_method() != 'POST':
|
|
return url
|
|
|
|
mime = self._get_content_type()
|
|
mime = mime.split(';')[0] if mime else ''
|
|
length = self._get_content_length()
|
|
stream = self.env['wsgi.input']
|
|
|
|
buffered_stream = BytesIO()
|
|
|
|
post_query = extract_post_query('POST', mime, length, stream,
|
|
buffered_stream=buffered_stream)
|
|
|
|
if post_query:
|
|
self.env['wsgi.input'] = buffered_stream
|
|
url = append_post_query(url, post_query)
|
|
|
|
return url
|
|
|
|
def get_full_request_uri(self):
|
|
req_uri = self.env.get('REQUEST_URI')
|
|
if req_uri and not self.env.get('SCRIPT_NAME'):
|
|
return req_uri
|
|
|
|
req_uri = quote(self.env.get('PATH_INFO', ''), safe='/~!$&\'()*+,;=:@')
|
|
query = self.env.get('QUERY_STRING')
|
|
if query:
|
|
req_uri += '?' + query
|
|
|
|
return req_uri
|
|
|
|
def reconstruct_request(self, url=None):
|
|
buff = StringIO()
|
|
buff.write(self.get_req_method())
|
|
buff.write(' ')
|
|
buff.write(self.get_full_request_uri())
|
|
buff.write(' ')
|
|
buff.write(self.get_req_protocol())
|
|
buff.write('\r\n')
|
|
|
|
headers = self.get_req_headers()
|
|
|
|
if url:
|
|
parts = urlsplit(url)
|
|
buff.write('Host: ')
|
|
buff.write(parts.netloc)
|
|
buff.write('\r\n')
|
|
|
|
for name, value in iteritems(headers):
|
|
if name.lower() == 'host':
|
|
continue
|
|
|
|
buff.write(name)
|
|
buff.write(': ')
|
|
buff.write(value)
|
|
buff.write('\r\n')
|
|
|
|
buff.write('\r\n')
|
|
buff = buff.getvalue().encode('latin-1')
|
|
|
|
body = self.get_req_body()
|
|
if body:
|
|
buff += body.read()
|
|
|
|
return buff
|
|
|
|
|
|
#=============================================================================
|
|
class POSTInputRequest(DirectWSGIInputRequest):
|
|
def __init__(self, env):
|
|
self.env = env
|
|
|
|
parser = StatusAndHeadersParser([], verify=False)
|
|
|
|
self.status_headers = parser.parse(self.env['wsgi.input'])
|
|
|
|
def get_req_method(self):
|
|
return self.status_headers.protocol
|
|
|
|
def get_req_headers(self):
|
|
headers = {}
|
|
for n, v in self.status_headers.headers:
|
|
headers[n] = v
|
|
|
|
return headers
|
|
|
|
def get_full_request_uri(self):
|
|
return self.status_headers.statusline.split(' ', 1)[0]
|
|
|
|
def get_req_protocol(self):
|
|
return self.status_headers.statusline.split(' ', 1)[-1]
|
|
|
|
def _get_content_type(self):
|
|
return self.status_headers.get_header('Content-Type')
|
|
|
|
def _get_content_length(self):
|
|
return self.status_headers.get_header('Content-Length')
|
|
|
|
def _get_header(self, name):
|
|
return self.status_headers.get_header(name)
|
|
|
|
|
|
|