diff --git a/pywb/apps/rewriterapp.py b/pywb/apps/rewriterapp.py index fc78ac9e..13397561 100644 --- a/pywb/apps/rewriterapp.py +++ b/pywb/apps/rewriterapp.py @@ -374,7 +374,7 @@ class RewriterApp(object): urlrewriter.rewrite_opts['ua_string'] = environ.get('HTTP_USER_AGENT') - result = content_rw(record, urlrewriter, cookie_rewriter, head_insert_func, cdx) + result = content_rw(record, urlrewriter, cookie_rewriter, head_insert_func, cdx, environ) status_headers, gen, is_rw = result diff --git a/pywb/rewrite/content_rewriter.py b/pywb/rewrite/content_rewriter.py index 0918e9a1..f77cc13e 100644 --- a/pywb/rewrite/content_rewriter.py +++ b/pywb/rewrite/content_rewriter.py @@ -175,8 +175,9 @@ class BaseContentRewriter(object): def __call__(self, record, url_rewriter, cookie_rewriter, head_insert_func=None, - cdx=None): + cdx=None, environ=None): + environ = environ or {} rwinfo = RewriteInfo(record, self, url_rewriter, cookie_rewriter) content_rewriter = None @@ -192,6 +193,16 @@ class BaseContentRewriter(object): gen = None + # check if decoding is needed + if not rwinfo.is_content_rw: + content_encoding = rwinfo.record.http_headers.get_header('Content-Encoding') + accept_encoding = environ.get('HTTP_ACCEPT_ENCODING', '') + + # if content-encoding is set but encoding is not in accept encoding, + # enable content_rw force decompression + if content_encoding and content_encoding not in accept_encoding: + rwinfo.is_content_rw = True + if content_rewriter: gen = content_rewriter(rwinfo) elif rwinfo.is_content_rw: diff --git a/pywb/rewrite/test/test_content_rewriter.py b/pywb/rewrite/test/test_content_rewriter.py index 53f395d3..f7cd42bb 100644 --- a/pywb/rewrite/test/test_content_rewriter.py +++ b/pywb/rewrite/test/test_content_rewriter.py @@ -20,6 +20,7 @@ from pywb import get_test_dir import os import json import pytest +import six # ============================================================================ @@ -45,7 +46,8 @@ class TestContentRewriter(object): warc_headers = warc_headers or {} - payload = payload.encode('utf-8') + if isinstance(payload, six.text_type): + payload = payload.encode('utf-8') http_headers = StatusAndHeaders('200 OK', headers, protocol='HTTP/1.0') @@ -57,7 +59,7 @@ class TestContentRewriter(object): def rewrite_record(self, headers, content, ts, url='http://example.com/', prefix='http://localhost:8080/prefix/', warc_headers=None, - request_url=None, is_live=None, use_js_proxy=True): + request_url=None, is_live=None, use_js_proxy=True, environ=None): record = self._create_response_record(url, headers, content, warc_headers) @@ -73,9 +75,9 @@ class TestContentRewriter(object): cdx['is_live'] = is_live if use_js_proxy: - return self.js_proxy_content_rewriter(record, url_rewriter, None, cdx=cdx) + return self.js_proxy_content_rewriter(record, url_rewriter, None, cdx=cdx, environ=environ) else: - return self.content_rewriter(record, url_rewriter, None, cdx=cdx) + return self.content_rewriter(record, url_rewriter, None, cdx=cdx, environ=environ) def test_rewrite_html(self, headers): content = '' @@ -269,6 +271,42 @@ class TestContentRewriter(object): assert ('Transfer-Encoding', 'chunked') not in headers.headers + @pytest.mark.importorskip('brotli') + def test_brotli_accepted_no_change(self): + import brotli + content = brotli.compress('ABCDEFG'.encode('utf-8')) + + headers = {'Content-Type': 'application/octet-stream', + 'Content-Encoding': 'br', + 'Content-Length': str(len(content)) + } + + headers, gen, is_rw = self.rewrite_record(headers, content, ts='201701mp_', + environ={'HTTP_ACCEPT_ENCODING': 'gzip, deflate, br'}) + + assert headers['Content-Encoding'] == 'br' + assert headers['Content-Length'] == str(len(content)) + + assert brotli.decompress(b''.join(gen)).decode('utf-8') == 'ABCDEFG' + + @pytest.mark.importorskip('brotli') + def test_brotli_not_accepted_auto_decode(self): + import brotli + content = brotli.compress('ABCDEFG'.encode('utf-8')) + + headers = {'Content-Type': 'application/octet-stream', + 'Content-Encoding': 'br', + 'Content-Length': str(len(content)) + } + + headers, gen, is_rw = self.rewrite_record(headers, content, ts='201701mp_') + + assert 'Content-Encoding' not in headers + assert 'Content-Length' not in headers + assert headers['X-Archive-Orig-Content-Encoding'] == 'br' + + assert b''.join(gen).decode('utf-8') == 'ABCDEFG' + def test_rewrite_json(self): headers = {'Content-Type': 'application/json'} content = '/**/ jQuery_ABC({"foo": "bar"});'