make sure connection closes when it's supposed to

This commit is contained in:
Noah Levitt 2015-09-21 23:30:43 +00:00
parent aa080ce2e9
commit 50df439706
2 changed files with 53 additions and 5 deletions

View File

@ -6,6 +6,7 @@ import time
import types
class RethinkerWrapper:
logger = logging.getLogger('rethinkstuff.RethinkerWrapper')
def __init__(self, rethinker, wrapped):
self.rethinker = rethinker
self.wrapped = wrapped
@ -22,23 +23,32 @@ class RethinkerWrapper:
try:
result = self.wrapped.run(conn, db=db or self.rethinker.db)
if hasattr(result, "__next__"):
is_iter = True
def gen():
try:
yield # empty yield, see comment below
for x in result:
yield x
finally:
self.logger.info("iterator finished, closing connection %s", conn)
conn.close()
return gen()
g = gen()
# Start executing the generator, leaving off after the
# empty yield. If we didn't do this, and the caller never
# started the generator, the finally block would never run
# and the connection would stay open.
next(g)
return g
else:
return result
except (r.ReqlAvailabilityError, r.ReqlTimeoutError) as e:
pass
finally:
if not is_iter:
self.logger.info("closing connection %s", conn)
conn.close(noreply_wait=False)
class Rethinker:
class Rethinker(object):
"""
>>> r = Rethinker(db="my_db")
>>> doc = r.table("my_table").get(1).run()

View File

@ -2,26 +2,64 @@ import rethinkstuff
import logging
import sys
import types
import gc
logging.basicConfig(stream=sys.stderr, level=logging.INFO,
format="%(asctime)s %(process)d %(levelname)s %(threadName)s %(name)s.%(funcName)s(%(filename)s:%(lineno)d) %(message)s")
class RethinkerForTesting(rethinkstuff.Rethinker):
def __init__(self, *args, **kwargs):
super(RethinkerForTesting, self).__init__(*args, **kwargs)
def _random_server_connection(self):
self.last_conn = super(RethinkerForTesting, self)._random_server_connection()
logging.info("self.last_conn=%s", self.last_conn)
return self.last_conn
def test_rethinker():
r = rethinkstuff.Rethinker()
r = RethinkerForTesting()
result = r.db_create("my_db").run()
assert not r.last_conn.is_open()
assert result["dbs_created"] == 1
r = rethinkstuff.Rethinker(db="my_db")
r = RethinkerForTesting(db="my_db")
assert r.table_list().run() == []
result = r.table_create("my_table").run()
assert not r.last_conn.is_open()
assert result["tables_created"] == 1
assert r.table("my_table").index_create("foo").run() == {"created": 1}
assert not r.last_conn.is_open()
result = r.table("my_table").insert(({"foo":i,"bar":"repeat"*i} for i in range(2000))).run()
assert not r.last_conn.is_open()
assert len(result["generated_keys"]) == 2000
assert result["inserted"] == 2000
result = r.table("my_table").run()
assert r.last_conn.is_open() # should still be open this time
assert isinstance(result, types.GeneratorType)
n = 0
for x in result:
n += 1
pass
# connection should be closed after finished iterating over results
assert not r.last_conn.is_open()
assert n == 2000
result = r.table("my_table").run()
assert r.last_conn.is_open() # should still be open this time
assert isinstance(result, types.GeneratorType)
next(result)
result = None
gc.collect()
# connection should be closed after result is garbage-collected
assert not r.last_conn.is_open()
result = r.table("my_table").run()
assert r.last_conn.is_open() # should still be open this time
assert isinstance(result, types.GeneratorType)
result = None
gc.collect()
assert not r.last_conn.is_open()