use tempfile.SpooledTemporaryFile to overflow recorded response to disk

This commit is contained in:
Noah Levitt 2013-10-17 12:58:17 -07:00
parent 039f892024
commit e6a897412b

View File

@ -22,6 +22,7 @@ import httplib
import re import re
import signal import signal
import time import time
import tempfile
class CertificateAuthority(object): class CertificateAuthority(object):
@ -115,23 +116,25 @@ class ProxyingRecorder:
def __init__(self, fp, proxy_dest): def __init__(self, fp, proxy_dest):
self.fp = fp self.fp = fp
self.data = bytearray('') self.buf = tempfile.SpooledTemporaryFile(max_size=1024)
self.block_sha1 = hashlib.sha1() self.block_sha1 = hashlib.sha1()
self.payload_sha1 = None self.payload_sha1 = None
self.proxy_dest = proxy_dest self.proxy_dest = proxy_dest
self._prev_hunk_last_two_bytes = ''
self.len = 0
def _update(self, hunk): def _update(self, hunk):
if self.payload_sha1 is None: if self.payload_sha1 is None:
# convoluted handling of two newlines crossing chunks # convoluted handling of two newlines crossing hunks
# XXX write tests for this # XXX write tests for this
if self.data.endswith('\n'): if self._prev_hunk_last_two_bytes.endswith('\n'):
if hunk.startswith('\n'): if hunk.startswith('\n'):
self.payload_sha1 = hashlib.sha1() self.payload_sha1 = hashlib.sha1()
self.payload_sha1.update(hunk[1:]) self.payload_sha1.update(hunk[1:])
elif hunk.startswith('\r\n'): elif hunk.startswith('\r\n'):
self.payload_sha1 = hashlib.sha1() self.payload_sha1 = hashlib.sha1()
self.payload_sha1.update(hunk[2:]) self.payload_sha1.update(hunk[2:])
elif self.data.endswith('\n\r'): elif self._prev_hunk_last_two_bytes == '\n\r':
if hunk.startswith('\n'): if hunk.startswith('\n'):
self.payload_sha1 = hashlib.sha1() self.payload_sha1 = hashlib.sha1()
self.payload_sha1.update(hunk[1:]) self.payload_sha1.update(hunk[1:])
@ -140,13 +143,18 @@ class ProxyingRecorder:
if m is not None: if m is not None:
self.payload_sha1 = hashlib.sha1() self.payload_sha1 = hashlib.sha1()
self.payload_sha1.update(hunk[m.end():]) self.payload_sha1.update(hunk[m.end():])
# if we still haven't found start of payload hold on to these bytes
if self.payload_sha1 is None:
self._prev_hunk_last_two_bytes = hunk[-2:]
else: else:
self.payload_sha1.update(hunk) self.payload_sha1.update(hunk)
self.block_sha1.update(hunk) self.block_sha1.update(hunk)
self.data.extend(hunk) self.buf.write(hunk)
self.proxy_dest.sendall(hunk) self.proxy_dest.sendall(hunk)
self.len += len(hunk)
def read(self, size=-1): def read(self, size=-1):
hunk = self.fp.read(size=size) hunk = self.fp.read(size=size)
@ -162,6 +170,9 @@ class ProxyingRecorder:
def close(self): def close(self):
return self.fp.close() return self.fp.close()
def __len__(self):
return self.len
class ProxyingRecordingHTTPResponse(httplib.HTTPResponse): class ProxyingRecordingHTTPResponse(httplib.HTTPResponse):
@ -173,9 +184,6 @@ class ProxyingRecordingHTTPResponse(httplib.HTTPResponse):
self.recorder = ProxyingRecorder(self.fp, proxy_dest) self.recorder = ProxyingRecorder(self.fp, proxy_dest)
self.fp = self.recorder self.fp = self.recorder
def recorded(self):
return self.recorder.recorded
class ProxyHandler(BaseHTTPServer.BaseHTTPRequestHandler): class ProxyHandler(BaseHTTPServer.BaseHTTPRequestHandler):
@ -396,7 +404,7 @@ class WarcRecordQueuer:
def do_response(self, recorder): def do_response(self, recorder):
logging.info('{0} << {1}'.format(self.url, repr(recorder.data[:40]))) logging.info('{} << {} bytes'.format(self.url, len(recorder)))
record_id = WarcRecordQueuer.make_warc_uuid("{0} {1}".format(self.url, self._warc_date())) record_id = WarcRecordQueuer.make_warc_uuid("{0} {1}".format(self.url, self._warc_date()))
@ -406,13 +414,14 @@ class WarcRecordQueuer:
headers.append((warctools.WarcRecord.URL, self.url)) headers.append((warctools.WarcRecord.URL, self.url))
headers.append((warctools.WarcRecord.DATE, self._warc_date())) headers.append((warctools.WarcRecord.DATE, self._warc_date()))
headers.append((warctools.WarcRecord.BLOCK_DIGEST, 'sha1:{}'.format(recorder.block_sha1.hexdigest()))) headers.append((warctools.WarcRecord.BLOCK_DIGEST, 'sha1:{}'.format(recorder.block_sha1.hexdigest())))
headers.append((warctools.WarcRecord.CONTENT_TYPE, "application/http;msgtype=response"))
headers.append((warctools.WarcRecord.CONTENT_LENGTH, str(len(recorder))))
if recorder.payload_sha1 is not None: if recorder.payload_sha1 is not None:
headers.append((warctools.WarcRecord.PAYLOAD_DIGEST, 'sha1:{}'.format(recorder.payload_sha1.hexdigest()))) headers.append((warctools.WarcRecord.PAYLOAD_DIGEST, 'sha1:{}'.format(recorder.payload_sha1.hexdigest())))
# headers.append((warctools.WarcRecord.IP_ADDRESS, ip)) # headers.append((warctools.WarcRecord.IP_ADDRESS, ip))
content_tuple = ("application/http;msgtype=response", recorder.data) recorder.buf.seek(0)
response_record = warctools.WarcRecord(headers=headers, content_file=recorder.buf)
response_record = warctools.WarcRecord(headers=headers, content=content_tuple)
try: try:
self._request_record.set_header(warctools.WarcRecord.CONCURRENT_TO, record_id) self._request_record.set_header(warctools.WarcRecord.CONCURRENT_TO, record_id)
@ -441,7 +450,7 @@ class WarcWriterThread(threading.Thread):
self._serial = 0 self._serial = 0
if not os.path.exists(directory): if not os.path.exists(directory):
logging.info("warc destination directory {0} doesn't exist, creating it".format(directory)) logging.info("warc destination directory {} doesn't exist, creating it".format(directory))
os.mkdir(directory) os.mkdir(directory)
self.stop = threading.Event() self.stop = threading.Event()
@ -449,7 +458,7 @@ class WarcWriterThread(threading.Thread):
def timestamp17(self): def timestamp17(self):
now = datetime.now() now = datetime.now()
return '{0}{1}'.format(now.strftime('%Y%m%d%H%M%S'), now.microsecond//1000) return '{}{}'.format(now.strftime('%Y%m%d%H%M%S'), now.microsecond//1000)
def _close_writer(self): def _close_writer(self):