diff --git a/pywb/rewrite/html_rewriter.py b/pywb/rewrite/html_rewriter.py index 9895ce2e..5a10d651 100644 --- a/pywb/rewrite/html_rewriter.py +++ b/pywb/rewrite/html_rewriter.py @@ -74,8 +74,6 @@ class HTMLRewriterMixin(object): self.url_rewriter = url_rewriter self._wb_parse_context = None - #self.out = outstream if outstream else self.AccumBuff() - self.out = self.AccumBuff() self.js_rewriter = js_rewriter_class(url_rewriter) self.css_rewriter = css_rewriter_class(url_rewriter) @@ -218,17 +216,32 @@ class HTMLRewriterMixin(object): self.out.write(data) def rewrite(self, string): - if not self.out: - self.out = self.AccumBuff() + self.out = self.AccumBuff() self.feed(string) result = self.out.getvalue() + # Clear buffer to create new one for next rewrite() self.out = None return result + def close(self): + self.out = self.AccumBuff() + + self._internal_close() + + result = self.out.getvalue() + + # Clear buffer to create new one for next rewrite() + self.out = None + + return result + + def _internal_close(self): + pass + #================================================================= class HTMLRewriter(HTMLRewriterMixin, HTMLParser): @@ -243,30 +256,23 @@ class HTMLRewriter(HTMLRewriterMixin, HTMLParser): js_rewriter_class, css_rewriter_class) - # HTMLParser overrides below def feed(self, string): try: HTMLParser.feed(self, string) except HTMLParseError: self.out.write(string) - def close(self): + def _internal_close(self): if (self._wb_parse_context): end_tag = '' - result = self.rewrite(end_tag) - if result.endswith(end_tag): - result = result[:-len(end_tag)] + self.feed(end_tag) self._wb_parse_context = None - else: - result = '' try: HTMLParser.close(self) except HTMLParseError: pass - return result - # called to unescape attrs -- do not unescape! def unescape(self, s): return s diff --git a/pywb/rewrite/lxml_html_rewriter.py b/pywb/rewrite/lxml_html_rewriter.py index b245d055..415334d3 100644 --- a/pywb/rewrite/lxml_html_rewriter.py +++ b/pywb/rewrite/lxml_html_rewriter.py @@ -36,19 +36,9 @@ class LXMLHTMLRewriter(HTMLRewriterMixin): #string = string.replace(u'', u'') self.parser.feed(string) - def close(self): - if not self.out: - self.out = self.AccumBuff() - - self.is_closing = True + def _internal_close(self): self.parser.close() - result = self.out.getvalue() - # Clear buffer to create new one for next rewrite() - self.out = None - - return result - #================================================================= class RewriterTarget(object): diff --git a/pywb/rewrite/test/test_html_rewriter.py b/pywb/rewrite/test/test_html_rewriter.py index ed117e9e..ac12c310 100644 --- a/pywb/rewrite/test/test_html_rewriter.py +++ b/pywb/rewrite/test/test_html_rewriter.py @@ -53,9 +53,9 @@ ur""" >>> parse('') -# Unterminated script tag, handle but don't auto-terminate +# Unterminated script tag, handle and auto-terminate >>> parse(' >>> parse('') @@ -66,9 +66,9 @@ ur""" >>> parse('') -# Unterminated style tag, handle but don't auto-terminate +# Unterminated style tag, handle and auto-terminate >>> parse(' # Head Insertion >>> parse('Test', head_insert = '')