diff --git a/pywb/apps/frontendapp.py b/pywb/apps/frontendapp.py index 41f5232f..9aabe7be 100644 --- a/pywb/apps/frontendapp.py +++ b/pywb/apps/frontendapp.py @@ -17,13 +17,14 @@ from pywb.recorder.recorderapp import RecorderApp from pywb.utils.loaders import load_yaml_config from pywb.utils.geventserver import GeventServer from pywb.utils.io import StreamIter +from pywb.utils.wbexception import NotFoundException, WbException from pywb.warcserver.warcserver import WarcServer from pywb.rewrite.templateview import BaseInsertView from pywb.apps.static_handler import StaticHandler -from pywb.apps.rewriterapp import RewriterApp, UpstreamException +from pywb.apps.rewriterapp import RewriterApp from pywb.apps.wbrequestresponse import WbResponse import os @@ -382,8 +383,8 @@ class FrontEndApp(object): wb_url_str = wb_url_str.replace('timemap/{0}/'.format(timemap_output), '') try: response = self.rewriterapp.render_content(wb_url_str, metadata, environ) - except UpstreamException as ue: - response = self.rewriterapp.handle_error(environ, ue) + except WbException as wbe: + response = self.rewriterapp.handle_error(environ, wbe) raise HTTPException(response=response) return response @@ -446,7 +447,7 @@ class FrontEndApp(object): :param dict environ: The WSGI environment dictionary for the request :param str msg: The error message """ - raise NotFound(response=self.rewriterapp._error_response(environ, msg)) + raise NotFound(response=self.rewriterapp._error_response(environ, NotFoundException(msg))) def _check_refer_redirect(self, environ): """Returns a WbResponse for a HTTP 307 redirection if the HTTP referer header is the same as the HTTP host header @@ -513,7 +514,7 @@ class FrontEndApp(object): if self.debug: traceback.print_exc() - response = self.rewriterapp._error_response(environ, 'Internal Error: ' + str(e), '500 Server Error') + response = self.rewriterapp._error_response(environ, WbException('Internal Error: ' + str(e))) return response(environ, start_response) @classmethod diff --git a/pywb/apps/rewriterapp.py b/pywb/apps/rewriterapp.py index 68a64df3..a6565993 100644 --- a/pywb/apps/rewriterapp.py +++ b/pywb/apps/rewriterapp.py @@ -27,7 +27,11 @@ from pywb.warcserver.index.cdxobject import CDXObject class UpstreamException(WbException): def __init__(self, status_code, url, details): super(UpstreamException, self).__init__(url=url, msg=details) - self.status_code = status_code + self._status_code = status_code + + @property + def status_code(self): + return self._status_code # ============================================================================ @@ -502,24 +506,24 @@ class RewriterApp(object): top_url += wb_url.to_str(mod='') return top_url - def handle_error(self, environ, ue): - if ue.status_code == 404: - return self._not_found_response(environ, ue.url) - + def handle_error(self, environ, wbe): + if wbe.status_code == 404: + return self._not_found_response(environ, wbe.url) else: - status = str(ue.status_code) + ' ' + HTTP_STATUS_CODES.get(ue.status_code, 'Unknown Error') - return self._error_response(environ, ue.url, ue.msg, - status=status) + return self._error_response(environ, wbe) def _not_found_response(self, environ, url): resp = self.not_found_view.render_to_string(environ, url=url) return WbResponse.text_response(resp, status='404 Not Found', content_type='text/html') - def _error_response(self, environ, msg='', details='', status='404 Not Found'): + def _error_response(self, environ, wbe): + status = wbe.status() + resp = self.error_view.render_to_string(environ, - err_msg=msg, - err_details=details) + err_msg=wbe.url, + err_details=wbe.msg, + err_status=wbe.status_code) return WbResponse.text_response(resp, status=status, content_type='text/html') diff --git a/pywb/default_config.yaml b/pywb/default_config.yaml index 812ac9b3..5bd54aa1 100644 --- a/pywb/default_config.yaml +++ b/pywb/default_config.yaml @@ -3,8 +3,11 @@ collections_root: collections # Per-Collection Paths archive_paths: archive index_paths: indexes +acl_paths: acl static_path: static +default_access: allow + templates_dir: templates # Template HTML diff --git a/pywb/utils/binsearch.py b/pywb/utils/binsearch.py index 6f0b431e..add2f423 100644 --- a/pywb/utils/binsearch.py +++ b/pywb/utils/binsearch.py @@ -89,9 +89,13 @@ def linearsearch(iter_, key, prev_size=0, compare_func=cmp): matched = True break - # no matches, so return empty iterator + # no matches, so pop last line, but return rest of prev lines, if any if not matched: - return iter([]) + if not prev_size or len(prev_deque) <= 1: + return iter([]) + + prev_deque.popleft() + return iter(prev_deque) return itertools.chain(prev_deque, iter_) diff --git a/pywb/utils/merge.py b/pywb/utils/merge.py new file mode 100644 index 00000000..78c18d3f --- /dev/null +++ b/pywb/utils/merge.py @@ -0,0 +1,112 @@ +import sys + +if sys.version_info >= (3, 5): #pragma: no cover + from heapq import merge +else: #pragma: no cover + # ported from python 3.5 heapq merge with reverse=True support + + from heapq import heapify, heappop, heapreplace + from heapq import _heapify_max, _siftup_max + + def _heappop_max(heap): + """Maxheap version of a heappop.""" + lastelt = heap.pop() # raises appropriate IndexError if heap is empty + if heap: + returnitem = heap[0] + heap[0] = lastelt + _siftup_max(heap, 0) + return returnitem + return lastelt + + def _heapreplace_max(heap, item): + """Maxheap version of a heappop followed by a heappush.""" + returnitem = heap[0] # raises appropriate IndexError if heap is empty + heap[0] = item + _siftup_max(heap, 0) + return returnitem + + def _get_next_iter(it): + return it.__next__ if hasattr(it, '__next__') else it.next + + def merge(*iterables, **kwargs): + '''Merge multiple sorted inputs into a single sorted output. + Similar to sorted(itertools.chain(*iterables)) but returns a generator, + does not pull the data into memory all at once, and assumes that each of + the input streams is already sorted (smallest to largest). + >>> list(merge([1,3,5,7], [0,2,4,8], [5,10,15,20], [], [25])) + [0, 1, 2, 3, 4, 5, 5, 7, 8, 10, 15, 20, 25] + If *key* is not None, applies a key function to each element to determine + its sort order. + >>> list(merge(['dog', 'horse'], ['cat', 'fish', 'kangaroo'], key=len)) + ['dog', 'cat', 'fish', 'horse', 'kangaroo'] + ''' + + key = kwargs.get('key', None) + reverse = kwargs.get('reverse', False) + + h = [] + h_append = h.append + + if reverse: + _heapify = _heapify_max + _heappop = _heappop_max + _heapreplace = _heapreplace_max + direction = -1 + else: + _heapify = heapify + _heappop = heappop + _heapreplace = heapreplace + direction = 1 + + if key is None: + for order, it in enumerate(map(iter, iterables)): + try: + next = _get_next_iter(it) + h_append([next(), order * direction, next]) + except StopIteration: + pass + _heapify(h) + while len(h) > 1: + try: + while True: + value, order, next = s = h[0] + yield value + s[0] = next() # raises StopIteration when exhausted + _heapreplace(h, s) # restore heap condition + except StopIteration: + _heappop(h) # remove empty iterator + if h: + # fast case when only a single iterator remains + value, order, next = h[0] + yield value + for v in next.__self__: + yield v + + return + + for order, it in enumerate(map(iter, iterables)): + try: + next = _get_next_iter(it) + value = next() + h_append([key(value), order * direction, value, next]) + except StopIteration: + pass + _heapify(h) + while len(h) > 1: + try: + while True: + key_value, order, value, next = s = h[0] + yield value + value = next() + s[0] = key(value) + s[2] = value + _heapreplace(h, s) + except StopIteration: + _heappop(h) + if h: + key_value, order, value, next = h[0] + yield value + for v in next.__self__: + yield v + + diff --git a/pywb/utils/test/test_binsearch.py b/pywb/utils/test/test_binsearch.py index 4666b497..18e1cb19 100644 --- a/pywb/utils/test/test_binsearch.py +++ b/pywb/utils/test/test_binsearch.py @@ -23,6 +23,9 @@ org,iana)/domains/root/db 20140126200928 http://www.iana.org/domains/root/db tex >>> print_binsearch_results('org,iana)/time-zones', iter_exact) org,iana)/time-zones 20140126200737 http://www.iana.org/time-zones text/html 200 4Z27MYWOSXY2XDRAJRW7WRMT56LXDD4R - - 2449 569675 iana.warc.gz +>>> print_binsearch_results_range('org,iana)/time-zones', 'org,iana)/time-zones!', iter_range) +org,iana)/time-zones 20140126200737 http://www.iana.org/time-zones text/html 200 4Z27MYWOSXY2XDRAJRW7WRMT56LXDD4R - - 2449 569675 iana.warc.gz + # Exact search -- no matches >>> print_binsearch_results('org,iaana)/', iter_exact) >>> print_binsearch_results('org,ibna)/', iter_exact) @@ -74,6 +77,7 @@ org,iana)/time-zones 20140126200737 http://www.iana.org/time-zones text/html 200 #================================================================= import os from pywb.utils.binsearch import iter_prefix, iter_exact, iter_range +from pywb.utils.merge import merge from pywb import get_test_dir @@ -91,6 +95,22 @@ def print_binsearch_results_range(key, end_key, iter_func, prev_size=0): print(line.decode('utf-8')) + +def test_rev_merge(): + with open(test_cdx_dir + 'iana.cdx', 'rb') as cdx: + lines1 = cdx.readlines() + + with open(test_cdx_dir + 'dupes.cdx', 'rb') as cdx: + lines2 = cdx.readlines() + + + # check reverse merge: verify merging of lists, than reversing + # eqauls merging with reverse=True of reversed lists + assert (list(reversed(list(merge(lines1, lines2)))) == + list(merge(reversed(lines1), reversed(lines2), reverse=True))) + + + if __name__ == "__main__": import doctest doctest.testmod() diff --git a/pywb/utils/wbexception.py b/pywb/utils/wbexception.py index b7cd00da..40648bbf 100644 --- a/pywb/utils/wbexception.py +++ b/pywb/utils/wbexception.py @@ -1,3 +1,4 @@ +from werkzeug.http import HTTP_STATUS_CODES #================================================================= @@ -7,6 +8,13 @@ class WbException(Exception): self.msg = msg self.url = url + @property + def status_code(self): + return 500 + + def status(self): + return str(self.status_code) + ' ' + HTTP_STATUS_CODES.get(self.status_code, 'Unknown Error') + def __repr__(self): return "{0}('{1}',)".format(self.__class__.__name__, self.msg) @@ -17,25 +25,28 @@ class WbException(Exception): #================================================================= class AccessException(WbException): - def status(self): - return '403 Access Denied' + @property + def status_code(self): + return 451 #================================================================= class BadRequestException(WbException): - def status(self): - return '400 Bad Request' + @property + def status_code(self): + return 400 #================================================================= class NotFoundException(WbException): - def status(self): - return '404 Not Found' + @property + def status_code(self): + return 404 #================================================================= class LiveResourceException(WbException): - def status(self): - return '400 Bad Live Resource' - + @property + def status_code(self): + return 400 diff --git a/pywb/warcserver/access_checker.py b/pywb/warcserver/access_checker.py new file mode 100644 index 00000000..932ccb6b --- /dev/null +++ b/pywb/warcserver/access_checker.py @@ -0,0 +1,91 @@ +from pywb.warcserver.index.indexsource import FileIndexSource +from pywb.warcserver.index.aggregator import DirectoryIndexSource, CacheDirectoryMixin + +from pywb.utils.binsearch import search +from pywb.utils.merge import merge + +import os + + +# ============================================================================ +class FileAccessIndexSource(FileIndexSource): + @staticmethod + def rev_cmp(a, b): + return (a < b) - (a > b) + + def _get_gen(self, fh, params): + return search(fh, params['key'], prev_size=1, compare_func=self.rev_cmp) + + +# ============================================================================ +class DirectoryAccessSource(DirectoryIndexSource): + INDEX_SOURCES = [('.aclj', FileAccessIndexSource)] + + def _merge(self, iter_list): + return merge(*(iter_list), reverse=True) + + +# ============================================================================ +class CacheDirectoryAccessSource(CacheDirectoryMixin, DirectoryAccessSource): + pass + + +# ============================================================================ +class AccessChecker(object): + def __init__(self, access_source_file, default_access='allow'): + if isinstance(access_source_file, str): + self.aggregator = self.create_access_aggregator(access_source_file) + else: + self.aggregator = access_source_file + + self.default_rule = {'urlkey': '', 'access': default_access} + + def create_access_aggregator(self, filename): + if os.path.isdir(filename): + return CacheDirectoryAccessSource(filename) + + elif os.path.isfile(filename): + return FileAccessIndexSource(filename) + + else: + raise Exception('Invalid Access Source: ' + filename) + + def find_access_rule(self, url, ts=None, urlkey=None): + params = {'url': url, 'urlkey': urlkey} + cdx_iter, errs = self.aggregator(params) + if errs: + print(errs) + + key = params['key'].decode('utf-8') + + for cdx in cdx_iter: + if 'urlkey' not in cdx: + continue + + if key.startswith(cdx['urlkey']): + return cdx + + return self.default_rule + + def __call__(self, res): + cdx_iter, errs = res + return self.wrap_iter(cdx_iter), errs + + def wrap_iter(self, cdx_iter): + last_rule = None + last_url = None + + for cdx in cdx_iter: + url = cdx.get('url') + # if no url, possible idx or other object, don't apply any checks and pass through + if not url: + yield cdx + continue + + rule = self.find_access_rule(url, cdx.get('timestamp'), cdx.get('urlkey')) + access = rule.get('access', 'exclude') + if access == 'exclude': + continue + + cdx['access'] = access + yield cdx diff --git a/pywb/warcserver/basewarcserver.py b/pywb/warcserver/basewarcserver.py index a5082e29..0ef5f448 100644 --- a/pywb/warcserver/basewarcserver.py +++ b/pywb/warcserver/basewarcserver.py @@ -1,6 +1,8 @@ from pywb.warcserver.inputrequest import DirectWSGIInputRequest, POSTInputRequest from pywb.utils.format import query_to_dict +from pywb.utils.wbexception import AccessException + from werkzeug.routing import Map, Rule from werkzeug.exceptions import HTTPException @@ -90,6 +92,12 @@ class BaseWarcServer(object): start_response('200 OK', list(out_headers.items())) return res + except AccessException as ae: + out_headers = {} + res = self.json_encode(ae.msg, out_headers) + start_response(ae.status(), list(out_headers.items())) + return res + except Exception as e: if self.debug: traceback.print_exc() @@ -107,6 +115,7 @@ class BaseWarcServer(object): def send_error(self, errs, start_response, message='No Resource Found', status=404): + last_exc = errs.pop('last_exc', None) if last_exc: if self.debug: diff --git a/pywb/warcserver/handlers.py b/pywb/warcserver/handlers.py index cf1bdf4a..6cd2f4b2 100644 --- a/pywb/warcserver/handlers.py +++ b/pywb/warcserver/handlers.py @@ -1,4 +1,4 @@ -from pywb.utils.wbexception import BadRequestException, WbException +from pywb.utils.wbexception import BadRequestException, WbException, AccessException from pywb.utils.wbexception import NotFoundException from pywb.utils.memento import MementoUtils @@ -48,6 +48,7 @@ class IndexHandler(object): self.index_source = index_source self.opts = opts or {} self.fuzzy = FuzzyMatcher(kwargs.get('rules_file')) + self.access_checker = kwargs.get('access_checker') def get_supported_modes(self): return dict(modes=['list_sources', 'index']) @@ -62,7 +63,12 @@ class IndexHandler(object): if input_req: params['alt_url'] = input_req.include_method_query(url) - return self.fuzzy(self.index_source, params) + cdx_iter = self.fuzzy(self.index_source, params) + + if self.access_checker: + cdx_iter = self.access_checker(cdx_iter) + + return cdx_iter def __call__(self, params): mode = params.get('mode', 'index') @@ -101,8 +107,8 @@ class IndexHandler(object): #============================================================================= class ResourceHandler(IndexHandler): - def __init__(self, index_source, resource_loaders, rules_file=None): - super(ResourceHandler, self).__init__(index_source, rules_file=rules_file) + def __init__(self, index_source, resource_loaders, **kwargs): + super(ResourceHandler, self).__init__(index_source, **kwargs) self.resource_loaders = resource_loaders def get_supported_modes(self): @@ -121,6 +127,11 @@ class ResourceHandler(IndexHandler): last_exc = None for cdx in cdx_iter: + if cdx.get('access', 'allow') != 'allow': + raise AccessException(msg={'access': cdx['access'], + 'access_status': cdx.get('access_status', 451)}, + url=cdx['url']) + for loader in self.resource_loaders: try: out_headers, resp = loader(cdx, params) @@ -141,13 +152,12 @@ class ResourceHandler(IndexHandler): #============================================================================= class DefaultResourceHandler(ResourceHandler): def __init__(self, index_source, warc_paths='', forward_proxy_prefix='', - rules_file=''): + **kwargs): loaders = [WARCPathLoader(warc_paths, index_source), LiveWebLoader(forward_proxy_prefix), VideoLoader() ] - super(DefaultResourceHandler, self).__init__(index_source, loaders, - rules_file=rules_file) + super(DefaultResourceHandler, self).__init__(index_source, loaders, **kwargs) #============================================================================= diff --git a/pywb/warcserver/index/aggregator.py b/pywb/warcserver/index/aggregator.py index 16cd25e9..032ad1d0 100644 --- a/pywb/warcserver/index/aggregator.py +++ b/pywb/warcserver/index/aggregator.py @@ -90,10 +90,13 @@ class BaseAggregator(object): if len(iter_list) <= 1: cdx_iter = iter_list[0] if iter_list else iter([]) else: - cdx_iter = merge(*(iter_list)) + cdx_iter = self._merge(iter_list) return cdx_iter, err_list + def _merge(self, iter_list): + return merge(*(iter_list)) + def _on_source_error(self, name): #pragma: no cover pass @@ -257,6 +260,11 @@ class GeventTimeoutAggregator(TimeoutMixin, GeventMixin, BaseSourceListAggregato #============================================================================= class BaseDirectoryIndexSource(BaseAggregator): + INDEX_SOURCES = [ + (FileIndexSource.CDX_EXT, FileIndexSource), + (ZipNumIndexSource.IDX_EXT, ZipNumIndexSource) + ] + def __init__(self, base_prefix, base_dir='', name='', config=None): self.base_prefix = base_prefix self.base_dir = base_dir @@ -280,13 +288,13 @@ class BaseDirectoryIndexSource(BaseAggregator): def _load_files_single_dir(self, the_dir): for name in os.listdir(the_dir): - filename = os.path.join(the_dir, name) + for ext, cls in self.INDEX_SOURCES: + if not name.endswith(ext): + continue - is_cdx = filename.endswith(FileIndexSource.CDX_EXT) - is_zip = filename.endswith(ZipNumIndexSource.IDX_EXT) + filename = os.path.join(the_dir, name) - if is_cdx or is_zip: - #print('Adding ' + filename) + #print('Adding ' + filename) rel_path = os.path.relpath(the_dir, self.base_prefix) if rel_path == '.': full_name = name @@ -296,10 +304,7 @@ class BaseDirectoryIndexSource(BaseAggregator): if self.name: full_name = self.name + ':' + full_name - if is_cdx: - index_src = FileIndexSource(filename) - else: - index_src = ZipNumIndexSource(filename, self.config) + index_src = cls(filename, self.config) yield full_name, index_src @@ -341,9 +346,9 @@ class DirectoryIndexSource(SeqAggMixin, BaseDirectoryIndexSource): #============================================================================= -class CacheDirectoryIndexSource(DirectoryIndexSource): +class CacheDirectoryMixin(object): def __init__(self, *args, **kwargs): - super(CacheDirectoryIndexSource, self).__init__(*args, **kwargs) + super(CacheDirectoryMixin, self).__init__(*args, **kwargs) self.cached_file_list = {} def _load_files_single_dir(self, the_dir): @@ -360,12 +365,17 @@ class CacheDirectoryIndexSource(DirectoryIndexSource): print('Dir {0} unchanged'.format(the_dir)) return files - files = super(CacheDirectoryIndexSource, self)._load_files_single_dir(the_dir) + files = super(CacheDirectoryMixin, self)._load_files_single_dir(the_dir) files = list(files) self.cached_file_list[the_dir] = (stat, files) return files +#============================================================================= +class CacheDirectoryIndexSource(CacheDirectoryMixin, DirectoryIndexSource): + pass + + #============================================================================= class BaseRedisMultiKeyIndexSource(BaseAggregator, RedisIndexSource): def _iter_sources(self, params): diff --git a/pywb/warcserver/index/cdxobject.py b/pywb/warcserver/index/cdxobject.py index ed5b79c3..8f050cb1 100644 --- a/pywb/warcserver/index/cdxobject.py +++ b/pywb/warcserver/index/cdxobject.py @@ -36,8 +36,9 @@ ORIG_FILENAME = 'orig.filename' #================================================================= class CDXException(WbException): - def status(self): - return '400 Bad Request' + @property + def status_code(self): + return 400 #================================================================= @@ -132,7 +133,7 @@ class CDXObject(OrderedDict): v = quote(v.encode('utf-8'), safe=':/') if n != 'filename': - v = to_native_str(v, 'utf-8') + v = to_native_str(v, 'utf-8') or v self[n] = v diff --git a/pywb/warcserver/index/indexsource.py b/pywb/warcserver/index/indexsource.py index efd5e1c5..55038616 100644 --- a/pywb/warcserver/index/indexsource.py +++ b/pywb/warcserver/index/indexsource.py @@ -67,21 +67,26 @@ class BaseIndexSource(object): class FileIndexSource(BaseIndexSource): CDX_EXT = ('.cdx', '.cdxj') - def __init__(self, filename): + def __init__(self, filename, config=None): self.filename_template = filename + def _do_open(self, filename): + try: + return open(filename, 'rb') + except IOError: + raise NotFoundException(filename) + + def _get_gen(self, fh, params): + return iter_range(fh, params['key'], params['end_key']) + def load_index(self, params): filename = res_template(self.filename_template, params) - try: - fh = open(filename, 'rb') - except IOError: - raise NotFoundException(filename) + fh = self._do_open(filename) def do_load(fh): with fh: - gen = iter_range(fh, params['key'], params['end_key']) - for line in gen: + for line in self._get_gen(fh, params): yield CDXObject(line) return do_load(fh) diff --git a/pywb/warcserver/test/test_access.py b/pywb/warcserver/test/test_access.py new file mode 100644 index 00000000..db2bcab1 --- /dev/null +++ b/pywb/warcserver/test/test_access.py @@ -0,0 +1,116 @@ +from mock import patch +import shutil +import os + +from pywb.warcserver.index.aggregator import SimpleAggregator +from pywb.warcserver.access_checker import FileAccessIndexSource, AccessChecker, DirectoryAccessSource + +from pywb.warcserver.test.testutils import to_path, TempDirTests, BaseTestClass +from pywb import get_test_dir + +TEST_EXCL_PATH = to_path(get_test_dir() + '/access/') + + +# ============================================================================ +class TestAccess(TempDirTests, BaseTestClass): + def test_allows_only_default_block(self): + agg = SimpleAggregator({'source': FileAccessIndexSource(TEST_EXCL_PATH + 'allows.aclj')}) + access = AccessChecker(agg, default_access='block') + + edx = access.find_access_rule('http://example.net') + assert edx['urlkey'] == 'net,' + + edx = access.find_access_rule('http://foo.example.net/abc') + assert edx['urlkey'] == 'net,' + + edx = access.find_access_rule('https://example.net/test/') + assert edx['urlkey'] == 'net,example)/test' + + edx = access.find_access_rule('https://example.org/') + assert edx['urlkey'] == '' + assert edx['access'] == 'block' + + edx = access.find_access_rule('https://abc.domain.net/path') + assert edx['urlkey'] == 'net,domain,' + + edx = access.find_access_rule('https://domain.neta/path') + assert edx['urlkey'] == '' + assert edx['access'] == 'block' + + def test_blocks_only(self): + agg = SimpleAggregator({'source': FileAccessIndexSource(TEST_EXCL_PATH + 'blocks.aclj')}) + access = AccessChecker(agg) + + edx = access.find_access_rule('https://example.com/foo') + assert edx['urlkey'] == 'com,example)/foo' + assert edx['access'] == 'exclude' + + edx = access.find_access_rule('https://example.com/food') + assert edx['urlkey'] == 'com,example)/foo' + assert edx['access'] == 'exclude' + + edx = access.find_access_rule('https://example.com/foo/path') + assert edx['urlkey'] == 'com,example)/foo' + assert edx['access'] == 'exclude' + + edx = access.find_access_rule('https://example.net/abc/path/other') + assert edx['urlkey'] == 'net,example)/abc/path' + assert edx['access'] == 'block' + + edx = access.find_access_rule('https://example.net/fo') + assert edx['urlkey'] == '' + assert edx['access'] == 'allow' + + def test_single_file_combined(self): + agg = SimpleAggregator({'source': FileAccessIndexSource(TEST_EXCL_PATH + 'list1.aclj')}) + access = AccessChecker(agg, default_access='block') + + edx = access.find_access_rule('http://example.com/abc/page.html') + assert edx['urlkey'] == 'com,example)/abc/page.html' + assert edx['access'] == 'allow' + + edx = access.find_access_rule('http://example.com/abc/page.htm') + assert edx['urlkey'] == 'com,example)/abc' + assert edx['access'] == 'block' + + edx = access.find_access_rule('http://example.com/abc/') + assert edx['urlkey'] == 'com,example)/abc' + assert edx['access'] == 'block' + + edx = access.find_access_rule('http://foo.example.com/') + assert edx['urlkey'] == 'com,example,' + assert edx['access'] == 'exclude' + + edx = access.find_access_rule('http://example.com/') + assert edx['urlkey'] == 'com,' + assert edx['access'] == 'allow' + + edx = access.find_access_rule('foo.net') + assert edx['urlkey'] == '' + assert edx['access'] == 'block' + + edx = access.find_access_rule('https://example.net/abc/path/other') + assert edx['urlkey'] == '' + assert edx['access'] == 'block' + + def test_excludes_dir(self): + agg = DirectoryAccessSource(TEST_EXCL_PATH) + + access = AccessChecker(agg, default_access='block') + + edx = access.find_access_rule('http://example.com/') + assert edx['urlkey'] == 'com,example)/' + assert edx['access'] == 'allow' + + edx = access.find_access_rule('http://example.bo') + assert edx['urlkey'] == 'bo,example)/' + assert edx['access'] == 'exclude' + + edx = access.find_access_rule('https://example.com/foo/path') + assert edx['urlkey'] == 'com,example)/foo' + assert edx['access'] == 'exclude' + + edx = access.find_access_rule('https://example.net/abc/path/other') + assert edx['urlkey'] == 'net,example)/abc/path' + assert edx['access'] == 'block' + diff --git a/pywb/warcserver/test/test_warcserver_config.yaml b/pywb/warcserver/test/test_warcserver_config.yaml index c6b0dfb6..a46adfb7 100644 --- a/pywb/warcserver/test/test_warcserver_config.yaml +++ b/pywb/warcserver/test/test_warcserver_config.yaml @@ -1,3 +1,5 @@ +debug: true + collections: # Live Index diff --git a/pywb/warcserver/warcserver.py b/pywb/warcserver/warcserver.py index c3ed1ef8..3c346a53 100644 --- a/pywb/warcserver/warcserver.py +++ b/pywb/warcserver/warcserver.py @@ -14,6 +14,8 @@ from pywb.warcserver.index.indexsource import XmlQueryIndexSource from pywb.warcserver.index.zipnum import ZipNumIndexSource +from pywb.warcserver.access_checker import AccessChecker, CacheDirectoryAccessSource + from pywb import DEFAULT_CONFIG from six import iteritems, iterkeys, itervalues @@ -60,6 +62,9 @@ class WarcServer(BaseWarcServer): self.root_dir = self.config.get('collections_root', '') self.index_paths = self.init_paths('index_paths') self.archive_paths = self.init_paths('archive_paths', self.root_dir) + self.acl_paths = self.init_paths('acl_paths') + + self.default_access = self.config.get('default_access') self.rules_file = self.config.get('rules_file', '') @@ -103,8 +108,12 @@ class WarcServer(BaseWarcServer): base_dir=self.index_paths, config=self.config) + access_checker = AccessChecker(CacheDirectoryAccessSource(self.acl_paths), + self.default_access) + return DefaultResourceHandler(dir_source, self.archive_paths, - rules_file=self.rules_file) + rules_file=self.rules_file, + access_checker=access_checker) def list_fixed_routes(self): return list(self.fixed_routes.keys()) @@ -156,11 +165,15 @@ class WarcServer(BaseWarcServer): if isinstance(coll_config, str): index = coll_config archive_paths = None + acl_paths = None + default_access = self.default_access elif isinstance(coll_config, dict): index = coll_config.get('index') if not index: index = coll_config.get('index_paths') archive_paths = coll_config.get('archive_paths') + acl_paths = coll_config.get('acl_paths') + default_access = coll_config.get('default_access', self.default_access) else: raise Exception('collection config must be string or dict') @@ -186,8 +199,13 @@ class WarcServer(BaseWarcServer): if not archive_paths: archive_paths = self.config.get('archive_paths') + access_checker = None + if acl_paths: + access_checker = AccessChecker(acl_paths, default_access) + return DefaultResourceHandler(agg, archive_paths, - rules_file=self.rules_file) + rules_file=self.rules_file, + access_checker=access_checker) def init_sequence(self, coll_name, seq_config): if not isinstance(seq_config, list): diff --git a/sample_archive/access/allows.aclj b/sample_archive/access/allows.aclj new file mode 100644 index 00000000..aa8260fc --- /dev/null +++ b/sample_archive/access/allows.aclj @@ -0,0 +1,3 @@ +net,example)/test - {"access": "allow"} +net,domain, - {"access": "allow"} +net, - {"access": "allow"} diff --git a/sample_archive/access/blocks.aclj b/sample_archive/access/blocks.aclj new file mode 100644 index 00000000..41102212 --- /dev/null +++ b/sample_archive/access/blocks.aclj @@ -0,0 +1,3 @@ +net,example)/abc/path - {"access": "block"} +com,example)/foo - {"access": "exclude"} + diff --git a/sample_archive/access/list1.aclj b/sample_archive/access/list1.aclj new file mode 100644 index 00000000..fcfd17a7 --- /dev/null +++ b/sample_archive/access/list1.aclj @@ -0,0 +1,9 @@ +com,example, - {"access": "exclude"} +com,example)/abc/page.html - {"access": "allow"} +com,example)/abc/ef - {"access": "block"} +com,example)/abc/cd - {"access": "block"} +com,example)/abc/ab - {"access": "block"} +com,example)/abc - {"access": "block"} +com,exampke)/ - {"access": "allow"} +com,ex)/ - {"access": "exclude"} +com, - {"access": "allow"} diff --git a/sample_archive/access/list2.aclj b/sample_archive/access/list2.aclj new file mode 100644 index 00000000..249aa7bb --- /dev/null +++ b/sample_archive/access/list2.aclj @@ -0,0 +1,2 @@ +com,example)/ - {"access": "allow"} +bo,example)/ - {"access": "exclude"} diff --git a/sample_archive/access/pywb.aclj b/sample_archive/access/pywb.aclj new file mode 100644 index 00000000..b8895914 --- /dev/null +++ b/sample_archive/access/pywb.aclj @@ -0,0 +1,5 @@ +org,iana)/about - {"access": "block"} +org,iana)/_css/2013.1/fonts/opensans-semibold.ttf - {"access": "allow"} +org,iana)/_css - {"access": "exclude"} +org,example)/?example=1 - {"access": "block"} +org,iana)/ - {"access": "exclude"} diff --git a/tests/config_test_access.yaml b/tests/config_test_access.yaml new file mode 100644 index 00000000..ed6e6d04 --- /dev/null +++ b/tests/config_test_access.yaml @@ -0,0 +1,9 @@ +debug: true + +collections: + pywb: + index_paths: ./sample_archive/cdx/ + archive_paths: ./sample_archive/warcs/ + acl_paths: ./sample_archive/access/ + default_access: block + diff --git a/tests/test_acl.py b/tests/test_acl.py new file mode 100644 index 00000000..09d2d630 --- /dev/null +++ b/tests/test_acl.py @@ -0,0 +1,55 @@ +from .base_config_test import BaseConfigTest, fmod + +import webtest +import os + +from six.moves.urllib.parse import urlencode + + +# ============================================================================ +class TestACLApp(BaseConfigTest): + @classmethod + def setup_class(cls): + super(TestACLApp, cls).setup_class('config_test_access.yaml') + + def query(self, url, is_error=False, **params): + params['url'] = url + return self.testapp.get('/pywb/cdx?' + urlencode(params, doseq=1), expect_errors=is_error) + + def test_excluded_url(self): + resp = self.query('http://www.iana.org/') + + assert len(resp.text.splitlines()) == 0 + + self.testapp.get('/pywb/mp_/http://www.iana.org/', status=404) + + def test_blocked_url(self): + resp = self.query('http://www.iana.org/about/') + + assert len(resp.text.splitlines()) == 1 + + resp = self.testapp.get('/pywb/mp_/http://www.iana.org/about/', status=451) + + assert 'Access Blocked' in resp.text + + def test_allowed_more_specific(self): + resp = self.query('http://www.iana.org/_css/2013.1/fonts/opensans-semibold.ttf') + + assert resp.status_code == 200 + + assert len(resp.text.splitlines()) > 0 + + resp = self.testapp.get('/pywb/mp_/http://www.iana.org/_css/2013.1/fonts/opensans-semibold.ttf', status=200) + + assert resp.content_type == 'application/octet-stream' + + def test_default_rule_blocked(self): + resp = self.query('http://httpbin.org/anything/resource.json') + + assert len(resp.text.splitlines()) > 0 + + resp = self.testapp.get('/pywb/mp_/http://httpbin.org/anything/resource.json', status=451) + + assert 'Access Blocked' in resp.text + +