diff --git a/rethinkstuff/__init__.py b/rethinkstuff/__init__.py index d248c59..bcc985c 100644 --- a/rethinkstuff/__init__.py +++ b/rethinkstuff/__init__.py @@ -6,6 +6,7 @@ import time import types class RethinkerWrapper: + logger = logging.getLogger('rethinkstuff.RethinkerWrapper') def __init__(self, rethinker, wrapped): self.rethinker = rethinker self.wrapped = wrapped @@ -22,23 +23,32 @@ class RethinkerWrapper: try: result = self.wrapped.run(conn, db=db or self.rethinker.db) if hasattr(result, "__next__"): + is_iter = True def gen(): try: + yield # empty yield, see comment below for x in result: yield x finally: + self.logger.info("iterator finished, closing connection %s", conn) conn.close() - return gen() + g = gen() + # Start executing the generator, leaving off after the + # empty yield. If we didn't do this, and the caller never + # started the generator, the finally block would never run + # and the connection would stay open. + next(g) + return g else: return result except (r.ReqlAvailabilityError, r.ReqlTimeoutError) as e: pass finally: if not is_iter: + self.logger.info("closing connection %s", conn) conn.close(noreply_wait=False) - -class Rethinker: +class Rethinker(object): """ >>> r = Rethinker(db="my_db") >>> doc = r.table("my_table").get(1).run() diff --git a/tests/test_rethinker.py b/tests/test_rethinker.py index 34d8c02..f00b0d8 100644 --- a/tests/test_rethinker.py +++ b/tests/test_rethinker.py @@ -2,26 +2,64 @@ import rethinkstuff import logging import sys import types +import gc logging.basicConfig(stream=sys.stderr, level=logging.INFO, format="%(asctime)s %(process)d %(levelname)s %(threadName)s %(name)s.%(funcName)s(%(filename)s:%(lineno)d) %(message)s") +class RethinkerForTesting(rethinkstuff.Rethinker): + def __init__(self, *args, **kwargs): + super(RethinkerForTesting, self).__init__(*args, **kwargs) + + def _random_server_connection(self): + self.last_conn = super(RethinkerForTesting, self)._random_server_connection() + logging.info("self.last_conn=%s", self.last_conn) + return self.last_conn + def test_rethinker(): - r = rethinkstuff.Rethinker() + r = RethinkerForTesting() result = r.db_create("my_db").run() + assert not r.last_conn.is_open() assert result["dbs_created"] == 1 - r = rethinkstuff.Rethinker(db="my_db") + r = RethinkerForTesting(db="my_db") assert r.table_list().run() == [] result = r.table_create("my_table").run() + assert not r.last_conn.is_open() assert result["tables_created"] == 1 assert r.table("my_table").index_create("foo").run() == {"created": 1} + assert not r.last_conn.is_open() result = r.table("my_table").insert(({"foo":i,"bar":"repeat"*i} for i in range(2000))).run() + assert not r.last_conn.is_open() assert len(result["generated_keys"]) == 2000 assert result["inserted"] == 2000 result = r.table("my_table").run() + assert r.last_conn.is_open() # should still be open this time assert isinstance(result, types.GeneratorType) + n = 0 + for x in result: + n += 1 + pass + # connection should be closed after finished iterating over results + assert not r.last_conn.is_open() + assert n == 2000 + + result = r.table("my_table").run() + assert r.last_conn.is_open() # should still be open this time + assert isinstance(result, types.GeneratorType) + next(result) + result = None + gc.collect() + # connection should be closed after result is garbage-collected + assert not r.last_conn.is_open() + + result = r.table("my_table").run() + assert r.last_conn.is_open() # should still be open this time + assert isinstance(result, types.GeneratorType) + result = None + gc.collect() + assert not r.last_conn.is_open()