from pywb.utils.loaders import extract_post_query, append_post_query
from pywb.utils.loaders import LimitReader
from pywb.utils.statusandheaders import StatusAndHeadersParser

from six.moves.urllib.parse import urlsplit, quote
from six import iteritems, StringIO
from io import BytesIO


#=============================================================================
class DirectWSGIInputRequest(object):
    def __init__(self, env):
        self.env = env

    def get_req_method(self):
        return self.env['REQUEST_METHOD'].upper()

    def get_req_protocol(self):
        return self.env['SERVER_PROTOCOL']

    def get_req_headers(self):
        headers = {}

        for name, value in iteritems(self.env):
            # will be set by requests to match actual host
            if name == 'HTTP_HOST':
                continue

            elif name.startswith('HTTP_'):
                name = name[5:].title().replace('_', '-')

            elif name in ('CONTENT_LENGTH', 'CONTENT_TYPE'):
                name = name.title().replace('_', '-')

            else:
                value = None

            if value:
                headers[name] = value

        return headers

    def get_req_body(self):
        input_ = self.env['wsgi.input']
        len_ = self._get_content_length()
        enc = self._get_header('Transfer-Encoding')

        if len_:
            data = LimitReader(input_, int(len_))
        elif enc:
            data = input_
        else:
            data = None

        return data

    def _get_content_type(self):
        return self.env.get('CONTENT_TYPE')

    def _get_content_length(self):
        return self.env.get('CONTENT_LENGTH')

    def _get_header(self, name):
        return self.env.get('HTTP_' + name.upper().replace('-', '_'))

    def include_post_query(self, url):
        if not url or self.get_req_method() != 'POST':
            return url

        mime = self._get_content_type()
        #mime = mime.split(';')[0] if mime else ''
        length = self._get_content_length()
        stream = self.env['wsgi.input']

        buffered_stream = BytesIO()

        post_query = extract_post_query('POST', mime, length, stream,
                                        buffered_stream=buffered_stream,
                                        environ=self.env)

        if post_query:
            self.env['wsgi.input'] = buffered_stream
            url = append_post_query(url, post_query)

        return url

    def get_full_request_uri(self):
        req_uri = self.env.get('REQUEST_URI')
        if req_uri and not self.env.get('SCRIPT_NAME'):
            return req_uri

        req_uri = quote(self.env.get('PATH_INFO', ''), safe='/~!$&\'()*+,;=:@')
        query = self.env.get('QUERY_STRING')
        if query:
            req_uri += '?' + query

        return req_uri

    def reconstruct_request(self, url=None):
        buff = StringIO()
        buff.write(self.get_req_method())
        buff.write(' ')
        buff.write(self.get_full_request_uri())
        buff.write(' ')
        buff.write(self.get_req_protocol())
        buff.write('\r\n')

        headers = self.get_req_headers()

        if url:
            parts = urlsplit(url)
            buff.write('Host: ')
            buff.write(parts.netloc)
            buff.write('\r\n')

        for name, value in iteritems(headers):
            if name.lower() == 'host':
                continue

            buff.write(name)
            buff.write(': ')
            buff.write(value)
            buff.write('\r\n')

        buff.write('\r\n')
        buff = buff.getvalue().encode('latin-1')

        body = self.get_req_body()
        if body:
            buff += body.read()

        return buff


#=============================================================================
class POSTInputRequest(DirectWSGIInputRequest):
    def __init__(self, env):
        self.env = env

        parser = StatusAndHeadersParser([], verify=False)

        self.status_headers = parser.parse(self.env['wsgi.input'])

    def get_req_method(self):
        return self.status_headers.protocol

    def get_req_headers(self):
        headers = {}
        for n, v in self.status_headers.headers:
            headers[n] = v

        return headers

    def get_full_request_uri(self):
        return self.status_headers.statusline.split(' ', 1)[0]

    def get_req_protocol(self):
        return self.status_headers.statusline.split(' ', 1)[-1]

    def _get_content_type(self):
        return self.status_headers.get_header('Content-Type')

    def _get_content_length(self):
        return self.status_headers.get_header('Content-Length')

    def _get_header(self, name):
        return self.status_headers.get_header(name)