Improvements for MongoDB:
* Use native document format (BSON) instead of binary blob
* Add option to use native TTL feature
diff --git a/requests_cache/backends/mongodb.py b/requests_cache/backends/mongodb.py
index a6b94ee..8ba68da 100644
--- a/requests_cache/backends/mongodb.py
+++ b/requests_cache/backends/mongodb.py
@@ -12,14 +12,21 @@
it is not quite as fast as :py:mod:`~requests_cache.backends.redis`, but may be preferable if you
already have a MongoDB instance you're using for other purposes, or if you find it easier to use.
+Expiration
+^^^^^^^^^^
+MongoDB natively supports TTL, and can automatically remove expired responses from the cache.
+Note that this is `not guaranteed to happen immediately
+<https://www.mongodb.com/docs/v4.0/core/index-ttl/#timing-of-the-delete-operation>`_. This is the
+recommended way to expire responses, and you can leave the session ``expire_after`` as the default
+(never expire). Example:
+
+ >>> backend = MongoCache(ttl=3600)
+ >>> session = CachedSession('http_cache', backend=backend)
+
Connection Options
^^^^^^^^^^^^^^^^^^
The MongoDB backend accepts any keyword arguments for :py:class:`pymongo.mongo_client.MongoClient`.
-These can be passed via :py:class:`.CachedSession`:
-
- >>> session = CachedSession('http_cache', backend='mongodb', host='192.168.1.63', port=27017)
-
-Or via :py:class:`.MongoCache`:
+These can be passed via :py:class:`.MongoCache`:
>>> backend = MongoCache(host='192.168.1.63', port=27017)
>>> session = CachedSession('http_cache', backend=backend)
@@ -35,6 +42,7 @@
from pymongo import MongoClient
from .._utils import get_valid_kwargs
+from ..serializers import dict_serializer
from . import BaseCache, BaseStorage
@@ -47,13 +55,26 @@
kwargs: Additional keyword arguments for :py:class:`pymongo.mongo_client.MongoClient`
"""
- def __init__(self, db_name: str = 'http_cache', connection: MongoClient = None, **kwargs):
+ def __init__(
+ self,
+ db_name: str = 'http_cache',
+ connection: MongoClient = None,
+ ttl: int = None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
- self.responses = MongoPickleDict(db_name, 'responses', connection=connection, **kwargs)
+ self.responses = MongoPickleDict(
+ db_name,
+ collection_name='responses',
+ connection=connection,
+ ttl=ttl,
+ **kwargs,
+ )
self.redirects = MongoDict(
db_name,
collection_name='redirects',
connection=self.responses.connection,
+ ttl=ttl,
**kwargs,
)
@@ -68,11 +89,29 @@
kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient`
"""
- def __init__(self, db_name, collection_name='http_cache', connection=None, **kwargs):
+ def __init__(
+ self,
+ db_name: str,
+ collection_name: str = 'http_cache',
+ connection: MongoClient = None,
+ ttl: int = None,
+ **kwargs,
+ ):
super().__init__(**kwargs)
connection_kwargs = get_valid_kwargs(MongoClient, kwargs)
self.connection = connection or MongoClient(**connection_kwargs)
self.collection = self.connection[db_name][collection_name]
+ # Index will not be recreated if it already exists
+ # TODO: If TTL changes, drop and recreate index? Or just document that you need to manually
+ # call update_ttl()?
+ # TODO: Accept timedelta TTL
+ if ttl:
+ self.collection.create_index('created_at', expireAfterSeconds=ttl)
+
+ def update_ttl(self, ttl: int = None):
+ self.collection.drop_index('created_at')
+ if ttl:
+ self.collection.create_index('created_at', expireAfterSeconds=ttl)
def __getitem__(self, key):
result = self.collection.find_one({'_id': key})
@@ -105,7 +144,29 @@
class MongoPickleDict(MongoDict):
- """Same as :class:`MongoDict`, but pickles values before saving"""
+ """Same as :class:`MongoDict`, but serializes values before saving.
+
+ By default, responses are only partially serialized (unstructured into a dict), and stored as a
+ document.
+ """
+
+ def __init__(
+ self,
+ db_name: str,
+ collection_name: str = 'http_cache',
+ connection: MongoClient = None,
+ ttl: int = None,
+ serializer=None,
+ **kwargs,
+ ):
+ super().__init__(
+ db_name,
+ collection_name=collection_name,
+ connection=connection,
+ ttl=ttl,
+ serializer=serializer or dict_serializer,
+ **kwargs,
+ )
def __setitem__(self, key, item):
super().__setitem__(key, self.serializer.dumps(item))
diff --git a/requests_cache/serializers/__init__.py b/requests_cache/serializers/__init__.py
index 1e6e6c0..dc78489 100644
--- a/requests_cache/serializers/__init__.py
+++ b/requests_cache/serializers/__init__.py
@@ -29,6 +29,7 @@
SERIALIZERS = {
'bson': bson_serializer,
+ 'dict': dict_serializer,
'json': json_serializer,
'pickle': pickle_serializer,
'yaml': yaml_serializer,
diff --git a/requests_cache/serializers/preconf.py b/requests_cache/serializers/preconf.py
index cb099b8..1236c44 100644
--- a/requests_cache/serializers/preconf.py
+++ b/requests_cache/serializers/preconf.py
@@ -30,9 +30,8 @@
return get_placeholder_class(e)
+# Pre-serialization stages
base_stage = CattrStage() #: Base stage for all serializer pipelines
-dict_serializer = base_stage #: Partial serializer that unstructures responses into dicts
-pickle_serializer = SerializerPipeline([base_stage, pickle], is_binary=True) #: Pickle serializer
utf8_encoder = Stage(dumps=str.encode, loads=lambda x: x.decode()) #: Encode to bytes
bson_preconf_stage = make_stage('cattr.preconf.bson') #: Pre-serialization steps for BSON
json_preconf_stage = make_stage('cattr.preconf.json') #: Pre-serialization steps for JSON
@@ -42,6 +41,12 @@
ujson_preconf_stage = make_stage('cattr.preconf.ujson') #: Pre-serialization steps for ultrajson
yaml_preconf_stage = make_stage('cattr.preconf.pyyaml') #: Pre-serialization steps for YAML
+# Basic serializers with no additional dependencies
+dict_serializer = SerializerPipeline(
+ [base_stage], is_binary=False
+) #: Partial serializer that unstructures responses into dicts
+pickle_serializer = SerializerPipeline([base_stage, pickle], is_binary=True) #: Pickle serializer
+
# Safe pickle serializer
def signer_stage(secret_key=None, salt='requests-cache') -> Stage:
@@ -68,6 +73,7 @@
safe_pickle_serializer = get_placeholder_class(e)
+# BSON serializer
def _get_bson_functions():
"""Handle different function names between pymongo's bson and standalone bson"""
try:
@@ -78,7 +84,6 @@
return {'dumps': 'dumps', 'loads': 'loads'}
-# BSON serializer
try:
import bson
diff --git a/tests/integration/base_cache_test.py b/tests/integration/base_cache_test.py
index e0a2d06..5d87f0f 100644
--- a/tests/integration/base_cache_test.py
+++ b/tests/integration/base_cache_test.py
@@ -17,7 +17,13 @@
from requests_cache import ALL_METHODS, CachedResponse, CachedSession
from requests_cache.backends.base import BaseCache
-from requests_cache.serializers import SERIALIZERS, SerializerPipeline, safe_pickle_serializer
+from requests_cache.serializers import (
+ SERIALIZERS,
+ SerializerPipeline,
+ Stage,
+ dict_serializer,
+ safe_pickle_serializer,
+)
from tests.conftest import (
CACHE_NAME,
ETAG,
@@ -36,7 +42,8 @@
logger = getLogger(__name__)
-# Handle optional dependencies if they're not installed; if so, skips will be shown in pytest output
+# Handle optional dependencies if they're not installed,
+# so any skipped tests will explicitly be shown in pytest output
TEST_SERIALIZERS = SERIALIZERS.copy()
try:
TEST_SERIALIZERS['safe_pickle'] = safe_pickle_serializer(secret_key='hunter2')
@@ -49,6 +56,7 @@
"""Base class for testing cache backend classes"""
backend_class: Type[BaseCache] = None
+ document_support: bool = False
init_kwargs: Dict = {}
def init_session(self, cache_name=CACHE_NAME, clear=True, **kwargs) -> CachedSession:
@@ -68,11 +76,13 @@
@pytest.mark.parametrize('method', HTTPBIN_METHODS)
@pytest.mark.parametrize('field', ['params', 'data', 'json'])
def test_all_methods(self, field, method, serializer):
- """Test all relevant combinations of methods X data fields X serializers.
+ """Test all relevant combinations of (methods X data fields X serializers).
Requests with different request params, data, or json should be cached under different keys.
"""
- if not isinstance(serializer, SerializerPipeline):
+ if not isinstance(serializer, (SerializerPipeline, Stage)):
pytest.skip(f'Dependencies not installed for {serializer}')
+ if serializer is dict_serializer and not self.document_support:
+ return
url = httpbin(method.lower())
session = self.init_session(serializer=serializer)
@@ -83,14 +93,16 @@
@pytest.mark.parametrize('serializer', TEST_SERIALIZERS.values())
@pytest.mark.parametrize('response_format', HTTPBIN_FORMATS)
def test_all_response_formats(self, response_format, serializer):
- """Test that all relevant combinations of response formats X serializers are cached correctly"""
+ """Test all relevant combinations of (response formats X serializers)"""
if not isinstance(serializer, SerializerPipeline):
pytest.skip(f'Dependencies not installed for {serializer}')
+ if serializer is dict_serializer and not self.document_support:
+ return
session = self.init_session(serializer=serializer)
# Workaround for this issue: https://github.com/kevin1024/pytest-httpbin/issues/60
if response_format == 'json' and USE_PYTEST_HTTPBIN:
- session.allowable_codes = (200, 404)
+ session.settings.allowable_codes = (200, 404)
r1 = session.get(httpbin(response_format))
r2 = session.get(httpbin(response_format))
@@ -116,8 +128,8 @@
@pytest.mark.parametrize('n_redirects', range(1, 5))
@pytest.mark.parametrize('endpoint', ['redirect', 'absolute-redirect', 'relative-redirect'])
def test_redirect_history(self, endpoint, n_redirects):
- """Test redirect caching (in separate `redirects` cache) with all types of redirect endpoints,
- using different numbers of consecutive redirects
+ """Test redirect caching (in separate `redirects` cache) with all types of redirect
+ endpoints, using different numbers of consecutive redirects
"""
session = self.init_session()
session.get(httpbin(f'{endpoint}/{n_redirects}'))
@@ -147,9 +159,11 @@
response_1 = get_json(httpbin('cookies/set/test1/test2'))
with session.cache_disabled():
assert get_json(httpbin('cookies')) == response_1
+
# From cache
response_2 = get_json(httpbin('cookies'))
assert response_2 == get_json(httpbin('cookies'))
+
# Not from cache
with session.cache_disabled():
response_3 = get_json(httpbin('cookies/set/test3/test4'))
diff --git a/tests/integration/base_storage_test.py b/tests/integration/base_storage_test.py
index fa5c9d6..776d494 100644
--- a/tests/integration/base_storage_test.py
+++ b/tests/integration/base_storage_test.py
@@ -1,7 +1,9 @@
"""Common tests to run for all backends (BaseStorage subclasses)"""
+from datetime import datetime
from typing import Dict, Type
import pytest
+from attrs import define, field
from requests_cache.backends import BaseStorage
from tests.conftest import CACHE_NAME
@@ -97,9 +99,19 @@
def test_picklable_dict(self):
if self.picklable:
cache = self.init_cache()
- cache['key_1'] = Picklable()
- assert cache['key_1'].attr_1 == 'value_1'
- assert cache['key_1'].attr_2 == 'value_2'
+ original_obj = BasicDataclass(
+ bool_attr=True,
+ datetime_attr=datetime(2022, 2, 2),
+ int_attr=2,
+ str_attr='value',
+ )
+ cache['key_1'] = original_obj
+
+ obj = cache['key_1']
+ assert obj.bool_attr == original_obj.bool_attr
+ assert obj.datetime_attr == original_obj.datetime_attr
+ assert obj.int_attr == original_obj.int_attr
+ assert obj.str_attr == original_obj.str_attr
def test_clear_and_work_again(self):
cache_1 = self.init_cache()
@@ -130,6 +142,9 @@
assert f'key_{i}' in str(cache)
-class Picklable:
- attr_1 = 'value_1'
- attr_2 = 'value_2'
+@define
+class BasicDataclass:
+ bool_attr: bool = field(default=None)
+ datetime_attr: datetime = field(default=None)
+ int_attr: int = field(default=None)
+ str_attr: str = field(default=None)
diff --git a/tests/integration/test_dynamodb.py b/tests/integration/test_dynamodb.py
index 2b55339..52cb24f 100644
--- a/tests/integration/test_dynamodb.py
+++ b/tests/integration/test_dynamodb.py
@@ -32,4 +32,5 @@
class TestDynamoDbCache(BaseCacheTest):
backend_class = DynamoDbCache
+ # document_support = True
init_kwargs = AWS_OPTIONS
diff --git a/tests/integration/test_filesystem.py b/tests/integration/test_filesystem.py
index 1b81dff..20af517 100644
--- a/tests/integration/test_filesystem.py
+++ b/tests/integration/test_filesystem.py
@@ -9,6 +9,9 @@
from tests.integration.base_cache_test import BaseCacheTest
from tests.integration.base_storage_test import CACHE_NAME, BaseStorageTest
+FILE_SERIALIZERS = SERIALIZERS.copy()
+FILE_SERIALIZERS.pop('dict')
+
class TestFileDict(BaseStorageTest):
storage_class = FileDict
@@ -52,9 +55,9 @@
backend_class = FileCache
init_kwargs = {'use_temp': True}
- @pytest.mark.parametrize('serializer_name', SERIALIZERS.keys())
+ @pytest.mark.parametrize('serializer_name', FILE_SERIALIZERS.keys())
def test_paths(self, serializer_name):
- if not isinstance(SERIALIZERS[serializer_name], SerializerPipeline):
+ if not isinstance(FILE_SERIALIZERS[serializer_name], SerializerPipeline):
pytest.skip(f'Dependencies not installed for {serializer_name}')
session = self.init_session(serializer=serializer_name)
diff --git a/tests/integration/test_mongodb.py b/tests/integration/test_mongodb.py
index c798e1c..0db05ab 100644
--- a/tests/integration/test_mongodb.py
+++ b/tests/integration/test_mongodb.py
@@ -49,6 +49,7 @@
class TestMongoCache(BaseCacheTest):
backend_class = MongoCache
+ document_support = True
class TestGridFSPickleDict(BaseStorageTest):
@@ -86,3 +87,4 @@
class TestGridFSCache(BaseCacheTest):
backend_class = GridFSCache
+ document_support = False