from pywb.utils.loaders import extract_client_cookie from pywb.utils.loaders import extract_post_query, append_post_query from pywb.utils.loaders import LimitReader from pywb.utils.statusandheaders import StatusAndHeadersParser from six.moves.urllib.parse import urlsplit from six import StringIO import six #============================================================================= class WSGIInputRequest(object): def __init__(self, env): self.env = env def get_req_method(self): return self.env['REQUEST_METHOD'].upper() def get_req_headers(self, url): headers = {} splits = urlsplit(url) for name, value in six.iteritems(self.env): if name == 'HTTP_HOST': name = 'Host' value = splits.netloc elif name == 'HTTP_ORIGIN': name = 'Origin' value = (splits.scheme + '://' + splits.netloc) elif name == 'HTTP_X_CSRFTOKEN': name = 'X-CSRFToken' cookie_val = extract_client_cookie(env, 'csrftoken') if cookie_val: value = cookie_val elif name == 'HTTP_X_FORWARDED_PROTO': name = 'X-Forwarded-Proto' value = splits.scheme elif name.startswith('HTTP_'): name = name[5:].title().replace('_', '-') elif name in ('CONTENT_LENGTH', 'CONTENT_TYPE'): name = name.title().replace('_', '-') else: value = None if value: headers[name] = value return headers def get_req_body(self): input_ = self.env.get('wsgi.input') if not input_: return None len_ = self._get_content_length() enc = self._get_header('Transfer-Encoding') if len_: data = LimitReader(input_, int(len_)) elif enc: data = input_ else: data = None return data #buf = data.read().decode('utf-8') #print(buf) #return StringIO(buf) def _get_content_type(self): return self.env.get('CONTENT_TYPE') def _get_content_length(self): return self.env.get('CONTENT_LENGTH') def _get_header(self, name): return self.env.get('HTTP_' + name.upper().replace('-', '_')) def include_post_query(self, url): if self.get_req_method() != 'POST': return url mime = self._get_content_type() mime = mime.split(';')[0] if mime else '' length = self._get_content_length() stream = self.env['wsgi.input'] buffered_stream = StringIO() post_query = extract_post_query('POST', mime, length, stream, buffered_stream=buffered_stream) if post_query: self.env['wsgi.input'] = buffered_stream url = append_post_query(url, post_query) return url #============================================================================= class POSTInputRequest(WSGIInputRequest): def __init__(self, env): self.env = env parser = StatusAndHeadersParser([], verify=False) self.status_headers = parser.parse(self.env['wsgi.input']) def get_req_method(self): return self.status_headers.protocol def get_req_headers(self, url): headers = {} for n, v in self.status_headers.headers: headers[n] = v return headers def _get_content_type(self): return self.status_headers.get_header('Content-Type') def _get_content_length(self): return self.status_headers.get_header('Content-Length') def _get_header(self, name): return self.status_headers.get_header(name)