From 4bb35567095dae3837c09738b382b63cfcbcfc52 Mon Sep 17 00:00:00 2001
From: Noah Levitt <nlevitt@archive.org>
Date: Tue, 10 May 2016 23:11:47 +0000
Subject: [PATCH] implement enforcement of Warcprox-Meta header block rules;
 includes automated tests

---
 setup.py               |   2 +-
 tests/test_warcprox.py | 159 +++++++++++++++++++++++++++++++++++------
 warcprox/__init__.py   |  53 ++++++++------
 warcprox/main.py       |   5 +-
 warcprox/mitmproxy.py  |  32 +++++----
 warcprox/warcproxy.py  | 146 +++++++++++++++++++++++++++++++++++--
 6 files changed, 331 insertions(+), 66 deletions(-)

diff --git a/setup.py b/setup.py
index e543c33..6584d18 100755
--- a/setup.py
+++ b/setup.py
@@ -50,7 +50,7 @@ except:
     deps.append('futures')
 
 setuptools.setup(name='warcprox',
-        version='2.0.dev8',
+        version='2.0.dev9',
         description='WARC writing MITM HTTP/S proxy',
         url='https://github.com/internetarchive/warcprox',
         author='Noah Levitt',
diff --git a/tests/test_warcprox.py b/tests/test_warcprox.py
index a17d2ea..45933b5 100755
--- a/tests/test_warcprox.py
+++ b/tests/test_warcprox.py
@@ -1,24 +1,24 @@
 #!/usr/bin/env python
-#
-# tests/test_warcprox.py - automated tests for warcprox
-#
-# Copyright (C) 2013-2016 Internet Archive
-#
-# This program is free software; you can redistribute it and/or
-# modify it under the terms of the GNU General Public License
-# as published by the Free Software Foundation; either version 2
-# of the License, or (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program; if not, write to the Free Software
-# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301,
-# USA.
-#
+'''
+tests/test_warcprox.py - automated tests for warcprox
+
+Copyright (C) 2013-2016 Internet Archive
+
+This program is free software; you can redistribute it and/or
+modify it under the terms of the GNU General Public License
+as published by the Free Software Foundation; either version 2
+of the License, or (at your option) any later version.
+
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with this program; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301,
+USA.
+'''
 
 import pytest
 import threading
@@ -58,7 +58,8 @@ import certauth.certauth
 import warcprox
 
 logging.basicConfig(stream=sys.stdout, level=logging.INFO,
-        format='%(asctime)s %(process)d %(levelname)s %(threadName)s %(name)s.%(funcName)s(%(filename)s:%(lineno)d) %(message)s')
+        format='%(asctime)s %(process)d %(levelname)s %(threadName)s '
+        '%(name)s.%(funcName)s(%(filename)s:%(lineno)d) %(message)s')
 logging.getLogger("requests.packages.urllib3").setLevel(logging.WARN)
 warnings.simplefilter("ignore", category=requests.packages.urllib3.exceptions.InsecureRequestWarning)
 warnings.simplefilter("ignore", category=requests.packages.urllib3.exceptions.InsecurePlatformWarning)
@@ -137,8 +138,8 @@ def cert(request):
 
 @pytest.fixture(scope="module")
 def http_daemon(request):
-    http_daemon = http_server.HTTPServer(('localhost', 0),
-            RequestHandlerClass=_TestHttpRequestHandler)
+    http_daemon = http_server.HTTPServer(
+            ('localhost', 0), RequestHandlerClass=_TestHttpRequestHandler)
     logging.info('starting http://{}:{}'.format(http_daemon.server_address[0], http_daemon.server_address[1]))
     http_daemon_thread = threading.Thread(name='HttpDaemonThread',
             target=http_daemon.serve_forever)
@@ -725,6 +726,118 @@ def test_dedup_buckets(https_daemon, http_daemon, warcprox_, archiving_proxies,
     finally:
         fh.close()
 
+def test_block_rules(http_daemon, https_daemon, warcprox_, archiving_proxies):
+    rules = [
+        {
+            "host": "localhost",
+            "url_match": "STRING_MATCH",
+            "value": "bar",
+        },
+        {
+            "url_match": "SURT_MATCH",
+            "value": "http://(localhost:%s,)/fuh/" % (http_daemon.server_port),
+        },
+        {
+            "url_match": "SURT_MATCH",
+            # this rule won't match because of http scheme, https port
+            "value": "http://(localhost:%s,)/fuh/" % (https_daemon.server_port),
+        },
+        {
+            "host": "badhost.com",
+        },
+    ]
+    request_meta = {"blocks":rules}
+    headers = {"Warcprox-Meta":json.dumps(request_meta)}
+
+    # blocked by STRING_MATCH rule
+    url = 'http://localhost:{}/bar'.format(http_daemon.server_port)
+    response = requests.get(
+            url, proxies=archiving_proxies, headers=headers, stream=True)
+    assert response.status_code == 403
+    assert response.content.startswith(b"request rejected by warcprox: blocked by rule found in Warcprox-Meta header:")
+    assert json.loads(response.headers['warcprox-meta']) == {"blocked-by-rule":rules[0]}
+
+    # not blocked
+    url = 'http://localhost:{}/m/n'.format(http_daemon.server_port)
+    response = requests.get(
+            url, proxies=archiving_proxies, headers=headers, stream=True)
+    assert response.status_code == 200
+
+    # blocked by SURT_MATCH
+    url = 'http://localhost:{}/fuh/guh'.format(http_daemon.server_port)
+    # logging.info("%s => %s", repr(url), repr(warcprox.warcproxy.Url(url).surt))
+    response = requests.get(
+            url, proxies=archiving_proxies, headers=headers, stream=True)
+    assert response.status_code == 403
+    assert response.content.startswith(b"request rejected by warcprox: blocked by rule found in Warcprox-Meta header:")
+    assert json.loads(response.headers['warcprox-meta']) == {"blocked-by-rule":rules[1]}
+
+    # not blocked (no trailing slash)
+    url = 'http://localhost:{}/fuh'.format(http_daemon.server_port)
+    response = requests.get(
+            url, proxies=archiving_proxies, headers=headers, stream=True)
+    # 404 because server set up at the top of this file doesn't handle this url
+    assert response.status_code == 404
+
+    # not blocked because surt scheme does not match (differs from heritrix
+    # behavior where https urls are coerced to http surt form)
+    url = 'https://localhost:{}/fuh/guh'.format(https_daemon.server_port)
+    response = requests.get(
+            url, proxies=archiving_proxies, headers=headers, stream=True,
+            verify=False)
+    assert response.status_code == 200
+
+    # blocked by blanket host block
+    url = 'http://badhost.com/'
+    response = requests.get(
+            url, proxies=archiving_proxies, headers=headers, stream=True)
+    assert response.status_code == 403
+    assert response.content.startswith(b"request rejected by warcprox: blocked by rule found in Warcprox-Meta header:")
+    assert json.loads(response.headers['warcprox-meta']) == {"blocked-by-rule":rules[3]}
+
+    # blocked by blanket host block
+    url = 'https://badhost.com/'
+    response = requests.get(
+            url, proxies=archiving_proxies, headers=headers, stream=True,
+            verify=False)
+    assert response.status_code == 403
+    assert response.content.startswith(b"request rejected by warcprox: blocked by rule found in Warcprox-Meta header:")
+    assert json.loads(response.headers['warcprox-meta']) == {"blocked-by-rule":rules[3]}
+
+    # blocked by blanket host block
+    url = 'http://badhost.com:1234/'
+    response = requests.get(
+            url, proxies=archiving_proxies, headers=headers, stream=True)
+    assert response.status_code == 403
+    assert response.content.startswith(b"request rejected by warcprox: blocked by rule found in Warcprox-Meta header:")
+    assert json.loads(response.headers['warcprox-meta']) == {"blocked-by-rule":rules[3]}
+
+    # blocked by blanket host block
+    url = 'http://foo.bar.badhost.com/'
+    response = requests.get(
+            url, proxies=archiving_proxies, headers=headers, stream=True)
+    assert response.status_code == 403
+    assert response.content.startswith(b"request rejected by warcprox: blocked by rule found in Warcprox-Meta header:")
+    assert json.loads(response.headers['warcprox-meta']) == {"blocked-by-rule":rules[3]}
+
+    # host block also applies to subdomains
+    url = 'https://foo.bar.badhost.com/'
+    response = requests.get(
+            url, proxies=archiving_proxies, headers=headers, stream=True,
+            verify=False)
+    assert response.status_code == 403
+    assert response.content.startswith(b"request rejected by warcprox: blocked by rule found in Warcprox-Meta header:")
+    assert json.loads(response.headers['warcprox-meta']) == {"blocked-by-rule":rules[3]}
+
+    # blocked by blanket host block
+    url = 'http://foo.bar.badhost.com:1234/'
+    response = requests.get(
+            url, proxies=archiving_proxies, headers=headers, stream=True)
+    assert response.status_code == 403
+    assert response.content.startswith(b"request rejected by warcprox: blocked by rule found in Warcprox-Meta header:")
+    assert json.loads(response.headers['warcprox-meta']) == {"blocked-by-rule":rules[3]}
+
+
 # XXX this test relies on a tor proxy running at localhost:9050 with a working
 # connection to the internet, and relies on a third party site (facebook) being
 # up and behaving a certain way
diff --git a/warcprox/__init__.py b/warcprox/__init__.py
index 394d8b8..1eeb9a4 100644
--- a/warcprox/__init__.py
+++ b/warcprox/__init__.py
@@ -1,23 +1,23 @@
-#
-# warcprox/__init__.py - warcprox package main file, contains some utility code
-#
-# Copyright (C) 2013-2016 Internet Archive
-#
-# This program is free software; you can redistribute it and/or
-# modify it under the terms of the GNU General Public License
-# as published by the Free Software Foundation; either version 2
-# of the License, or (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program; if not, write to the Free Software
-# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301,
-# USA.
-#
+"""
+warcprox/__init__.py - warcprox package main file, contains some utility code
+
+Copyright (C) 2013-2016 Internet Archive
+
+This program is free software; you can redistribute it and/or
+modify it under the terms of the GNU General Public License
+as published by the Free Software Foundation; either version 2
+of the License, or (at your option) any later version.
+
+This program is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with this program; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301,
+USA.
+"""
 
 from argparse import Namespace as _Namespace
 from pkg_resources import get_distribution as _get_distribution
@@ -47,6 +47,19 @@ def gettid():
     except:
         return "n/a"
 
+class RequestBlockedByRule(Exception):
+    """
+    An exception raised when a request should be blocked to respect a
+    Warcprox-Meta rule.
+    """
+    def __init__(self, msg):
+        self.msg = msg
+    def __str__(self):
+        return "%s: %s" % (self.__class__.__name__, self.msg)
+
+# logging level more fine-grained than logging.DEBUG==10
+TRACE = 5
+
 import warcprox.controller as controller
 import warcprox.playback as playback
 import warcprox.dedup as dedup
diff --git a/warcprox/main.py b/warcprox/main.py
index d8529b5..00d2d85 100644
--- a/warcprox/main.py
+++ b/warcprox/main.py
@@ -109,6 +109,7 @@ def _build_arg_parser(prog=os.path.basename(sys.argv[0])):
     arg_parser.add_argument('--version', action='version',
             version="warcprox {}".format(warcprox.__version__))
     arg_parser.add_argument('-v', '--verbose', dest='verbose', action='store_true')
+    arg_parser.add_argument('--trace', dest='trace', action='store_true')
     arg_parser.add_argument('-q', '--quiet', dest='quiet', action='store_true')
 
     return arg_parser
@@ -232,7 +233,9 @@ def main(argv=sys.argv):
     '''
     args = parse_args(argv)
 
-    if args.verbose:
+    if args.trace:
+        loglevel = warcprox.TRACE
+    elif args.verbose:
         loglevel = logging.DEBUG
     elif args.quiet:
         loglevel = logging.WARNING
diff --git a/warcprox/mitmproxy.py b/warcprox/mitmproxy.py
index 333bad7..3950e4e 100644
--- a/warcprox/mitmproxy.py
+++ b/warcprox/mitmproxy.py
@@ -241,6 +241,8 @@ class MitmProxyHandler(http_server.BaseHTTPRequestHandler):
                             self.hostname)
                     raise
 
+        return self._remote_server_sock
+
     def _transition_to_ssl(self):
         self.request = self.connection = ssl.wrap_socket(self.connection,
                 server_side=True, certfile=self.server.ca.cert_for_host(self.hostname))
@@ -262,9 +264,7 @@ class MitmProxyHandler(http_server.BaseHTTPRequestHandler):
         '''
         self.is_connect = True
         try:
-            # Connect to destination first
             self._determine_host_port()
-            self._connect_to_remote_server()
 
             # If successful, let's do this!
             self.send_response(200, 'Connection established')
@@ -305,19 +305,23 @@ class MitmProxyHandler(http_server.BaseHTTPRequestHandler):
         return result
 
     def do_COMMAND(self):
-        if not self.is_connect:
-            try:
-                # Connect to destination
-                self._determine_host_port()
-                self._connect_to_remote_server()
-                assert self.url
-            except Exception as e:
-                self.logger.error("problem processing request {}: {}".format(repr(self.requestline), e))
-                self.send_error(500, str(e))
-                return
-        else:
-            # if self.is_connect we already connected in do_CONNECT
+        if self.is_connect:
             self.url = self._construct_tunneled_url()
+        else:
+            self._determine_host_port()
+            assert self.url
+
+        try:
+            # Connect to destination
+            self._connect_to_remote_server()
+        except warcprox.RequestBlockedByRule as e:
+            # limit enforcers have already sent the appropriate response
+            self.logger.info("%s: %s", repr(self.requestline), e)
+            return
+        except Exception as e:
+            self.logger.error("problem processing request {}: {}".format(repr(self.requestline), e))
+            self.send_error(500, str(e))
+            return
 
         try:
             self._proxy_request()
diff --git a/warcprox/warcproxy.py b/warcprox/warcproxy.py
index 0549179..d342774 100644
--- a/warcprox/warcproxy.py
+++ b/warcprox/warcproxy.py
@@ -45,18 +45,135 @@ import warcprox
 import datetime
 import concurrent.futures
 import resource
+import ipaddress
+import surt
+
+class Url:
+    def __init__(self, url):
+        self.url = url
+        self._surt = None
+        self._host = None
+
+    @property
+    def surt(self):
+        if not self._surt:
+            hurl = surt.handyurl.parse(self.url)
+            surt.GoogleURLCanonicalizer.canonicalize(hurl)
+            hurl.query = None
+            hurl.hash = None
+            self._surt = hurl.getURLString(surt=True, trailing_comma=True)
+        return self._surt
+
+    @property
+    def host(self):
+        if not self._host:
+            self._host = surt.handyurl.parse(self.url).host
+        return self._host
+
+    def matches_ip_or_domain(self, ip_or_domain):
+        """Returns true if
+           - ip_or_domain is an ip address and self.host is the same ip address
+           - ip_or_domain is a domain and self.host is the same domain
+           - ip_or_domain is a domain and self.host is a subdomain of it
+        """
+        if ip_or_domain == self.host:
+            return True
+
+        # if either ip_or_domain or self.host are ip addresses, and they're not
+        # identical (previous check), not a match
+        try:
+            ipaddress.ip_address(ip_or_domain)
+            return False
+        except:
+            pass
+        try:
+            ipaddress.ip_address(self.host)
+            return False
+        except:
+            pass
+
+        # if we get here, we're looking at two hostnames
+        # XXX do we need to handle case of one punycoded idn, other not?
+        domain_parts = ip_or_domain.split(".")
+        host_parts = self.host.split(".")
+
+        return host_parts[-len(domain_parts):] == domain_parts
 
 class WarcProxyHandler(warcprox.mitmproxy.MitmProxyHandler):
     # self.server is WarcProxy
     logger = logging.getLogger("warcprox.warcprox.WarcProxyHandler")
 
+    # XXX nearly identical to brozzler.site.Site._scope_rule_applies() but
+    # there's no obvious common dependency where this code should go... TBD
+    def _scope_rule_applies(self, rule):
+        u = Url(self.url)
+
+        if "host" in rule and not u.matches_ip_or_domain(rule["host"]):
+            return False
+        if "url_match" in rule:
+            if rule["url_match"] == "STRING_MATCH":
+                return u.url.find(rule["value"]) >= 0
+            elif rule["url_match"] == "REGEX_MATCH":
+                try:
+                    return re.fullmatch(rule["value"], u.url)
+                except Exception as e:
+                    self.logger.warn(
+                            "caught exception matching against regex %s: %s",
+                            rule["value"], e)
+                    return False
+            elif rule["url_match"] == "SURT_MATCH":
+                return u.surt.startswith(rule["value"])
+            else:
+                self.logger.warn("invalid rule.url_match=%s", rule.url_match)
+                return False
+        else:
+            if "host" in rule:
+                # we already know that it matches from earlier check
+                return True
+            else:
+                self.logger.warn("unable to make sense of scope rule %s", rule)
+                return False
+
+    def _enforce_blocks(self, warcprox_meta):
+        """
+        Sends a 403 response and raises warcprox.RequestBlockedByRule if the
+        url is blocked by a rule in warcprox_meta.
+        """
+        if warcprox_meta and "blocks" in warcprox_meta:
+            for rule in warcprox_meta["blocks"]:
+                if self._scope_rule_applies(rule):
+                    body = ("request rejected by warcprox: blocked by "
+                            "rule found in Warcprox-Meta header: %s"
+                            % rule).encode("utf-8")
+                    self.send_response(403, "Forbidden")
+                    self.send_header("Content-Type", "text/plain;charset=utf-8")
+                    self.send_header("Connection", "close")
+                    self.send_header("Content-Length", len(body))
+                    response_meta = {"blocked-by-rule":rule}
+                    self.send_header(
+                            "Warcprox-Meta",
+                            json.dumps(response_meta, separators=(",",":")))
+                    self.end_headers()
+                    if self.command != "HEAD":
+                        self.wfile.write(body)
+                    self.connection.close()
+                    raise warcprox.RequestBlockedByRule(
+                            "%s 403 %s %s -- blocked by rule in Warcprox-Meta "
+                            "request header %s" % (
+                                self.client_address[0], self.command,
+                                self.url, rule))
+
     def _enforce_limits(self, warcprox_meta):
+        """
+        Sends a 420 response and raises warcprox.RequestBlockedByRule if a
+        limit specified in warcprox_meta is reached.
+        """
         if warcprox_meta and "limits" in warcprox_meta:
             for item in warcprox_meta["limits"].items():
                 key, limit = item
                 bucket0, bucket1, bucket2 = key.rsplit(".", 2)
                 value = self.server.stats_db.value(bucket0, bucket1, bucket2)
-                self.logger.debug("warcprox_meta['limits']=%s stats['%s']=%s recorded_url_q.qsize()=%s", 
+                self.logger.debug("warcprox_meta['limits']=%s stats['%s']=%s recorded_url_q.qsize()=%s",
                         warcprox_meta['limits'], key, value, self.server.recorded_url_q.qsize())
                 if value and value >= limit:
                     body = "request rejected by warcprox: reached limit {}={}\n".format(key, limit).encode("utf-8")
@@ -70,20 +187,35 @@ class WarcProxyHandler(warcprox.mitmproxy.MitmProxyHandler):
                     if self.command != "HEAD":
                         self.wfile.write(body)
                     self.connection.close()
-                    self.logger.info("%s 420 %s %s -- reached limit %s=%s", self.client_address[0], self.command, self.url, key, limit)
-                    return True
-        return False
+                    raise warcprox.RequestBlockedByRule(
+                            "%s 420 %s %s -- reached limit %s=%s" % (
+                                self.client_address[0], self.command,
+                                self.url, key, limit))
+
+    def _connect_to_remote_server(self):
+        '''
+        Wraps MitmProxyHandler._connect_to_remote_server, first enforcing
+        limits and block rules in the Warcprox-Meta request header, if any.
+        Raises warcprox.RequestBlockedByRule if a rule has been enforced.
+        Otherwise calls MitmProxyHandler._connect_to_remote_server, which
+        initializes self._remote_server_sock.
+        '''
+        if 'Warcprox-Meta' in self.headers:
+            warcprox_meta = json.loads(self.headers['Warcprox-Meta'])
+            self._enforce_limits(warcprox_meta)
+            self._enforce_blocks(warcprox_meta)
+        return warcprox.mitmproxy.MitmProxyHandler._connect_to_remote_server(self)
 
     def _proxy_request(self):
         warcprox_meta = None
         raw_warcprox_meta = self.headers.get('Warcprox-Meta')
+        self.logger.log(
+                warcprox.TRACE, 'request for %s Warcprox-Meta header: %s',
+                self.url, repr(raw_warcprox_meta))
         if raw_warcprox_meta:
             warcprox_meta = json.loads(raw_warcprox_meta)
             del self.headers['Warcprox-Meta']
 
-        if self._enforce_limits(warcprox_meta):
-            return
-
         remote_ip = self._remote_server_sock.getpeername()[0]
         timestamp = datetime.datetime.utcnow()