diff --git a/warcprox/__init__.py b/warcprox/__init__.py index e2c8df7..852d3fc 100644 --- a/warcprox/__init__.py +++ b/warcprox/__init__.py @@ -81,11 +81,14 @@ class RequestBlockedByRule(Exception): class BasePostfetchProcessor(threading.Thread): logger = logging.getLogger("warcprox.BasePostfetchProcessor") - def __init__(self, options=Options()): + def __init__(self, options=Options(), controller=None, **kwargs): threading.Thread.__init__(self, name=self.__class__.__name__) self.options = options + self.controller = controller + self.stop = threading.Event() - # these should be set before thread is started + + # these should be set by the caller before thread is started self.inq = None self.outq = None self.profiler = None @@ -205,8 +208,8 @@ class BaseBatchPostfetchProcessor(BasePostfetchProcessor): raise Exception('not implemented') class ListenerPostfetchProcessor(BaseStandardPostfetchProcessor): - def __init__(self, listener, options=Options()): - BaseStandardPostfetchProcessor.__init__(self, options) + def __init__(self, listener, options=Options(), controller=None, **kwargs): + BaseStandardPostfetchProcessor.__init__(self, options, controller, **kwargs) self.listener = listener self.name = listener.__class__.__name__ diff --git a/warcprox/controller.py b/warcprox/controller.py index 0b3daef..fcdaa58 100644 --- a/warcprox/controller.py +++ b/warcprox/controller.py @@ -93,15 +93,19 @@ class Factory: return None @staticmethod - def plugin(qualname, options): + def plugin(qualname, options, controller=None): try: (module_name, class_name) = qualname.rsplit('.', 1) module_ = importlib.import_module(module_name) class_ = getattr(module_, class_name) - try: # new plugins take `options` argument - plugin = class_(options) - except: # backward-compatibility - plugin = class_() + try: + # new plugins take `options` and `controller` arguments + plugin = class_(options, controller) + except: + try: # medium plugins take `options` argument + plugin = class_(options) + except: # old plugins take no arguments + plugin = class_() # check that this is either a listener or a batch processor assert hasattr(plugin, 'notify') ^ hasattr(plugin, '_startup') return plugin @@ -229,7 +233,7 @@ class WarcproxController(object): crawl_logger, self.options)) for qualname in self.options.plugins or []: - plugin = Factory.plugin(qualname, self.options) + plugin = Factory.plugin(qualname, self.options, self) if hasattr(plugin, 'notify'): self._postfetch_chain.append( warcprox.ListenerPostfetchProcessor(