1
0
mirror of https://github.com/webrecorder/pywb.git synced 2025-03-24 06:59:52 +01:00

recorder: move skip_response() check to occur before response is sent, rather than at the end

filters: replace SkipNothingFilter with SkipDefaultFilter which checks for 'Recorder-Skip', call base filter checks on all filters
This commit is contained in:
Ilya Kreymer 2017-06-01 14:03:56 -07:00
parent 06b1134be5
commit eac5d18985
2 changed files with 25 additions and 20 deletions

View File

@ -59,16 +59,22 @@ class WriteDupePolicy(object):
# ============================================================================ # ============================================================================
# Skip Record Filters # Skip Record Filters
# ============================================================================ # ============================================================================
class SkipNothingFilter(object): class SkipDefaultFilter(object):
def skip_request(self, path, req_headers): def skip_request(self, path, req_headers):
if req_headers.get('Recorder-Skip') == '1':
return True
return False return False
def skip_response(self, path, req_headers, resp_headers, params): def skip_response(self, path, req_headers, resp_headers, params):
if resp_headers.get('Recorder-Skip') == '1':
return True
return False return False
# ============================================================================ # ============================================================================
class CollectionFilter(SkipNothingFilter): class CollectionFilter(SkipDefaultFilter):
def __init__(self, accept_colls): def __init__(self, accept_colls):
self.rx_accept_map = {} self.rx_accept_map = {}
@ -79,14 +85,9 @@ class CollectionFilter(SkipNothingFilter):
for name in accept_colls: for name in accept_colls:
self.rx_accept_map[name] = re.compile(accept_colls[name]) self.rx_accept_map[name] = re.compile(accept_colls[name])
def skip_request(self, path, req_headers):
if req_headers.get('Recorder-Skip') == '1':
return True
return False
def skip_response(self, path, req_headers, resp_headers, params): def skip_response(self, path, req_headers, resp_headers, params):
if resp_headers.get('Recorder-Skip') == '1': if super(CollectionFilter, self).skip_response(path, req_headers,
resp_headers, params):
return True return True
path = path[1:].split('/', 1)[0] path = path[1:].split('/', 1)[0]
@ -102,8 +103,12 @@ class CollectionFilter(SkipNothingFilter):
# ============================================================================ # ============================================================================
class SkipRangeRequestFilter(SkipNothingFilter): class SkipRangeRequestFilter(SkipDefaultFilter):
def skip_request(self, path, req_headers): def skip_request(self, path, req_headers):
if super(SkipRangeRequestFilter, self).skip_request(path,
req_headers):
return True
range_ = req_headers.get('Range') range_ = req_headers.get('Range')
if range_ and not range_.lower().startswith('bytes=0-'): if range_ and not range_.lower().startswith('bytes=0-'):
return True return True

View File

@ -225,7 +225,13 @@ class RecorderApp(object):
req_stream.out.close() req_stream.out.close()
return self.send_error(e, start_response) return self.send_error(e, start_response)
start_response('200 OK', list(res.headers.items())) if not skipping:
skipping = any(x.skip_response(path,
req_stream.headers,
res.headers,
params)
for x in self.skip_filters)
if not skipping: if not skipping:
resp_stream = RespWrapper(res.raw, resp_stream = RespWrapper(res.raw,
@ -233,14 +239,15 @@ class RecorderApp(object):
req_stream, req_stream,
params, params,
self.write_queue, self.write_queue,
self.skip_filters,
path, path,
self.create_buff_func) self.create_buff_func)
else: else:
resp_stream = res.raw resp_stream = res.raw
resp_iter = StreamIter(resp_stream) resp_iter = StreamIter(resp_stream)
start_response('200 OK', list(res.headers.items()))
return resp_iter return resp_iter
@ -267,13 +274,12 @@ class Wrapper(object):
#============================================================================== #==============================================================================
class RespWrapper(Wrapper): class RespWrapper(Wrapper):
def __init__(self, stream, headers, req, def __init__(self, stream, headers, req,
params, queue, skip_filters, path, create_func): params, queue, path, create_func):
super(RespWrapper, self).__init__(stream, params, create_func) super(RespWrapper, self).__init__(stream, params, create_func)
self.headers = headers self.headers = headers
self.req = req self.req = req
self.queue = queue self.queue = queue
self.skip_filters = skip_filters
self.path = path self.path = path
def close(self): def close(self):
@ -299,12 +305,6 @@ class RespWrapper(Wrapper):
try: try:
if self.interrupted: if self.interrupted:
skipping = True skipping = True
else:
skipping = any(x.skip_response(self.path,
self.req.headers,
self.headers,
self.params)
for x in self.skip_filters)
if not skipping: if not skipping:
entry = (self.req.headers, self.req.out, entry = (self.req.headers, self.req.out,