rewrite run-benchmarks.py for aiohttp2

This commit is contained in:
Noah Levitt 2017-05-08 20:56:32 -07:00
parent c87ff90bc1
commit eea582c6db
9 changed files with 269 additions and 155 deletions

View File

@ -1 +1 @@
aiohttp aiohttp==2.0.7

View File

@ -1,42 +1,57 @@
#!/usr/bin/env python #!/usr/bin/env python
# '''
# run-benchmarks.py - some benchmarking code for warcprox run-benchmarks.py - some benchmarking code for warcprox
#
# Copyright (C) 2015-2016 Internet Archive
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301,
# USA.
#
import sys Copyright (C) 2015-2017 Internet Archive
import aiohttp
import aiohttp.server This program is free software; you can redistribute it and/or
modify it under the terms of the GNU General Public License
as published by the Free Software Foundation; either version 2
of the License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301,
USA.
'''
import aiohttp.web
import asyncio import asyncio
import ssl import ssl
import tempfile
import OpenSSL.crypto import OpenSSL.crypto
import OpenSSL.SSL import OpenSSL.SSL
import tempfile
import random import random
import os import os
import threading
import time
import logging import logging
import sys
import time
import argparse
import hashlib
import datetime
import cryptography.hazmat.backends.openssl
import warcprox
import warcprox.main import warcprox.main
import threading
logging.basicConfig(stream=sys.stdout, level=logging.INFO, # https://medium.com/@generativist/a-simple-streaming-http-server-in-aiohttp-4233dbc173c7
format='%(asctime)s %(process)d %(levelname)s %(threadName)s %(name)s.%(funcName)s(%(filename)s:%(lineno)d) %(message)s') async def do_get(request):
# return aiohttp.web.Response(text='foo=%s' % request.match_info.get('foo'))
n = int(request.match_info.get('n'))
response = aiohttp.web.StreamResponse(
status=200, reason='OK', headers={'Content-Type': 'text/plain'})
await response.prepare(request)
for i in range(n):
for i in range(10):
response.write(b'x' * 99 + b'\n')
await response.drain()
return response
def self_signed_cert(): def self_signed_cert():
key = OpenSSL.crypto.PKey() key = OpenSSL.crypto.PKey()
@ -44,7 +59,7 @@ def self_signed_cert():
cert = OpenSSL.crypto.X509() cert = OpenSSL.crypto.X509()
cert.set_serial_number(random.randint(0, 2 ** 64 - 1)) cert.set_serial_number(random.randint(0, 2 ** 64 - 1))
cert.get_subject().CN = 'localhost' cert.get_subject().CN = '127.0.0.1'
cert.set_version(2) cert.set_version(2)
cert.gmtime_adj_notBefore(0) cert.gmtime_adj_notBefore(0)
@ -52,121 +67,211 @@ def self_signed_cert():
cert.set_issuer(cert.get_subject()) cert.set_issuer(cert.get_subject())
cert.set_pubkey(key) cert.set_pubkey(key)
cert.sign(key, "sha1") cert.sign(key, 'sha1')
return key, cert return key, cert
class HttpRequestHandler(aiohttp.server.ServerHttpProtocol): def ssl_context():
@asyncio.coroutine sslc = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
def handle_request(self, message, payload): with tempfile.NamedTemporaryFile(delete=False) as certfile:
response = aiohttp.Response( key, cert = self_signed_cert()
self.writer, 200, http_version=message.version certfile.write(
) OpenSSL.crypto.dump_privatekey(OpenSSL.SSL.FILETYPE_PEM, key))
n = int(message.path.partition('/')[2]) certfile.write(
response.add_header('Content-Type', 'text/plain') OpenSSL.crypto.dump_certificate(OpenSSL.SSL.FILETYPE_PEM, cert))
# response.add_header('Content-Length', '18') sslc.load_cert_chain(certfile.name)
response.send_headers() os.remove(certfile.name)
for i in range(n): return sslc
response.write(b'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\n')
yield from response.write_eof()
def run_servers():
loop.run_forever()
def start_servers(): def start_servers():
app = aiohttp.web.Application()
app.router.add_get('/{n}', do_get)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
http = loop.create_server(lambda: HttpRequestHandler(debug=True, keep_alive=75), '127.0.0.1', '8080')
sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
key, cert = self_signed_cert()
with tempfile.NamedTemporaryFile(delete=False) as certfile:
certfile.write(OpenSSL.crypto.dump_privatekey(OpenSSL.SSL.FILETYPE_PEM, key))
certfile.write(OpenSSL.crypto.dump_certificate(OpenSSL.SSL.FILETYPE_PEM, cert))
sslcontext.load_cert_chain(certfile.name)
os.remove(certfile.name)
https = loop.create_server(lambda: HttpRequestHandler(debug=True, keep_alive=75), '127.0.0.1', '8443', ssl=sslcontext)
srv = loop.run_until_complete(http)
srv = loop.run_until_complete(https)
logging.info('serving on http://127.0.0.1:8080 and https://127.0.0.1:8443')
class AsyncClient(object): http = loop.create_server(
def __init__(self, proxy=None): app.make_handler(access_log=None), '127.0.0.1', 4080)
self.n_urls = 0 loop.run_until_complete(http)
self.n_bytes = 0
self.proxy = proxy
if proxy:
self.connector = aiohttp.connector.ProxyConnector(proxy, verify_ssl=False)
else:
self.connector = aiohttp.connector.TCPConnector(verify_ssl=False)
@asyncio.coroutine sslc = ssl_context()
def read_response(self, r, url): https = loop.create_server(
# time.sleep(random.random() * 10) app.make_handler(access_log=None), '127.0.0.1', 4443, ssl=sslc)
while True: loop.run_until_complete(https)
chunk = yield from r.content.read(2**16)
self.n_bytes += len(chunk)
if not chunk:
self.n_urls += 1
logging.debug("finished reading from %s", url)
r.close()
break
@asyncio.coroutine async def benchmarking_client(base_url, n, proxy=None):
def one_request(self, url): n_urls = 0
logging.debug("issuing request to %s", url) n_bytes = 0
r = yield from aiohttp.get(url, connector=self.connector) for i in range(n):
logging.debug("issued request to %s", url) url = '%s/%s' % (base_url, i)
yield from self.read_response(r, url) connector = aiohttp.TCPConnector(verify_ssl=False)
async with aiohttp.ClientSession(connector=connector) as session:
async with session.get(url, proxy=proxy) as response:
assert response.status == 200
while True:
chunk = await response.content.read(2**16)
n_bytes += len(chunk)
if not chunk:
n_urls += 1
break
return n_urls, n_bytes
def benchmark(client): def build_arg_parser(tmpdir, prog=os.path.basename(sys.argv[0])):
arg_parser = argparse.ArgumentParser(
prog=prog, description='warcprox benchmarker',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
arg_parser.add_argument(
'-z', '--gzip', dest='gzip', action='store_true',
help='write gzip-compressed warc records')
arg_parser.add_argument(
'-s', '--size', dest='size', default=1000*1000*1000, type=int,
help='WARC file rollover size threshold in bytes')
arg_parser.add_argument(
'--rollover-idle-time', dest='rollover_idle_time', default=None,
type=int, help=(
'WARC file rollover idle time threshold in seconds (so that '
"Friday's last open WARC doesn't sit there all weekend "
'waiting for more data)'))
try: try:
start = time.time() hash_algos = hashlib.algorithms_guaranteed
tasks_https = [client.one_request('https://localhost:8443/%s' % int(1.1**i)) for i in range(80)] except AttributeError:
asyncio.get_event_loop().run_until_complete(asyncio.wait(tasks_https)) hash_algos = hashlib.algorithms
tasks_http = [client.one_request('http://localhost:8080/%s' % int(1.1**i)) for i in range(80)] arg_parser.add_argument(
asyncio.get_event_loop().run_until_complete(asyncio.wait(tasks_http)) '-g', '--digest-algorithm', dest='digest_algorithm',
finally: default='sha1', help='digest algorithm, one of %s' % hash_algos)
finish = time.time() arg_parser.add_argument('--base32', dest='base32', action='store_true',
logging.info("proxy=%s: %s urls totaling %s bytes in %s seconds", client.proxy, client.n_urls, client.n_bytes, (finish - start)) default=False, help='write digests in Base32 instead of hex')
arg_parser.add_argument(
'--method-filter', metavar='HTTP_METHOD',
action='append', help=(
'only record requests with the given http method(s) (can be '
'used more than once)'))
arg_parser.add_argument(
'--stats-db-file', dest='stats_db_file',
default=os.path.join(tmpdir, 'stats.db'), help=(
'persistent statistics database file; empty string or '
'/dev/null disables statistics tracking'))
group = arg_parser.add_mutually_exclusive_group()
group.add_argument(
'-j', '--dedup-db-file', dest='dedup_db_file',
default=os.path.join(tmpdir, 'dedup.db'), help=(
'persistent deduplication database file; empty string or '
'/dev/null disables deduplication'))
group.add_argument(
'--rethinkdb-servers', dest='rethinkdb_servers', help=(
'rethinkdb servers, used for dedup and stats if specified; '
'e.g. db0.foo.org,db0.foo.org:38015,db1.foo.org'))
# arg_parser.add_argument(
# '--rethinkdb-db', dest='rethinkdb_db', default='warcprox', help=(
# 'rethinkdb database name (ignored unless --rethinkdb-servers '
# 'is specified)'))
arg_parser.add_argument(
'--rethinkdb-big-table', dest='rethinkdb_big_table',
action='store_true', default=False, help=(
'use a big rethinkdb table called "captures", instead of a '
'small table called "dedup"; table is suitable for use as '
'index for playback (ignored unless --rethinkdb-servers is '
'specified)'))
arg_parser.add_argument(
'--kafka-broker-list', dest='kafka_broker_list', default=None,
help='kafka broker list for capture feed')
arg_parser.add_argument(
'--kafka-capture-feed-topic', dest='kafka_capture_feed_topic',
default=None, help='kafka capture feed topic')
arg_parser.add_argument(
'--queue-size', dest='queue_size', type=int, default=1)
arg_parser.add_argument(
'--max-threads', dest='max_threads', type=int)
arg_parser.add_argument(
'--version', action='version',
version='warcprox %s' % warcprox.__version__)
arg_parser.add_argument(
'-v', '--verbose', dest='verbose', action='store_true')
arg_parser.add_argument('--trace', dest='trace', action='store_true')
arg_parser.add_argument('-q', '--quiet', dest='quiet', action='store_true')
return arg_parser
if __name__ == '__main__': if __name__ == '__main__':
args = warcprox.main.parse_args() # see https://github.com/pyca/cryptography/issues/2911
cryptography.hazmat.backends.openssl.backend.activate_builtin_random()
start_servers()
baseline_client = AsyncClient()
logging.info("===== baseline benchmark starting (no proxy) =====")
benchmark(baseline_client)
logging.info("===== baseline benchmark finished =====")
# Queue size of 1 makes warcprox behave as though it were synchronous (each
# request blocks until the warc writer starts working on the last request).
# This gives us a better sense of sustained max throughput. The
# asynchronous nature of warcprox helps with bursty traffic, as long as the
# average throughput stays below the sustained max.
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
args.queue_size = 1 arg_parser = build_arg_parser(tmpdir)
args.cacert = os.path.join(tmpdir, "benchmark-warcprox-ca.pem") args = arg_parser.parse_args(args=sys.argv[1:])
args.certs_dir = os.path.join(tmpdir, "benchmark-warcprox-ca")
args.directory = os.path.join(tmpdir, "warcs")
args.gzip = True
args.base32 = True
args.stats_db_file = os.path.join(tmpdir, "stats.db")
args.dedup_db_file = os.path.join(tmpdir, "dedup.db")
if args.trace:
loglevel = warcprox.TRACE
elif args.verbose:
loglevel = logging.DEBUG
elif args.quiet:
loglevel = logging.WARNING
else:
loglevel = logging.INFO
logging.basicConfig(
stream=sys.stdout, level=loglevel, format=(
'%(asctime)s %(process)d %(levelname)s %(threadName)s '
'%(name)s.%(funcName)s(%(filename)s:%(lineno)d) %(message)s'))
logging.getLogger('warcprox').setLevel(loglevel + 5)
args.playback_port = None
args.address = '127.0.0.1'
args.port = 0
args.cacert = os.path.join(tmpdir, 'benchmark-warcprox-ca.pem')
args.certs_dir = os.path.join(tmpdir, 'benchmark-warcprox-ca')
args.directory = os.path.join(tmpdir, 'warcs')
if args.rethinkdb_servers:
args.rethinkdb_db = 'benchmarks_{:%Y%m%d%H%M%S}' % (
datetime.datetime.utcnow())
warcprox_controller = warcprox.main.init_controller(args) warcprox_controller = warcprox.main.init_controller(args)
warcprox_controller_thread = threading.Thread(target=warcprox_controller.run_until_shutdown) warcprox_controller_thread = threading.Thread(
target=warcprox_controller.run_until_shutdown)
warcprox_controller_thread.start() warcprox_controller_thread.start()
proxy = "http://%s:%s" % (args.address, args.port)
proxied_client = AsyncClient(proxy=proxy)
start_servers()
logging.info(
'servers running at http://127.0.0.1:4080 and '
'https://127.0.0.1:4443')
loop = asyncio.get_event_loop()
logging.info("===== baseline benchmark starting (no proxy) =====")
start = time.time()
n_urls, n_bytes = loop.run_until_complete(
benchmarking_client('http://127.0.0.1:4080', 100))
finish = time.time()
logging.info(
'http baseline (no proxy): n_urls=%s n_bytes=%s in %.1f sec',
n_urls, n_bytes, finish - start)
start = time.time()
n_urls, n_bytes = loop.run_until_complete(
benchmarking_client('https://127.0.0.1:4443', 100))
finish = time.time()
logging.info(
'https baseline (no proxy): n_urls=%s n_bytes=%s in %.1f sec',
n_urls, n_bytes, finish - start)
logging.info("===== baseline benchmark finished =====")
proxy = "http://%s:%s" % (
warcprox_controller.proxy.server_address[0],
warcprox_controller.proxy.server_address[1])
logging.info("===== warcprox benchmark starting =====") logging.info("===== warcprox benchmark starting =====")
benchmark(proxied_client) start = time.time()
n_urls, n_bytes = loop.run_until_complete(
benchmarking_client('http://127.0.0.1:4080', 100, proxy))
finish = time.time()
logging.info(
'http: n_urls=%s n_bytes=%s in %.1f sec',
n_urls, n_bytes, finish - start)
start = time.time()
n_urls, n_bytes = loop.run_until_complete(
benchmarking_client('https://127.0.0.1:4443', 100, proxy))
finish = time.time()
logging.info(
'https: n_urls=%s n_bytes=%s in %.1f sec',
n_urls, n_bytes, finish - start)
logging.info("===== warcprox benchmark finished =====") logging.info("===== warcprox benchmark finished =====")
warcprox_controller.stop.set() warcprox_controller.stop.set()
warcprox_controller_thread.join() warcprox_controller_thread.join()
asyncio.get_event_loop().stop()
logging.info("finished")

View File

@ -51,7 +51,7 @@ except:
setuptools.setup( setuptools.setup(
name='warcprox', name='warcprox',
version='2.1b1.dev77', version='2.1b1.dev78',
description='WARC writing MITM HTTP/S proxy', description='WARC writing MITM HTTP/S proxy',
url='https://github.com/internetarchive/warcprox', url='https://github.com/internetarchive/warcprox',
author='Noah Levitt', author='Noah Levitt',

View File

@ -4,7 +4,7 @@ the table is "big" in the sense that it is designed to be usable as an index
for playback software outside of warcprox, and contains information not for playback software outside of warcprox, and contains information not
needed merely for deduplication needed merely for deduplication
Copyright (C) 2015-2016 Internet Archive Copyright (C) 2015-2017 Internet Archive
This program is free software; you can redistribute it and/or This program is free software; you can redistribute it and/or
modify it under the terms of the GNU General Public License modify it under the terms of the GNU General Public License
@ -64,26 +64,30 @@ class RethinkCaptures:
def _insert_batch(self): def _insert_batch(self):
try: try:
with self._batch_lock: with self._batch_lock:
if len(self._batch) > 0: batch = self._batch
result = self.rr.table(self.table).insert( self._batch = []
self._batch, conflict="replace").run()
if (result["inserted"] + result["replaced"] if batch:
+ result["unchanged"] != len(self._batch)): result = self.rr.table(self.table).insert(
raise Exception( batch, conflict="replace").run()
"unexpected result saving batch of %s: %s " if (result["inserted"] + result["replaced"]
"entries" % (len(self._batch), result)) + result["unchanged"] != len(batch)):
if result["replaced"] > 0 or result["unchanged"] > 0: raise Exception(
self.logger.warn( "unexpected result saving batch of %s: %s "
"inserted=%s replaced=%s unchanged=%s in big " "entries" % (len(batch), result))
"captures table (normally replaced=0 and " if result["replaced"] > 0 or result["unchanged"] > 0:
"unchanged=0)", result["inserted"], self.logger.warn(
result["replaced"], result["unchanged"]) "inserted=%s replaced=%s unchanged=%s in big "
else: "captures table (normally replaced=0 and "
self.logger.debug( "unchanged=0)", result["inserted"],
"inserted %s entries to big captures table", result["replaced"], result["unchanged"])
len(self._batch)) else:
self._batch = [] self.logger.debug(
except BaseException as e: "inserted %s entries to big captures table",
len(batch))
except Exception as e:
with self._batch_lock:
self._batch.extend(batch)
self.logger.error( self.logger.error(
"caught exception trying to save %s entries, they will " "caught exception trying to save %s entries, they will "
"be included in the next batch", len(self._batch), "be included in the next batch", len(self._batch),

View File

@ -238,11 +238,11 @@ class WarcproxController(object):
).total_seconds() > self.HEARTBEAT_INTERVAL): ).total_seconds() > self.HEARTBEAT_INTERVAL):
self._service_heartbeat() self._service_heartbeat()
if self.options.profile and ( # if self.options.profile and (
datetime.datetime.utcnow() - last_mem_dbg # datetime.datetime.utcnow() - last_mem_dbg
).total_seconds() > 60: # ).total_seconds() > 60:
self.debug_mem() # self.debug_mem()
last_mem_dbg = datetime.datetime.utcnow() # last_mem_dbg = datetime.datetime.utcnow()
time.sleep(0.5) time.sleep(0.5)
except: except:

View File

@ -52,6 +52,7 @@ class CaptureFeed:
return self.__producer return self.__producer
def notify(self, recorded_url, records): def notify(self, recorded_url, records):
import pdb; pdb.set_trace()
if records[0].type not in (b'revisit', b'response'): if records[0].type not in (b'revisit', b'response'):
return return

View File

@ -104,7 +104,8 @@ def _build_arg_parser(prog=os.path.basename(sys.argv[0])):
default=500, help=argparse.SUPPRESS) default=500, help=argparse.SUPPRESS)
arg_parser.add_argument('--max-threads', dest='max_threads', type=int, arg_parser.add_argument('--max-threads', dest='max_threads', type=int,
help=argparse.SUPPRESS) help=argparse.SUPPRESS)
arg_parser.add_argument('--profile', action='store_true', default=False, arg_parser.add_argument(
'--profile', dest='profile', action='store_true', default=False,
help=argparse.SUPPRESS) help=argparse.SUPPRESS)
arg_parser.add_argument('--onion-tor-socks-proxy', dest='onion_tor_socks_proxy', arg_parser.add_argument('--onion-tor-socks-proxy', dest='onion_tor_socks_proxy',
default=None, help='host:port of tor socks proxy, used only to connect to .onion sites') default=None, help='host:port of tor socks proxy, used only to connect to .onion sites')

View File

@ -110,7 +110,9 @@ class WarcWriter:
record.offset = offset record.offset = offset
record.length = writer.tell() - offset record.length = writer.tell() - offset
record.warc_filename = self._f_finalname record.warc_filename = self._f_finalname
self.logger.debug('wrote warc record: warc_type=%s content_length=%s url=%s warc=%s offset=%d', self.logger.debug(
'wrote warc record: warc_type=%s content_length=%s url=%s '
'warc=%s offset=%d',
record.get_header(warctools.WarcRecord.TYPE), record.get_header(warctools.WarcRecord.TYPE),
record.get_header(warctools.WarcRecord.CONTENT_LENGTH), record.get_header(warctools.WarcRecord.CONTENT_LENGTH),
record.get_header(warctools.WarcRecord.URL), record.get_header(warctools.WarcRecord.URL),

View File

@ -89,8 +89,9 @@ class WarcWriterThread(threading.Thread):
self.idle = None self.idle = None
if self._filter_accepts(recorded_url): if self._filter_accepts(recorded_url):
if self.dedup_db: if self.dedup_db:
warcprox.dedup.decorate_with_dedup_info(self.dedup_db, warcprox.dedup.decorate_with_dedup_info(
recorded_url, base32=self.options.base32) self.dedup_db, recorded_url,
base32=self.options.base32)
records = self.writer_pool.write_records(recorded_url) records = self.writer_pool.write_records(recorded_url)
self._final_tasks(recorded_url, records) self._final_tasks(recorded_url, records)