diff --git a/tests/test_warcprox.py b/tests/test_warcprox.py index 95532d4..0c250e7 100755 --- a/tests/test_warcprox.py +++ b/tests/test_warcprox.py @@ -96,6 +96,14 @@ logging.getLogger("requests.packages.urllib3").setLevel(logging.WARN) warnings.simplefilter("ignore", category=requests.packages.urllib3.exceptions.InsecureRequestWarning) warnings.simplefilter("ignore", category=requests.packages.urllib3.exceptions.InsecurePlatformWarning) +def wait(callback, timeout=10): + start = time.time() + while time.time() - start < timeout: + if callback(): + return + time.sleep(0.1) + raise Exception('timed out waiting for %s to return truthy' % callback) + # monkey patch dns lookup so we can test domain inheritance on localhost orig_getaddrinfo = socket.getaddrinfo orig_gethostbyname = socket.gethostbyname @@ -339,9 +347,6 @@ def warcprox_(request): logging.info('changing to working directory %r', work_dir) os.chdir(work_dir) - # we can't wait around all day in the tests - warcprox.BaseBatchPostfetchProcessor.MAX_BATCH_SEC = 0.5 - argv = ['warcprox', '--method-filter=GET', '--method-filter=POST', @@ -437,17 +442,9 @@ def test_httpds_no_proxy(http_daemon, https_daemon): assert response.headers['warcprox-test-header'] == 'c!' assert response.content == b'I am the warcprox test payload! dddddddddd!\n' -def _poll_playback_until(playback_proxies, url, status, timeout_sec): - start = time.time() - # check playback (warc writing is asynchronous, give it up to 10 sec) - while time.time() - start < timeout_sec: - response = requests.get(url, proxies=playback_proxies, verify=False) - if response.status_code == status: - break - time.sleep(0.5) - return response +def test_archive_and_playback_http_url(http_daemon, archiving_proxies, playback_proxies, warcprox_): + urls_before = warcprox_.proxy.running_stats.urls -def test_archive_and_playback_http_url(http_daemon, archiving_proxies, playback_proxies): url = 'http://localhost:{}/a/b'.format(http_daemon.server_port) # ensure playback fails before archiving @@ -461,12 +458,17 @@ def test_archive_and_playback_http_url(http_daemon, archiving_proxies, playback_ assert response.headers['warcprox-test-header'] == 'a!' assert response.content == b'I am the warcprox test payload! bbbbbbbbbb!\n' - response = _poll_playback_until(playback_proxies, url, status=200, timeout_sec=10) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 1) + + response = requests.get(url, proxies=playback_proxies, verify=False) assert response.status_code == 200 assert response.headers['warcprox-test-header'] == 'a!' assert response.content == b'I am the warcprox test payload! bbbbbbbbbb!\n' -def test_archive_and_playback_https_url(https_daemon, archiving_proxies, playback_proxies): +def test_archive_and_playback_https_url(https_daemon, archiving_proxies, playback_proxies, warcprox_): + urls_before = warcprox_.proxy.running_stats.urls + url = 'https://localhost:{}/c/d'.format(https_daemon.server_port) # ensure playback fails before archiving @@ -480,14 +482,19 @@ def test_archive_and_playback_https_url(https_daemon, archiving_proxies, playbac assert response.headers['warcprox-test-header'] == 'c!' assert response.content == b'I am the warcprox test payload! dddddddddd!\n' + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 1) + # test playback - response = _poll_playback_until(playback_proxies, url, status=200, timeout_sec=10) + response = requests.get(url, proxies=playback_proxies, verify=False) assert response.status_code == 200 assert response.headers['warcprox-test-header'] == 'c!' assert response.content == b'I am the warcprox test payload! dddddddddd!\n' # test dedup of same http url with same payload def test_dedup_http(http_daemon, warcprox_, archiving_proxies, playback_proxies): + urls_before = warcprox_.proxy.running_stats.urls + url = 'http://localhost:{}/e/f'.format(http_daemon.server_port) # ensure playback fails before archiving @@ -506,18 +513,14 @@ def test_dedup_http(http_daemon, warcprox_, archiving_proxies, playback_proxies) assert response.headers['warcprox-test-header'] == 'e!' assert response.content == b'I am the warcprox test payload! ffffffffff!\n' + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 1) # test playback - response = _poll_playback_until(playback_proxies, url, status=200, timeout_sec=10) + response = requests.get(url, proxies=playback_proxies, verify=False) assert response.status_code == 200 assert response.headers['warcprox-test-header'] == 'e!' assert response.content == b'I am the warcprox test payload! ffffffffff!\n' - # wait for writer thread to process - time.sleep(0.5) - while warcprox_.postfetch_chain_busy(): - time.sleep(0.5) - time.sleep(0.5) - # check in dedup db # {u'id': u'', u'url': u'https://localhost:62841/c/d', u'date': u'2013-11-22T00:14:37Z'} dedup_lookup = warcprox_.dedup_db.lookup( @@ -531,7 +534,7 @@ def test_dedup_http(http_daemon, warcprox_, archiving_proxies, playback_proxies) # need revisit to have a later timestamp than original, else playing # back the latest record might not hit the revisit - time.sleep(1.5) + time.sleep(1.1) # fetch & archive revisit response = requests.get(url, proxies=archiving_proxies, verify=False) @@ -539,11 +542,8 @@ def test_dedup_http(http_daemon, warcprox_, archiving_proxies, playback_proxies) assert response.headers['warcprox-test-header'] == 'e!' assert response.content == b'I am the warcprox test payload! ffffffffff!\n' - # wait for writer thread to process - time.sleep(0.5) - while warcprox_.postfetch_chain_busy(): - time.sleep(0.5) - time.sleep(0.5) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 2) # check in dedup db (no change from prev) dedup_lookup = warcprox_.dedup_db.lookup( @@ -554,7 +554,7 @@ def test_dedup_http(http_daemon, warcprox_, archiving_proxies, playback_proxies) # test playback logging.debug('testing playback of revisit of {}'.format(url)) - response = _poll_playback_until(playback_proxies, url, status=200, timeout_sec=10) + response = requests.get(url, proxies=playback_proxies, verify=False) assert response.status_code == 200 assert response.headers['warcprox-test-header'] == 'e!' assert response.content == b'I am the warcprox test payload! ffffffffff!\n' @@ -562,6 +562,8 @@ def test_dedup_http(http_daemon, warcprox_, archiving_proxies, playback_proxies) # test dedup of same https url with same payload def test_dedup_https(https_daemon, warcprox_, archiving_proxies, playback_proxies): + urls_before = warcprox_.proxy.running_stats.urls + url = 'https://localhost:{}/g/h'.format(https_daemon.server_port) # ensure playback fails before archiving @@ -580,18 +582,15 @@ def test_dedup_https(https_daemon, warcprox_, archiving_proxies, playback_proxie assert response.headers['warcprox-test-header'] == 'g!' assert response.content == b'I am the warcprox test payload! hhhhhhhhhh!\n' + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 1) + # test playback - response = _poll_playback_until(playback_proxies, url, status=200, timeout_sec=10) + response = requests.get(url, proxies=playback_proxies, verify=False) assert response.status_code == 200 assert response.headers['warcprox-test-header'] == 'g!' assert response.content == b'I am the warcprox test payload! hhhhhhhhhh!\n' - # wait for writer thread to process - time.sleep(0.5) - while warcprox_.postfetch_chain_busy(): - time.sleep(0.5) - time.sleep(0.5) - # check in dedup db # {u'id': u'', u'url': u'https://localhost:62841/c/d', u'date': u'2013-11-22T00:14:37Z'} dedup_lookup = warcprox_.dedup_db.lookup( @@ -605,7 +604,7 @@ def test_dedup_https(https_daemon, warcprox_, archiving_proxies, playback_proxie # need revisit to have a later timestamp than original, else playing # back the latest record might not hit the revisit - time.sleep(1.5) + time.sleep(1.1) # fetch & archive revisit response = requests.get(url, proxies=archiving_proxies, verify=False) @@ -613,11 +612,8 @@ def test_dedup_https(https_daemon, warcprox_, archiving_proxies, playback_proxie assert response.headers['warcprox-test-header'] == 'g!' assert response.content == b'I am the warcprox test payload! hhhhhhhhhh!\n' - # wait for writer thread to process - time.sleep(0.5) - while warcprox_.postfetch_chain_busy(): - time.sleep(0.5) - time.sleep(0.5) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 2) # check in dedup db (no change from prev) dedup_lookup = warcprox_.dedup_db.lookup( @@ -628,13 +624,15 @@ def test_dedup_https(https_daemon, warcprox_, archiving_proxies, playback_proxie # test playback logging.debug('testing playback of revisit of {}'.format(url)) - response = _poll_playback_until(playback_proxies, url, status=200, timeout_sec=10) + response = requests.get(url, proxies=playback_proxies, verify=False) assert response.status_code == 200 assert response.headers['warcprox-test-header'] == 'g!' assert response.content == b'I am the warcprox test payload! hhhhhhhhhh!\n' # XXX how to check dedup was used? def test_limits(http_daemon, warcprox_, archiving_proxies): + urls_before = warcprox_.proxy.running_stats.urls + url = 'http://localhost:{}/i/j'.format(http_daemon.server_port) request_meta = {"stats":{"buckets":["test_limits_bucket"]},"limits":{"test_limits_bucket/total/urls":10}} headers = {"Warcprox-Meta": json.dumps(request_meta)} @@ -644,11 +642,8 @@ def test_limits(http_daemon, warcprox_, archiving_proxies): assert response.headers['warcprox-test-header'] == 'i!' assert response.content == b'I am the warcprox test payload! jjjjjjjjjj!\n' - # wait for writer thread to process - time.sleep(0.5) - while warcprox_.postfetch_chain_busy(): - time.sleep(0.5) - time.sleep(0.5) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 1) for i in range(9): response = requests.get(url, proxies=archiving_proxies, headers=headers, stream=True) @@ -656,11 +651,8 @@ def test_limits(http_daemon, warcprox_, archiving_proxies): assert response.headers['warcprox-test-header'] == 'i!' assert response.content == b'I am the warcprox test payload! jjjjjjjjjj!\n' - # wait for writer thread to process - time.sleep(0.5) - while warcprox_.postfetch_chain_busy(): - time.sleep(0.5) - time.sleep(2.5) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 10) response = requests.get(url, proxies=archiving_proxies, headers=headers, stream=True) assert response.status_code == 420 @@ -671,6 +663,8 @@ def test_limits(http_daemon, warcprox_, archiving_proxies): assert response.raw.data == b"request rejected by warcprox: reached limit test_limits_bucket/total/urls=10\n" def test_return_capture_timestamp(http_daemon, warcprox_, archiving_proxies): + urls_before = warcprox_.proxy.running_stats.urls + url = 'http://localhost:{}/i/j'.format(http_daemon.server_port) request_meta = {"accept": ["capture-metadata"]} headers = {"Warcprox-Meta": json.dumps(request_meta)} @@ -686,7 +680,12 @@ def test_return_capture_timestamp(http_daemon, warcprox_, archiving_proxies): except ValueError: pytest.fail('Invalid capture-timestamp format %s', data['capture-timestamp']) + # wait for postfetch chain (or subsequent test could fail) + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 1) + def test_dedup_buckets(https_daemon, http_daemon, warcprox_, archiving_proxies, playback_proxies): + urls_before = warcprox_.proxy.running_stats.urls + url1 = 'http://localhost:{}/k/l'.format(http_daemon.server_port) url2 = 'https://localhost:{}/k/l'.format(https_daemon.server_port) @@ -697,15 +696,14 @@ def test_dedup_buckets(https_daemon, http_daemon, warcprox_, archiving_proxies, assert response.headers['warcprox-test-header'] == 'k!' assert response.content == b'I am the warcprox test payload! llllllllll!\n' - # wait for writer thread to process - time.sleep(0.5) - while warcprox_.postfetch_chain_busy(): - time.sleep(0.5) - time.sleep(0.5) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 1) # check url1 in dedup db bucket_a + # logging.info('looking up sha1:bc3fac8847c9412f49d955e626fb58a76befbf81 in bucket_a') dedup_lookup = warcprox_.dedup_db.lookup( b'sha1:bc3fac8847c9412f49d955e626fb58a76befbf81', bucket="bucket_a") + assert dedup_lookup assert dedup_lookup['url'] == url1.encode('ascii') assert re.match(br'^$', dedup_lookup['id']) assert re.match(br'^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z$', dedup_lookup['date']) @@ -724,11 +722,8 @@ def test_dedup_buckets(https_daemon, http_daemon, warcprox_, archiving_proxies, assert response.headers['warcprox-test-header'] == 'k!' assert response.content == b'I am the warcprox test payload! llllllllll!\n' - # wait for writer thread to process - time.sleep(0.5) - while warcprox_.postfetch_chain_busy(): - time.sleep(0.5) - time.sleep(0.5) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 2) # check url2 in dedup db bucket_b dedup_lookup = warcprox_.dedup_db.lookup( @@ -746,11 +741,8 @@ def test_dedup_buckets(https_daemon, http_daemon, warcprox_, archiving_proxies, assert response.headers['warcprox-test-header'] == 'k!' assert response.content == b'I am the warcprox test payload! llllllllll!\n' - # wait for writer thread to process - time.sleep(0.5) - while warcprox_.postfetch_chain_busy(): - time.sleep(0.5) - time.sleep(0.5) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 3) # archive url1 bucket_b headers = {"Warcprox-Meta": json.dumps({"warc-prefix":"test_dedup_buckets","captures-bucket":"bucket_b"})} @@ -759,11 +751,8 @@ def test_dedup_buckets(https_daemon, http_daemon, warcprox_, archiving_proxies, assert response.headers['warcprox-test-header'] == 'k!' assert response.content == b'I am the warcprox test payload! llllllllll!\n' - # wait for writer thread to process - time.sleep(0.5) - while warcprox_.postfetch_chain_busy(): - time.sleep(0.5) - time.sleep(0.5) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 4) # close the warc assert warcprox_.warc_writer_thread.writer_pool.warc_writers["test_dedup_buckets"] @@ -827,6 +816,8 @@ def test_dedup_buckets(https_daemon, http_daemon, warcprox_, archiving_proxies, fh.close() def test_block_rules(http_daemon, https_daemon, warcprox_, archiving_proxies): + urls_before = warcprox_.proxy.running_stats.urls + rules = [ { "domain": "localhost", @@ -863,6 +854,9 @@ def test_block_rules(http_daemon, https_daemon, warcprox_, archiving_proxies): url, proxies=archiving_proxies, headers=headers, stream=True) assert response.status_code == 200 + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 1) + # blocked by SURT_MATCH url = 'http://localhost:{}/fuh/guh'.format(http_daemon.server_port) response = requests.get( @@ -878,6 +872,9 @@ def test_block_rules(http_daemon, https_daemon, warcprox_, archiving_proxies): # 404 because server set up at the top of this file doesn't handle this url assert response.status_code == 404 + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 2) + # not blocked because surt scheme does not match (differs from heritrix # behavior where https urls are coerced to http surt form) url = 'https://localhost:{}/fuh/guh'.format(https_daemon.server_port) @@ -886,6 +883,9 @@ def test_block_rules(http_daemon, https_daemon, warcprox_, archiving_proxies): verify=False) assert response.status_code == 200 + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 3) + # blocked by blanket domain block url = 'http://bad.domain.com/' response = requests.get( @@ -938,6 +938,8 @@ def test_block_rules(http_daemon, https_daemon, warcprox_, archiving_proxies): def test_domain_doc_soft_limit( http_daemon, https_daemon, warcprox_, archiving_proxies): + urls_before = warcprox_.proxy.running_stats.urls + request_meta = { "stats": {"buckets": [{"bucket":"test_domain_doc_limit_bucket","tally-domains":["foo.localhost"]}]}, "soft-limits": {"test_domain_doc_limit_bucket:foo.localhost/total/urls":10}, @@ -952,11 +954,8 @@ def test_domain_doc_soft_limit( assert response.headers['warcprox-test-header'] == 'o!' assert response.content == b'I am the warcprox test payload! pppppppppp!\n' - # wait for writer thread to process - time.sleep(0.5) - while warcprox_.postfetch_chain_busy(): - time.sleep(0.5) - time.sleep(0.5) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 1) # make sure stats from different domain don't count url = 'http://bar.localhost:{}/o/p'.format(http_daemon.server_port) @@ -967,15 +966,10 @@ def test_domain_doc_soft_limit( assert response.headers['warcprox-test-header'] == 'o!' assert response.content == b'I am the warcprox test payload! pppppppppp!\n' - # wait for writer thread to process - time.sleep(0.5) - while warcprox_.postfetch_chain_busy(): - time.sleep(0.5) - # rethinkdb stats db update cycle is 2 seconds (at the moment anyway) - time.sleep(2.0) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 11) # (2) same host but different scheme and port: domain limit applies - # url = 'https://foo.localhost:{}/o/p'.format(https_daemon.server_port) response = requests.get( url, proxies=archiving_proxies, headers=headers, stream=True, @@ -994,12 +988,12 @@ def test_domain_doc_soft_limit( assert response.headers['warcprox-test-header'] == 'o!' assert response.content == b'I am the warcprox test payload! pppppppppp!\n' - # wait for writer thread to process - time.sleep(0.5) - while warcprox_.postfetch_chain_busy(): - time.sleep(0.5) - # rethinkdb stats db update cycle is 2 seconds (at the moment anyway) - time.sleep(2.0) + # wait for postfetch chain + time.sleep(3) + logging.info( + 'warcprox_.proxy.running_stats.urls - urls_before = %s', + warcprox_.proxy.running_stats.urls - urls_before) + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 19) # (10) response = requests.get( @@ -1009,12 +1003,8 @@ def test_domain_doc_soft_limit( assert response.headers['warcprox-test-header'] == 'o!' assert response.content == b'I am the warcprox test payload! pppppppppp!\n' - # wait for writer thread to process - time.sleep(0.5) - while warcprox_.postfetch_chain_busy(): - time.sleep(0.5) - # rethinkdb stats db update cycle is 2 seconds (at the moment anyway) - time.sleep(2.0) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 20) # (11) back to http, and this is the 11th request url = 'http://zuh.foo.localhost:{}/o/p'.format(http_daemon.server_port) @@ -1036,6 +1026,9 @@ def test_domain_doc_soft_limit( assert response.headers['warcprox-test-header'] == 'o!' assert response.content == b'I am the warcprox test payload! pppppppppp!\n' + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 21) + # https also blocked url = 'https://zuh.foo.localhost:{}/o/p'.format(https_daemon.server_port) response = requests.get( @@ -1062,6 +1055,8 @@ def test_domain_doc_soft_limit( def test_domain_data_soft_limit( http_daemon, https_daemon, warcprox_, archiving_proxies): + urls_before = warcprox_.proxy.running_stats.urls + # using idn request_meta = { "stats": {"buckets": [{"bucket":"test_domain_data_limit_bucket","tally-domains":['ÞzZ.LOCALhost']}]}, @@ -1077,12 +1072,8 @@ def test_domain_data_soft_limit( assert response.headers['warcprox-test-header'] == 'y!' assert response.content == b'I am the warcprox test payload! zzzzzzzzzz!\n' - # wait for writer thread to process - time.sleep(0.5) - while warcprox_.postfetch_chain_busy(): - time.sleep(0.5) - # rethinkdb stats db update cycle is 2 seconds (at the moment anyway) - time.sleep(2.0) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 1) # duplicate, does not count toward limit url = 'https://baz.Þzz.localhost:{}/y/z'.format(https_daemon.server_port) @@ -1093,12 +1084,8 @@ def test_domain_data_soft_limit( assert response.headers['warcprox-test-header'] == 'y!' assert response.content == b'I am the warcprox test payload! zzzzzzzzzz!\n' - # wait for writer thread to process - time.sleep(0.5) - while warcprox_.postfetch_chain_busy(): - time.sleep(0.5) - # rethinkdb stats db update cycle is 2 seconds (at the moment anyway) - time.sleep(2.0) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 2) # novel, pushes stats over the limit url = 'https://muh.XN--Zz-2Ka.locALHOst:{}/z/~'.format(https_daemon.server_port) @@ -1109,12 +1096,8 @@ def test_domain_data_soft_limit( assert response.headers['warcprox-test-header'] == 'z!' assert response.content == b'I am the warcprox test payload! ~~~~~~~~~~!\n' - # wait for writer thread to process - time.sleep(0.5) - while warcprox_.postfetch_chain_busy(): - time.sleep(0.5) - # rethinkdb stats db update cycle is 2 seconds (at the moment anyway) - time.sleep(2.0) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 3) # make sure limit doesn't get applied to a different host url = 'http://baz.localhost:{}/z/~'.format(http_daemon.server_port) @@ -1124,6 +1107,9 @@ def test_domain_data_soft_limit( assert response.headers['warcprox-test-header'] == 'z!' assert response.content == b'I am the warcprox test payload! ~~~~~~~~~~!\n' + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 4) + # blocked because we're over the limit now url = 'http://lOl.wHut.ÞZZ.lOcALHOst:{}/y/z'.format(http_daemon.server_port) response = requests.get( @@ -1155,7 +1141,9 @@ def test_domain_data_soft_limit( # connection to the internet, and relies on a third party site (facebook) being # up and behaving a certain way @pytest.mark.xfail -def test_tor_onion(archiving_proxies): +def test_tor_onion(archiving_proxies, warcprox_): + urls_before = warcprox_.proxy.running_stats.urls + response = requests.get('http://www.facebookcorewwwi.onion/', proxies=archiving_proxies, verify=False, allow_redirects=False) assert response.status_code == 302 @@ -1164,7 +1152,12 @@ def test_tor_onion(archiving_proxies): proxies=archiving_proxies, verify=False, allow_redirects=False) assert response.status_code == 200 -def test_missing_content_length(archiving_proxies, http_daemon, https_daemon): + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 2) + +def test_missing_content_length(archiving_proxies, http_daemon, https_daemon, warcprox_): + urls_before = warcprox_.proxy.running_stats.urls + # double-check that our test http server is responding as expected url = 'http://localhost:%s/missing-content-length' % ( http_daemon.server_port) @@ -1201,8 +1194,14 @@ def test_missing_content_length(archiving_proxies, http_daemon, https_daemon): b'This response is missing a Content-Length http header.') assert not 'content-length' in response.headers + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 2) + def test_method_filter( - https_daemon, http_daemon, archiving_proxies, playback_proxies): + warcprox_, https_daemon, http_daemon, archiving_proxies, + playback_proxies): + urls_before = warcprox_.proxy.running_stats.urls + # we've configured warcprox with method_filters=['GET','POST'] so HEAD # requests should not be archived @@ -1213,7 +1212,10 @@ def test_method_filter( assert response.headers['warcprox-test-header'] == 'z!' assert response.content == b'' - response = _poll_playback_until(playback_proxies, url, status=200, timeout_sec=10) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 1) + + response = requests.get(url, proxies=playback_proxies, verify=False) assert response.status_code == 404 assert response.content == b'404 Not in Archive\n' @@ -1230,13 +1232,17 @@ def test_method_filter( headers=headers, proxies=archiving_proxies) assert response.status_code == 204 - response = _poll_playback_until( - playback_proxies, url, status=200, timeout_sec=10) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 2) + + response = requests.get(url, proxies=playback_proxies, verify=False) assert response.status_code == 200 assert response.content == payload def test_dedup_ok_flag( https_daemon, http_daemon, warcprox_, archiving_proxies): + urls_before = warcprox_.proxy.running_stats.urls + if not warcprox_.options.rethinkdb_big_table: # this feature is n/a unless using rethinkdb big table return @@ -1258,10 +1264,8 @@ def test_dedup_ok_flag( assert response.headers['warcprox-test-header'] == 'z!' assert response.content == b'I am the warcprox test payload! bbbbbbbbbb!\n' - time.sleep(0.5) - while warcprox_.postfetch_chain_busy(): - time.sleep(0.5) - time.sleep(0.5) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 1) # check that dedup db doesn't give us anything for this dedup_lookup = warcprox_.dedup_db.lookup( @@ -1279,10 +1283,8 @@ def test_dedup_ok_flag( assert response.headers['warcprox-test-header'] == 'z!' assert response.content == b'I am the warcprox test payload! bbbbbbbbbb!\n' - time.sleep(0.5) - while warcprox_.postfetch_chain_busy(): - time.sleep(0.5) - time.sleep(0.5) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 2) # check that dedup db gives us something for this dedup_lookup = warcprox_.dedup_db.lookup( @@ -1316,7 +1318,8 @@ def test_status_api(warcprox_): 'role', 'version', 'host', 'address', 'port', 'pid', 'load', 'queued_urls', 'queue_max_size', 'seconds_behind', 'threads', 'rates_5min', 'rates_1min', 'unaccepted_requests', 'rates_15min', - 'active_requests',} + 'active_requests','start_time','urls_processed', + 'warc_bytes_written'} assert status['role'] == 'warcprox' assert status['version'] == warcprox.__version__ assert status['port'] == warcprox_.proxy.server_port @@ -1337,7 +1340,8 @@ def test_svcreg_status(warcprox_): 'queued_urls', 'queue_max_size', 'seconds_behind', 'first_heartbeat', 'ttl', 'last_heartbeat', 'threads', 'rates_5min', 'rates_1min', 'unaccepted_requests', - 'rates_15min', 'active_requests',} + 'rates_15min', 'active_requests','start_time','urls_processed', + 'warc_bytes_written',} assert status['role'] == 'warcprox' assert status['version'] == warcprox.__version__ assert status['port'] == warcprox_.proxy.server_port @@ -1426,12 +1430,17 @@ def test_choose_a_port_for_me(warcprox_): th.join() def test_via_response_header(warcprox_, http_daemon, archiving_proxies, playback_proxies): + urls_before = warcprox_.proxy.running_stats.urls + url = 'http://localhost:%s/a/z' % http_daemon.server_port response = requests.get(url, proxies=archiving_proxies) assert response.headers['via'] == '1.1 warcprox' - playback_response = _poll_playback_until( - playback_proxies, url, status=200, timeout_sec=10) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 1) + + playback_response = requests.get( + url, proxies=playback_proxies, verify=False) assert response.status_code == 200 assert not 'via' in playback_response @@ -1458,15 +1467,19 @@ def test_slash_in_warc_prefix(warcprox_, http_daemon, archiving_proxies): assert response.reason == 'request rejected by warcprox: slash and backslash are not permitted in warc-prefix' def test_crawl_log(warcprox_, http_daemon, archiving_proxies): + urls_before = warcprox_.proxy.running_stats.urls + try: os.unlink(os.path.join(warcprox_.options.crawl_log_dir, 'crawl.log')) except: pass + # should go to default crawl log url = 'http://localhost:%s/b/aa' % http_daemon.server_port response = requests.get(url, proxies=archiving_proxies) assert response.status_code == 200 + # should go to test_crawl_log_1.log url = 'http://localhost:%s/b/bb' % http_daemon.server_port headers = { "Warcprox-Meta": json.dumps({"warc-prefix":"test_crawl_log_1"}), @@ -1475,13 +1488,12 @@ def test_crawl_log(warcprox_, http_daemon, archiving_proxies): response = requests.get(url, proxies=archiving_proxies, headers=headers) assert response.status_code == 200 - start = time.time() + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 2) + file = os.path.join(warcprox_.options.crawl_log_dir, 'test_crawl_log_1.log') - while time.time() - start < 10: - if os.path.exists(file) and os.stat(file).st_size > 0: - break - time.sleep(0.5) assert os.path.exists(file) + assert os.stat(file).st_size > 0 assert os.path.exists(os.path.join( warcprox_.options.crawl_log_dir, 'crawl.log')) @@ -1536,13 +1548,12 @@ def test_crawl_log(warcprox_, http_daemon, archiving_proxies): response = requests.get(url, proxies=archiving_proxies, headers=headers) assert response.status_code == 200 - start = time.time() + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 3) + file = os.path.join(warcprox_.options.crawl_log_dir, 'test_crawl_log_2.log') - while time.time() - start < 10: - if os.path.exists(file) and os.stat(file).st_size > 0: - break - time.sleep(0.5) assert os.path.exists(file) + assert os.stat(file).st_size > 0 crawl_log_2 = open(file, 'rb').read() @@ -1566,17 +1577,14 @@ def test_crawl_log(warcprox_, http_daemon, archiving_proxies): assert extra_info['contentSize'] == 145 # a request that is not saved to a warc (because of --method-filter) - # currently not logged at all (XXX maybe it should be) url = 'http://localhost:%s/b/cc' % http_daemon.server_port headers = {'Warcprox-Meta': json.dumps({'warc-prefix': 'test_crawl_log_3'})} response = requests.head(url, proxies=archiving_proxies, headers=headers) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 4) + file = os.path.join(warcprox_.options.crawl_log_dir, 'test_crawl_log_3.log') - start = time.time() - while time.time() - start < 10: - if os.path.exists(file) and os.stat(file).st_size > 0: - break - time.sleep(0.5) assert os.path.exists(file) crawl_log_3 = open(file, 'rb').read() @@ -1611,13 +1619,10 @@ def test_crawl_log(warcprox_, http_daemon, archiving_proxies): headers=headers, proxies=archiving_proxies) assert response.status_code == 204 - start = time.time() - file = os.path.join(warcprox_.options.crawl_log_dir, 'test_crawl_log_4.log') - while time.time() - start < 10: - if os.path.exists(file) and os.stat(file).st_size > 0: - break - time.sleep(0.5) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 5) + file = os.path.join(warcprox_.options.crawl_log_dir, 'test_crawl_log_4.log') assert os.path.exists(file) crawl_log_4 = open(file, 'rb').read() @@ -1642,6 +1647,8 @@ def test_crawl_log(warcprox_, http_daemon, archiving_proxies): def test_long_warcprox_meta( warcprox_, http_daemon, archiving_proxies, playback_proxies): + urls_before = warcprox_.proxy.running_stats.urls + url = 'http://localhost:%s/b/g' % http_daemon.server_port # create a very long warcprox-meta header @@ -1651,11 +1658,8 @@ def test_long_warcprox_meta( url, proxies=archiving_proxies, headers=headers, verify=False) assert response.status_code == 200 - # wait for writer thread to process - time.sleep(0.5) - while warcprox_.postfetch_chain_busy(): - time.sleep(0.5) - time.sleep(0.5) + # wait for postfetch chain + wait(lambda: warcprox_.proxy.running_stats.urls - urls_before == 1) # check that warcprox-meta was parsed and honored ("warc-prefix" param) assert warcprox_.warc_writer_thread.writer_pool.warc_writers["test_long_warcprox_meta"] @@ -1681,7 +1685,6 @@ def test_long_warcprox_meta( def test_empty_response( warcprox_, http_daemon, https_daemon, archiving_proxies, playback_proxies): - url = 'http://localhost:%s/empty-response' % http_daemon.server_port response = requests.get(url, proxies=archiving_proxies, verify=False) assert response.status_code == 502 diff --git a/warcprox/__init__.py b/warcprox/__init__.py index bc1365c..2cd62cd 100644 --- a/warcprox/__init__.py +++ b/warcprox/__init__.py @@ -99,12 +99,13 @@ class RequestBlockedByRule(Exception): class BasePostfetchProcessor(threading.Thread): logger = logging.getLogger("warcprox.BasePostfetchProcessor") - def __init__(self, inq, outq, options=Options()): + def __init__(self, options=Options()): threading.Thread.__init__(self, name=self.__class__.__name__) - self.inq = inq - self.outq = outq self.options = options self.stop = threading.Event() + # these should be set before thread is started + self.inq = None + self.outq = None def run(self): if self.options.profile: @@ -128,6 +129,7 @@ class BasePostfetchProcessor(threading.Thread): raise Exception('not implemented') def _run(self): + self._startup() while not self.stop.is_set(): try: while True: @@ -152,6 +154,9 @@ class BasePostfetchProcessor(threading.Thread): self.name, exc_info=True) time.sleep(0.5) + def _startup(self): + pass + def _shutdown(self): pass @@ -175,6 +180,13 @@ class BaseBatchPostfetchProcessor(BasePostfetchProcessor): start = time.time() while True: + try: + batch.append(self.inq.get(block=True, timeout=0.5)) + except queue.Empty: + if self.stop.is_set(): + break + # else maybe keep adding to the batch + if len(batch) >= self.MAX_BATCH_SIZE: break # full batch @@ -186,18 +198,11 @@ class BaseBatchPostfetchProcessor(BasePostfetchProcessor): and len(self.outq.queue) == 0): break # next processor is waiting on us - try: - batch.append(self.inq.get(block=True, timeout=0.5)) - except queue.Empty: - if self.stop.is_set(): - break - # else keep adding to the batch - if not batch: raise queue.Empty self.logger.info( - 'gathered batch of %s in %0.1f sec', + 'gathered batch of %s in %0.2f sec', len(batch), time.time() - start) self._process_batch(batch) @@ -209,8 +214,8 @@ class BaseBatchPostfetchProcessor(BasePostfetchProcessor): raise Exception('not implemented') class ListenerPostfetchProcessor(BaseStandardPostfetchProcessor): - def __init__(self, listener, inq, outq, options=Options()): - BaseStandardPostfetchProcessor.__init__(self, inq, outq, options) + def __init__(self, listener, options=Options()): + BaseStandardPostfetchProcessor.__init__(self, options) self.listener = listener self.name = listener.__class__.__name__ diff --git a/warcprox/controller.py b/warcprox/controller.py index dabbf6a..9902cb5 100644 --- a/warcprox/controller.py +++ b/warcprox/controller.py @@ -55,20 +55,20 @@ class Factory: return dedup_db @staticmethod - def stats_db(options): + def stats_processor(options): + # return warcprox.stats.StatsProcessor(options) if options.rethinkdb_stats_url: - stats_db = warcprox.stats.RethinkStatsDb(options=options) + stats_processor = warcprox.stats.RethinkStatsProcessor(options) elif options.stats_db_file in (None, '', '/dev/null'): logging.info('statistics tracking disabled') - stats_db = None + stats_processor = None else: - stats_db = warcprox.stats.StatsDb( - options.stats_db_file, options=options) - return stats_db + stats_processor = warcprox.stats.StatsProcessor(options) + return stats_processor @staticmethod - def warc_writer(inq, outq, options): - return warcprox.writerthread.WarcWriterThread(inq, outq, options) + def warc_writer(options): + return warcprox.writerthread.WarcWriterThread(options) @staticmethod def playback_proxy(ca, options): @@ -130,9 +130,9 @@ class WarcproxController(object): self.stop = threading.Event() self._start_stop_lock = threading.Lock() - self.stats_db = Factory.stats_db(self.options) + self.stats_processor = Factory.stats_processor(self.options) - self.proxy = warcprox.warcproxy.WarcProxy(self.stats_db, options) + self.proxy = warcprox.warcproxy.WarcProxy(self.stats_processor, options) self.playback_proxy = Factory.playback_proxy( self.proxy.ca, self.options) @@ -140,59 +140,52 @@ class WarcproxController(object): self.service_registry = Factory.service_registry(options) - def postfetch_chain_busy(self): - for processor in self._postfetch_chain: - if processor.inq.qsize() > 0: - return True - return False + def chain(self, processor0, processor1): + assert not processor0.outq + assert not processor1.inq + q = warcprox.TimestampedQueue(maxsize=self.options.queue_size) + processor0.outq = q + processor1.inq = q def build_postfetch_chain(self, inq): - constructors = [] + self._postfetch_chain = [] self.dedup_db = Factory.dedup_db(self.options) if self.dedup_db: - constructors.append(self.dedup_db.loader) + self._postfetch_chain.append(self.dedup_db.loader()) - constructors.append(Factory.warc_writer) + self.warc_writer_thread = Factory.warc_writer(self.options) + self._postfetch_chain.append(self.warc_writer_thread) if self.dedup_db: - constructors.append(self.dedup_db.storer) + self._postfetch_chain.append(self.dedup_db.storer()) - if self.stats_db: - constructors.append(functools.partial( - warcprox.ListenerPostfetchProcessor, self.stats_db)) + if self.stats_processor: + self._postfetch_chain.append(self.stats_processor) if self.playback_proxy: - constructors.append(functools.partial( - warcprox.ListenerPostfetchProcessor, - self.playback_proxy.playback_index_db)) + self._postfetch_chain.append( + warcprox.ListenerPostfetchProcessor( + self.playback_proxy.playback_index_db)) crawl_logger = Factory.crawl_logger(self.options) if crawl_logger: - constructors.append(functools.partial( - warcprox.ListenerPostfetchProcessor, crawl_logger)) + self._postfetch_chain.append( + warcprox.ListenerPostfetchProcessor(crawl_logger)) - constructors.append(functools.partial( - warcprox.ListenerPostfetchProcessor, self.proxy.running_stats)) + self._postfetch_chain.append( + warcprox.ListenerPostfetchProcessor(self.proxy.running_stats)) for qualname in self.options.plugins or []: plugin = Factory.plugin(qualname) - constructors.append(functools.partial( - warcprox.ListenerPostfetchProcessor, plugin)) + self._postfetch_chain.append( + warcprox.ListenerPostfetchProcessor(plugin)) - self._postfetch_chain = [] - for i, constructor in enumerate(constructors): - if i != len(constructors) - 1: - outq = warcprox.TimestampedQueue( - maxsize=self.options.queue_size) - else: - outq = None - processor = constructor(inq, outq, self.options) - if isinstance(processor, warcprox.writerthread.WarcWriterThread): - self.warc_writer_thread = processor # ugly - self._postfetch_chain.append(processor) - inq = outq + # chain them all up + self._postfetch_chain[0].inq = inq + for i in range(1, len(self._postfetch_chain)): + self.chain(self._postfetch_chain[i-1], self._postfetch_chain[i]) def debug_mem(self): self.logger.info("self.proxy.recorded_url_q.qsize()=%s", self.proxy.recorded_url_q.qsize()) @@ -314,9 +307,6 @@ class WarcproxController(object): for processor in self._postfetch_chain: processor.join() - if self.stats_db: - self.stats_db.stop() - self.proxy_thread.join() if self.playback_proxy is not None: self.playback_proxy_thread.join() diff --git a/warcprox/dedup.py b/warcprox/dedup.py index 8931db2..0b52ffb 100644 --- a/warcprox/dedup.py +++ b/warcprox/dedup.py @@ -37,10 +37,10 @@ import collections urllib3.disable_warnings() class DedupLoader(warcprox.BaseStandardPostfetchProcessor): - def __init__(self, dedup_db, inq, outq, options=warcprox.Options()): - warcprox.BaseStandardPostfetchProcessor.__init__( - self, inq, outq, options) + def __init__(self, dedup_db, options=warcprox.Options()): + warcprox.BaseStandardPostfetchProcessor.__init__(self, options=options) self.dedup_db = dedup_db + def _process_url(self, recorded_url): decorate_with_dedup_info( self.dedup_db, recorded_url, self.options.base32) @@ -71,12 +71,11 @@ class DedupDb(object): conn.commit() conn.close() - def loader(self, inq, outq, *args, **kwargs): - return DedupLoader(self, inq, outq, self.options) + def loader(self, *args, **kwargs): + return DedupLoader(self, self.options) - def storer(self, inq, outq, *args, **kwargs): - return warcprox.ListenerPostfetchProcessor( - self, inq, outq, self.options) + def storer(self, *args, **kwargs): + return warcprox.ListenerPostfetchProcessor(self, self.options) def save(self, digest_key, response_record, bucket=""): record_id = response_record.get_header(warctools.WarcRecord.ID).decode('latin1') @@ -262,8 +261,8 @@ class CdxServerDedup(DedupDb): pass class BatchTroughLoader(warcprox.BaseBatchPostfetchProcessor): - def __init__(self, trough_dedup_db, inq, outq, options=warcprox.Options()): - warcprox.BaseBatchPostfetchProcessor.__init__(self, inq, outq, options) + def __init__(self, trough_dedup_db, options=warcprox.Options()): + warcprox.BaseBatchPostfetchProcessor.__init__(self, options) self.trough_dedup_db = trough_dedup_db def _filter_and_bucketize(self, batch): @@ -324,8 +323,8 @@ class TroughDedupDb(DedupDb): self._trough_cli = warcprox.trough.TroughClient( options.rethinkdb_trough_db_url, promotion_interval=60*60) - def loader(self, inq, outq, options=warcprox.Options()): - return BatchTroughLoader(self, inq, outq, options) + def loader(self, options=warcprox.Options()): + return BatchTroughLoader(self, options) def start(self): self._trough_cli.register_schema(self.SCHEMA_ID, self.SCHEMA_SQL) diff --git a/warcprox/stats.py b/warcprox/stats.py index 6047443..db2493c 100644 --- a/warcprox/stats.py +++ b/warcprox/stats.py @@ -53,45 +53,88 @@ def _empty_bucket(bucket): }, } -class StatsDb: - logger = logging.getLogger("warcprox.stats.StatsDb") +class StatsProcessor(warcprox.BaseBatchPostfetchProcessor): + logger = logging.getLogger("warcprox.stats.StatsProcessor") - def __init__(self, file='./warcprox.sqlite', options=warcprox.Options()): - self.file = file - self.options = options - self._lock = threading.RLock() + def _startup(self): + if os.path.exists(self.options.stats_db_file): + self.logger.info( + 'opening existing stats database %s', + self.options.stats_db_file) + else: + self.logger.info( + 'creating new stats database %s', + self.options.stats_db_file) - def start(self): - with self._lock: - if os.path.exists(self.file): - self.logger.info( - 'opening existing stats database %s', self.file) - else: - self.logger.info( - 'creating new stats database %s', self.file) + conn = sqlite3.connect(self.options.stats_db_file) + conn.execute( + 'create table if not exists buckets_of_stats (' + ' bucket varchar(300) primary key,' + ' stats varchar(4000)' + ');') + conn.commit() + conn.close() - conn = sqlite3.connect(self.file) + self.logger.info( + 'created table buckets_of_stats in %s', + self.options.stats_db_file) + + def _process_batch(self, batch): + batch_buckets = self._tally_batch(batch) + self._update_db(batch_buckets) + logging.trace('updated stats from batch of %s', len(batch)) + + def _update_db(self, batch_buckets): + conn = sqlite3.connect(self.options.stats_db_file) + for bucket in batch_buckets: + bucket_stats = batch_buckets[bucket] + + cursor = conn.execute( + 'select stats from buckets_of_stats where bucket=?', + (bucket,)) + result_tuple = cursor.fetchone() + cursor.close() + + if result_tuple: + old_bucket_stats = json.loads(result_tuple[0]) + + bucket_stats['total']['urls'] += old_bucket_stats['total']['urls'] + bucket_stats['total']['wire_bytes'] += old_bucket_stats['total']['wire_bytes'] + bucket_stats['revisit']['urls'] += old_bucket_stats['revisit']['urls'] + bucket_stats['revisit']['wire_bytes'] += old_bucket_stats['revisit']['wire_bytes'] + bucket_stats['new']['urls'] += old_bucket_stats['new']['urls'] + bucket_stats['new']['wire_bytes'] += old_bucket_stats['new']['wire_bytes'] + + json_value = json.dumps(bucket_stats, separators=(',',':')) conn.execute( - 'create table if not exists buckets_of_stats (' - ' bucket varchar(300) primary key,' - ' stats varchar(4000)' - ');') + 'insert or replace into buckets_of_stats ' + '(bucket, stats) values (?, ?)', (bucket, json_value)) conn.commit() - conn.close() + conn.close() - self.logger.info('created table buckets_of_stats in %s', self.file) + def _tally_batch(self, batch): + batch_buckets = {} + for recorded_url in batch: + for bucket in self.buckets(recorded_url): + bucket_stats = batch_buckets.get(bucket) + if not bucket_stats: + bucket_stats = _empty_bucket(bucket) + batch_buckets[bucket] = bucket_stats - def stop(self): - pass + bucket_stats["total"]["urls"] += 1 + bucket_stats["total"]["wire_bytes"] += recorded_url.size - def close(self): - pass - - def sync(self): - pass + if recorded_url.warc_records: + if recorded_url.warc_records[0].type == b'revisit': + bucket_stats["revisit"]["urls"] += 1 + bucket_stats["revisit"]["wire_bytes"] += recorded_url.size + else: + bucket_stats["new"]["urls"] += 1 + bucket_stats["new"]["wire_bytes"] += recorded_url.size + return batch_buckets def value(self, bucket0="__all__", bucket1=None, bucket2=None): - conn = sqlite3.connect(self.file) + conn = sqlite3.connect(self.options.stats_db_file) cursor = conn.execute( 'select stats from buckets_of_stats where bucket = ?', (bucket0,)) @@ -109,9 +152,6 @@ class StatsDb: else: return None - def notify(self, recorded_url, records): - self.tally(recorded_url, records) - def buckets(self, recorded_url): ''' Unravels bucket definitions in Warcprox-Meta header. Each bucket @@ -154,117 +194,20 @@ class StatsDb: return buckets - def tally(self, recorded_url, records): - with self._lock: - conn = sqlite3.connect(self.file) - - for bucket in self.buckets(recorded_url): - cursor = conn.execute( - 'select stats from buckets_of_stats where bucket=?', - (bucket,)) - - result_tuple = cursor.fetchone() - cursor.close() - if result_tuple: - bucket_stats = json.loads(result_tuple[0]) - else: - bucket_stats = _empty_bucket(bucket) - - bucket_stats["total"]["urls"] += 1 - bucket_stats["total"]["wire_bytes"] += recorded_url.size - - if records: - if records[0].type == b'revisit': - bucket_stats["revisit"]["urls"] += 1 - bucket_stats["revisit"]["wire_bytes"] += recorded_url.size - else: - bucket_stats["new"]["urls"] += 1 - bucket_stats["new"]["wire_bytes"] += recorded_url.size - - json_value = json.dumps(bucket_stats, separators=(',',':')) - conn.execute( - 'insert or replace into buckets_of_stats ' - '(bucket, stats) values (?, ?)', (bucket, json_value)) - conn.commit() - - conn.close() - -class RethinkStatsDb(StatsDb): - """Updates database in batch every 2.0 seconds""" - logger = logging.getLogger("warcprox.stats.RethinkStatsDb") +class RethinkStatsProcessor(StatsProcessor): + logger = logging.getLogger("warcprox.stats.RethinkStatsProcessor") def __init__(self, options=warcprox.Options()): + StatsProcessor.__init__(self, options) + parsed = doublethink.parse_rethinkdb_url(options.rethinkdb_stats_url) self.rr = doublethink.Rethinker( servers=parsed.hosts, db=parsed.database) self.table = parsed.table self.replicas = min(3, len(self.rr.servers)) + + def _startup(self): self._ensure_db_table() - self.options = options - - self._stop = threading.Event() - self._batch_lock = threading.RLock() - with self._batch_lock: - self._batch = {} - self._timer = None - - def start(self): - """Starts batch update repeating timer.""" - self._update_batch() # starts repeating timer - - def _bucket_batch_update_reql(self, bucket, batch): - return self.rr.table(self.table).get(bucket).replace( - lambda old: r.branch( - old.eq(None), batch[bucket], old.merge({ - "total": { - "urls": old["total"]["urls"].add( - batch[bucket]["total"]["urls"]), - "wire_bytes": old["total"]["wire_bytes"].add( - batch[bucket]["total"]["wire_bytes"]), - }, - "new": { - "urls": old["new"]["urls"].add( - batch[bucket]["new"]["urls"]), - "wire_bytes": old["new"]["wire_bytes"].add( - batch[bucket]["new"]["wire_bytes"]), - }, - "revisit": { - "urls": old["revisit"]["urls"].add( - batch[bucket]["revisit"]["urls"]), - "wire_bytes": old["revisit"]["wire_bytes"].add( - batch[bucket]["revisit"]["wire_bytes"]), - }, - }))) - - def _update_batch(self): - with self._batch_lock: - batch_copy = copy.deepcopy(self._batch) - self._batch = {} - try: - if len(batch_copy) > 0: - # XXX can all the buckets be done in one query? - for bucket in batch_copy: - result = self._bucket_batch_update_reql( - bucket, batch_copy).run() - if (not result["inserted"] and not result["replaced"] - or sorted(result.values()) != [0,0,0,0,0,1]): - raise Exception( - "unexpected result %s updating stats %s" % ( - result, batch_copy[bucket])) - except Exception as e: - self.logger.error("problem updating stats", exc_info=True) - # now we need to restore the stats that didn't get saved to the - # batch so that they are saved in the next call to _update_batch() - with self._batch_lock: - self._add_to_batch(batch_copy) - finally: - if not self._stop.is_set(): - self._timer = threading.Timer(2.0, self._update_batch) - self._timer.name = "RethinkStats-batch-update-timer-%s" % ( - datetime.datetime.utcnow().isoformat()) - self._timer.start() - else: - self.logger.info("finished") def _ensure_db_table(self): dbs = self.rr.db_list().run() @@ -282,17 +225,38 @@ class RethinkStatsDb(StatsDb): self.table, primary_key="bucket", shards=1, replicas=self.replicas).run() - def close(self): - self.stop() + def _update_db(self, batch_buckets): + # XXX can all the buckets be done in one query? + for bucket in batch_buckets: + result = self._bucket_batch_update_reql( + bucket, batch_buckets[bucket]).run() + if (not result['inserted'] and not result['replaced'] + or sorted(result.values()) != [0,0,0,0,0,1]): + self.logger.error( + 'unexpected result %s updating stats %s' % ( + result, batch_buckets[bucket])) - def stop(self): - self.logger.info("stopping rethinkdb stats table batch updates") - self._stop.set() - if self._timer: - self._timer.join() - - def sync(self): - pass + def _bucket_batch_update_reql(self, bucket, new): + return self.rr.table(self.table).get(bucket).replace( + lambda old: r.branch( + old.eq(None), new, old.merge({ + 'total': { + 'urls': old['total']['urls'].add(new['total']['urls']), + 'wire_bytes': old['total']['wire_bytes'].add( + new['total']['wire_bytes']), + }, + 'new': { + 'urls': old['new']['urls'].add(new['new']['urls']), + 'wire_bytes': old['new']['wire_bytes'].add( + new['new']['wire_bytes']), + }, + 'revisit': { + 'urls': old['revisit']['urls'].add( + new['revisit']['urls']), + 'wire_bytes': old['revisit']['wire_bytes'].add( + new['revisit']['wire_bytes']), + }, + }))) def value(self, bucket0="__all__", bucket1=None, bucket2=None): bucket0_stats = self.rr.table(self.table).get(bucket0).run() @@ -307,39 +271,6 @@ class RethinkStatsDb(StatsDb): return bucket0_stats[bucket1] return bucket0_stats - def tally(self, recorded_url, records): - buckets = self.buckets(recorded_url) - with self._batch_lock: - for bucket in buckets: - bucket_stats = self._batch.setdefault( - bucket, _empty_bucket(bucket)) - - bucket_stats["total"]["urls"] += 1 - bucket_stats["total"]["wire_bytes"] += recorded_url.size - - if records: - if records[0].type == b'revisit': - bucket_stats["revisit"]["urls"] += 1 - bucket_stats["revisit"]["wire_bytes"] += recorded_url.size - else: - bucket_stats["new"]["urls"] += 1 - bucket_stats["new"]["wire_bytes"] += recorded_url.size - - def _add_to_batch(self, add_me): - with self._batch_lock: - for bucket in add_me: - bucket_stats = self._batch.setdefault( - bucket, _empty_bucket(bucket)) - bucket_stats["total"]["urls"] += add_me[bucket]["total"]["urls"] - bucket_stats["total"]["wire_bytes"] += add_me[bucket]["total"]["wire_bytes"] - bucket_stats["revisit"]["urls"] += add_me[bucket]["revisit"]["urls"] - bucket_stats["revisit"]["wire_bytes"] += add_me[bucket]["revisit"]["wire_bytes"] - bucket_stats["new"]["urls"] += add_me[bucket]["new"]["urls"] - bucket_stats["new"]["wire_bytes"] += add_me[bucket]["new"]["wire_bytes"] - - def notify(self, recorded_url, records): - self.tally(recorded_url, records) - class RunningStats: ''' In-memory stats for measuring overall warcprox performance. diff --git a/warcprox/writerthread.py b/warcprox/writerthread.py index 5747428..f823cc6 100644 --- a/warcprox/writerthread.py +++ b/warcprox/writerthread.py @@ -36,9 +36,8 @@ class WarcWriterThread(warcprox.BaseStandardPostfetchProcessor): _ALWAYS_ACCEPT = {'WARCPROX_WRITE_RECORD'} - def __init__(self, inq, outq, options=warcprox.Options()): - warcprox.BaseStandardPostfetchProcessor.__init__( - self, inq, outq, options=options) + def __init__(self, options=warcprox.Options()): + warcprox.BaseStandardPostfetchProcessor.__init__(self, options=options) self.options = options self.writer_pool = warcprox.writer.WarcWriterPool(options) self.method_filter = set(method.upper() for method in self.options.method_filter or [])