1
0
mirror of https://github.com/webrecorder/pywb.git synced 2025-03-15 08:04:49 +01:00

recorder: move skip_response() check to occur before response is sent, rather than at the end

filters: replace SkipNothingFilter with SkipDefaultFilter which checks for 'Recorder-Skip', call base filter checks on all filters
This commit is contained in:
Ilya Kreymer 2017-06-01 14:03:56 -07:00
parent 06b1134be5
commit eac5d18985
2 changed files with 25 additions and 20 deletions

View File

@ -59,16 +59,22 @@ class WriteDupePolicy(object):
# ============================================================================
# Skip Record Filters
# ============================================================================
class SkipNothingFilter(object):
class SkipDefaultFilter(object):
def skip_request(self, path, req_headers):
if req_headers.get('Recorder-Skip') == '1':
return True
return False
def skip_response(self, path, req_headers, resp_headers, params):
if resp_headers.get('Recorder-Skip') == '1':
return True
return False
# ============================================================================
class CollectionFilter(SkipNothingFilter):
class CollectionFilter(SkipDefaultFilter):
def __init__(self, accept_colls):
self.rx_accept_map = {}
@ -79,14 +85,9 @@ class CollectionFilter(SkipNothingFilter):
for name in accept_colls:
self.rx_accept_map[name] = re.compile(accept_colls[name])
def skip_request(self, path, req_headers):
if req_headers.get('Recorder-Skip') == '1':
return True
return False
def skip_response(self, path, req_headers, resp_headers, params):
if resp_headers.get('Recorder-Skip') == '1':
if super(CollectionFilter, self).skip_response(path, req_headers,
resp_headers, params):
return True
path = path[1:].split('/', 1)[0]
@ -102,8 +103,12 @@ class CollectionFilter(SkipNothingFilter):
# ============================================================================
class SkipRangeRequestFilter(SkipNothingFilter):
class SkipRangeRequestFilter(SkipDefaultFilter):
def skip_request(self, path, req_headers):
if super(SkipRangeRequestFilter, self).skip_request(path,
req_headers):
return True
range_ = req_headers.get('Range')
if range_ and not range_.lower().startswith('bytes=0-'):
return True

View File

@ -225,7 +225,13 @@ class RecorderApp(object):
req_stream.out.close()
return self.send_error(e, start_response)
start_response('200 OK', list(res.headers.items()))
if not skipping:
skipping = any(x.skip_response(path,
req_stream.headers,
res.headers,
params)
for x in self.skip_filters)
if not skipping:
resp_stream = RespWrapper(res.raw,
@ -233,14 +239,15 @@ class RecorderApp(object):
req_stream,
params,
self.write_queue,
self.skip_filters,
path,
self.create_buff_func)
else:
resp_stream = res.raw
resp_iter = StreamIter(resp_stream)
start_response('200 OK', list(res.headers.items()))
return resp_iter
@ -267,13 +274,12 @@ class Wrapper(object):
#==============================================================================
class RespWrapper(Wrapper):
def __init__(self, stream, headers, req,
params, queue, skip_filters, path, create_func):
params, queue, path, create_func):
super(RespWrapper, self).__init__(stream, params, create_func)
self.headers = headers
self.req = req
self.queue = queue
self.skip_filters = skip_filters
self.path = path
def close(self):
@ -299,12 +305,6 @@ class RespWrapper(Wrapper):
try:
if self.interrupted:
skipping = True
else:
skipping = any(x.skip_response(self.path,
self.req.headers,
self.headers,
self.params)
for x in self.skip_filters)
if not skipping:
entry = (self.req.headers, self.req.out,