1
0
mirror of https://github.com/webrecorder/pywb.git synced 2025-03-20 02:39:13 +01:00
pywb/inputrequest.py
Ilya Kreymer 398e8f1a77 inputrequest: add input request handling (direct wsgi headers) or as a prepared post request
add timemap link output
rename source_name -> source
2016-02-24 14:22:29 -08:00

137 lines
3.7 KiB
Python

from pywb.utils.loaders import extract_client_cookie
from pywb.utils.loaders import extract_post_query, append_post_query
from pywb.utils.loaders import LimitReader
from pywb.utils.statusandheaders import StatusAndHeadersParser
from six.moves.urllib.parse import urlsplit
from six import StringIO
import six
#=============================================================================
class WSGIInputRequest(object):
def __init__(self, env):
self.env = env
def get_req_method(self):
return self.env['REQUEST_METHOD'].upper()
def get_req_headers(self, url):
headers = {}
splits = urlsplit(url)
for name, value in six.iteritems(self.env):
if name == 'HTTP_HOST':
name = 'Host'
value = splits.netloc
elif name == 'HTTP_ORIGIN':
name = 'Origin'
value = (splits.scheme + '://' + splits.netloc)
elif name == 'HTTP_X_CSRFTOKEN':
name = 'X-CSRFToken'
cookie_val = extract_client_cookie(env, 'csrftoken')
if cookie_val:
value = cookie_val
elif name == 'HTTP_X_FORWARDED_PROTO':
name = 'X-Forwarded-Proto'
value = splits.scheme
elif name.startswith('HTTP_'):
name = name[5:].title().replace('_', '-')
elif name in ('CONTENT_LENGTH', 'CONTENT_TYPE'):
name = name.title().replace('_', '-')
else:
value = None
if value:
headers[name] = value
return headers
def get_req_body(self):
input_ = self.env.get('wsgi.input')
if not input_:
return None
len_ = self._get_content_length()
enc = self._get_header('Transfer-Encoding')
if len_:
data = LimitReader(input_, int(len_))
elif enc:
data = input_
else:
data = None
return data
#buf = data.read().decode('utf-8')
#print(buf)
#return StringIO(buf)
def _get_content_type(self):
return self.env.get('CONTENT_TYPE')
def _get_content_length(self):
return self.env.get('CONTENT_LENGTH')
def _get_header(self, name):
return self.env.get('HTTP_' + name.upper().replace('-', '_'))
def include_post_query(self, url):
if self.get_req_method() != 'POST':
return url
mime = self._get_content_type()
mime = mime.split(';')[0] if mime else ''
length = self._get_content_length()
stream = self.env['wsgi.input']
buffered_stream = StringIO()
post_query = extract_post_query('POST', mime, length, stream,
buffered_stream=buffered_stream)
if post_query:
self.env['wsgi.input'] = buffered_stream
url = append_post_query(url, post_query)
return url
#=============================================================================
class POSTInputRequest(WSGIInputRequest):
def __init__(self, env):
self.env = env
parser = StatusAndHeadersParser([], verify=False)
self.status_headers = parser.parse(self.env['wsgi.input'])
def get_req_method(self):
return self.status_headers.protocol
def get_req_headers(self, url):
headers = {}
for n, v in self.status_headers.headers:
headers[n] = v
return headers
def _get_content_type(self):
return self.status_headers.get_header('Content-Type')
def _get_content_length(self):
return self.status_headers.get_header('Content-Length')
def _get_header(self, name):
return self.status_headers.get_header(name)