Close database connections (if applicable) on CachedSession.__exit__ and close()
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py
index 250b5e1..0a551ea 100644
--- a/requests_cache/backends/base.py
+++ b/requests_cache/backends/base.py
@@ -100,6 +100,12 @@
self.responses.clear()
self.redirects.clear()
+ def close(self):
+ """Close any open backend connections"""
+ logger.debug('Closing backend connections')
+ self.responses.close()
+ self.redirects.close()
+
def create_key(self, request: PreparedRequest = None, **kwargs) -> str:
"""Create a normalized cache key from a request object"""
key_fn = self._settings.key_fn or create_key
@@ -272,6 +278,9 @@
except KeyError:
pass
+ def close(self):
+ """Close any open backend connections"""
+
def __str__(self):
return str(list(self.keys()))
diff --git a/requests_cache/backends/mongodb.py b/requests_cache/backends/mongodb.py
index 7623319..6d4175d 100644
--- a/requests_cache/backends/mongodb.py
+++ b/requests_cache/backends/mongodb.py
@@ -136,6 +136,9 @@
def clear(self):
self.collection.drop()
+ def close(self):
+ self.connection.close()
+
class MongoPickleDict(MongoDict):
"""Same as :class:`MongoDict`, but serializes values before saving.
diff --git a/requests_cache/backends/redis.py b/requests_cache/backends/redis.py
index d79af35..5836714 100644
--- a/requests_cache/backends/redis.py
+++ b/requests_cache/backends/redis.py
@@ -102,6 +102,9 @@
def clear(self):
self.bulk_delete(self.keys())
+ def close(self):
+ self.connection.close()
+
def keys(self):
return [
decode(key).replace(f'{self.namespace}:', '')
diff --git a/requests_cache/session.py b/requests_cache/session.py
index 68d0c68..9f46d0e 100644
--- a/requests_cache/session.py
+++ b/requests_cache/session.py
@@ -292,6 +292,11 @@
finally:
self.settings.disabled = False
+ def close(self):
+ """Close the session and any open backend connections"""
+ super().close()
+ self.cache.close()
+
def remove_expired_responses(self, expire_after: ExpirationTime = None):
"""Remove expired responses from the cache, optionally with revalidation
diff --git a/tests/integration/base_cache_test.py b/tests/integration/base_cache_test.py
index 6381a42..0ec1d51 100644
--- a/tests/integration/base_cache_test.py
+++ b/tests/integration/base_cache_test.py
@@ -71,7 +71,8 @@
@classmethod
def teardown_class(cls):
- cls().init_session(clear=True)
+ session = cls().init_session(clear=True)
+ session.close()
@pytest.mark.parametrize('serializer', TEST_SERIALIZERS.values())
@pytest.mark.parametrize('method', HTTPBIN_METHODS)