Merge *PickleDict storage classes into parent classes
diff --git a/requests_cache/backends/__init__.py b/requests_cache/backends/__init__.py
index 7695b8f..9dd206a 100644
--- a/requests_cache/backends/__init__.py
+++ b/requests_cache/backends/__init__.py
@@ -15,35 +15,35 @@
# Import all backend classes for which dependencies are installed
try:
- from .dynamodb import DynamoDbCache, DynamoDbDict, DynamoDbDocumentDict
+ from .dynamodb import DynamoDbCache, DynamoDbDict
except ImportError as e:
- DynamoDbCache = DynamoDbDict = DynamoDbDocumentDict = get_placeholder_class(e) # type: ignore
+ DynamoDbCache = DynamoDbDict = get_placeholder_class(e) # type: ignore
+
try:
- from .gridfs import GridFSCache, GridFSPickleDict
+ from .gridfs import GridFSCache, GridFSDict
except ImportError as e:
- GridFSCache = GridFSPickleDict = get_placeholder_class(e) # type: ignore
+ GridFSCache = GridFSDict = get_placeholder_class(e) # type: ignore
+
try:
- from .mongodb import MongoCache, MongoDict, MongoDocumentDict
+ from .mongodb import MongoCache, MongoDict
except ImportError as e:
- MongoCache = MongoDict = MongoDocumentDict = get_placeholder_class(e) # type: ignore
+ MongoCache = MongoDict = get_placeholder_class(e) # type: ignore
+
try:
from .redis import RedisCache, RedisDict, RedisHashDict
except ImportError as e:
RedisCache = RedisDict = RedisHashDict = get_placeholder_class(e) # type: ignore
+
try:
- # Note: Heroku doesn't support SQLite due to ephemeral storage
- from .sqlite import SQLiteCache, SQLiteDict, SQLitePickleDict
+ from .sqlite import SQLiteCache, SQLiteDict
except ImportError as e:
- SQLiteCache = SQLiteDict = SQLitePickleDict = get_placeholder_class(e) # type: ignore
+ SQLiteCache = SQLiteDict = get_placeholder_class(e) # type: ignore
+
try:
from .filesystem import FileCache, FileDict
except ImportError as e:
FileCache = FileDict = get_placeholder_class(e) # type: ignore
-# Aliases for backwards-compatibility
-DbCache = SQLiteCache
-DbDict = SQLiteDict
-DbPickleDict = SQLitePickleDict
BACKEND_CLASSES = {
'dynamodb': DynamoDbCache,
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py
index 0a551ea..e837671 100644
--- a/requests_cache/backends/base.py
+++ b/requests_cache/backends/base.py
@@ -263,8 +263,8 @@
kwargs: Additional backend-specific keyword arguments
"""
- def __init__(self, serializer=None, **kwargs):
- self.serializer = init_serializer(serializer)
+ def __init__(self, **kwargs):
+ self.serializer = init_serializer(kwargs.get('serializer', 'pickle'))
logger.debug(f'Initializing {type(self).__name__} with serializer: {self.serializer}')
def bulk_delete(self, keys: Iterable[str]):
@@ -281,6 +281,14 @@
def close(self):
"""Close any open backend connections"""
+ def serialize(self, value):
+ """Serialize value, if a serializer is available"""
+ return self.serializer.dumps(value) if self.serializer else value
+
+ def deserialize(self, value):
+ """Deserialize value, if a serializer is available"""
+ return self.serializer.loads(value) if self.serializer else value
+
def __str__(self):
return str(list(self.keys()))
@@ -297,7 +305,6 @@
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self._serializer = None
self.serializer = None
def __getitem__(self, key):
diff --git a/requests_cache/backends/dynamodb.py b/requests_cache/backends/dynamodb.py
index d6a5fdb..f39d267 100644
--- a/requests_cache/backends/dynamodb.py
+++ b/requests_cache/backends/dynamodb.py
@@ -19,6 +19,7 @@
class DynamoDbCache(BaseCache):
"""DynamoDB cache backend.
+ By default, responses are only partially serialized into a DynamoDB-compatible document format.
Args:
table_name: DynamoDB table name
@@ -34,14 +35,25 @@
table_name: str = 'http_cache',
ttl: bool = True,
connection: ServiceResource = None,
+ serializer=None,
**kwargs,
):
super().__init__(cache_name=table_name, **kwargs)
- self.responses = DynamoDbDocumentDict(
- table_name, 'responses', ttl=ttl, connection=connection, **kwargs
+ self.responses = DynamoDbDict(
+ table_name,
+ 'responses',
+ ttl=ttl,
+ serializer=serializer or dynamodb_document_serializer,
+ connection=connection,
+ **kwargs,
)
self.redirects = DynamoDbDict(
- table_name, 'redirects', ttl=False, connection=self.responses.connection, **kwargs
+ table_name,
+ 'redirects',
+ ttl=False,
+ connection=self.responses.connection,
+ serializer=None,
+ **kwargs,
)
@@ -132,14 +144,15 @@
# With a custom serializer, the value may be a Binary object
raw_value = result['Item']['value']
- return raw_value.value if isinstance(raw_value, Binary) else raw_value
+ value = raw_value.value if isinstance(raw_value, Binary) else raw_value
+ return self.deserialize(value)
def __setitem__(self, key, value):
- item = {**self._composite_key(key), 'value': value}
+ item = {**self._composite_key(key), 'value': self.serialize(value)}
# If enabled, set TTL value as a timestamp in unix format
if self.ttl and getattr(value, 'ttl', None):
- item['ttl'] = int(time() + value.ttl)
+ item['ttl'] = round(time() + value.ttl)
self._table.put_item(Item=item)
@@ -169,19 +182,3 @@
def clear(self):
self.bulk_delete((k for k in self))
-
-
-class DynamoDbDocumentDict(DynamoDbDict):
- """Same as :class:`DynamoDbDict`, but serializes values before saving.
-
- By default, responses are only partially serialized into a DynamoDB-compatible document format.
- """
-
- def __init__(self, *args, serializer=None, **kwargs):
- super().__init__(*args, serializer=serializer or dynamodb_document_serializer, **kwargs)
-
- def __getitem__(self, key):
- return self.serializer.loads(super().__getitem__(key))
-
- def __setitem__(self, key, item):
- super().__setitem__(key, self.serializer.dumps(item))
diff --git a/requests_cache/backends/filesystem.py b/requests_cache/backends/filesystem.py
index 834b0ab..c1cb77e 100644
--- a/requests_cache/backends/filesystem.py
+++ b/requests_cache/backends/filesystem.py
@@ -91,7 +91,7 @@
mode = 'rb' if self.is_binary else 'r'
with self._try_io():
with self._path(key).open(mode) as f:
- return self.serializer.loads(f.read())
+ return self.deserialize(f.read())
def __delitem__(self, key):
with self._try_io():
@@ -100,7 +100,7 @@
def __setitem__(self, key, value):
with self._try_io():
with self._path(key).open(mode='wb' if self.is_binary else 'w') as f:
- f.write(self.serializer.dumps(value))
+ f.write(self.serialize(value))
def __iter__(self):
yield from self.keys()
diff --git a/requests_cache/backends/gridfs.py b/requests_cache/backends/gridfs.py
index eb139d5..dc66323 100644
--- a/requests_cache/backends/gridfs.py
+++ b/requests_cache/backends/gridfs.py
@@ -29,7 +29,7 @@
def __init__(self, db_name: str, **kwargs):
super().__init__(cache_name=db_name, **kwargs)
- self.responses = GridFSPickleDict(db_name, **kwargs)
+ self.responses = GridFSDict(db_name, **kwargs)
self.redirects = MongoDict(
db_name, collection_name='redirects', connection=self.responses.connection, **kwargs
)
@@ -39,7 +39,7 @@
return super().remove_expired_responses(*args, **kwargs)
-class GridFSPickleDict(BaseStorage):
+class GridFSDict(BaseStorage):
"""A dictionary-like interface for a GridFS database
Args:
@@ -63,13 +63,13 @@
result = self.fs.find_one({'_id': key})
if result is None:
raise KeyError
- return self.serializer.loads(result.read())
+ return self.deserialize(result.read())
except CorruptGridFile as e:
logger.warning(e, exc_info=True)
raise KeyError
def __setitem__(self, key, item):
- value = self.serializer.dumps(item)
+ value = self.serialize(item)
encoding = None if isinstance(value, bytes) else 'utf-8'
with self._lock:
diff --git a/requests_cache/backends/mongodb.py b/requests_cache/backends/mongodb.py
index d0fe79c..9c881be 100644
--- a/requests_cache/backends/mongodb.py
+++ b/requests_cache/backends/mongodb.py
@@ -21,6 +21,7 @@
class MongoCache(BaseCache):
"""MongoDB cache backend.
+ By default, responses are only partially serialized into a MongoDB-compatible document format.
Args:
db_name: Database name
@@ -28,18 +29,22 @@
kwargs: Additional keyword arguments for :py:class:`pymongo.mongo_client.MongoClient`
"""
- def __init__(self, db_name: str = 'http_cache', connection: MongoClient = None, **kwargs):
+ def __init__(
+ self, db_name: str = 'http_cache', connection: MongoClient = None, serializer=None, **kwargs
+ ):
super().__init__(cache_name=db_name, **kwargs)
- self.responses: MongoDict = MongoDocumentDict(
+ self.responses: MongoDict = MongoDict(
db_name,
collection_name='responses',
connection=connection,
+ serializer=serializer or bson_document_serializer,
**kwargs,
)
self.redirects: MongoDict = MongoDict(
db_name,
collection_name='redirects',
connection=self.responses.connection,
+ serializer=None,
**kwargs,
)
@@ -107,15 +112,17 @@
result = self.collection.find_one({'_id': key})
if result is None:
raise KeyError
- return result['data'] if 'data' in result else result
+ value = result['data'] if 'data' in result else result
+ return self.deserialize(value)
- def __setitem__(self, key, item):
- """If ``item`` is already a dict, its values will be stored under top-level keys.
+ def __setitem__(self, key, value):
+ """If ``value`` is already a dict, its values will be stored under top-level keys.
Otherwise, it will be stored under a 'data' key.
"""
- if not isinstance(item, Mapping):
- item = {'data': item}
- self.collection.replace_one({'_id': key}, item, upsert=True)
+ value = self.serialize(value)
+ if not isinstance(value, Mapping):
+ value = {'data': value}
+ self.collection.replace_one({'_id': key}, value, upsert=True)
def __delitem__(self, key):
result = self.collection.find_one_and_delete({'_id': key}, {'_id': True})
@@ -138,19 +145,3 @@
def close(self):
self.connection.close()
-
-
-class MongoDocumentDict(MongoDict):
- """Same as :class:`MongoDict`, but serializes values before saving.
-
- By default, responses are only partially serialized into a MongoDB-compatible document format.
- """
-
- def __init__(self, *args, serializer=None, **kwargs):
- super().__init__(*args, serializer=serializer or bson_document_serializer, **kwargs)
-
- def __getitem__(self, key):
- return self.serializer.loads(super().__getitem__(key))
-
- def __setitem__(self, key, item):
- super().__setitem__(key, self.serializer.dumps(item))
diff --git a/requests_cache/backends/redis.py b/requests_cache/backends/redis.py
index 0697023..76c83c1 100644
--- a/requests_cache/backends/redis.py
+++ b/requests_cache/backends/redis.py
@@ -75,14 +75,15 @@
result = self.connection.get(self._bkey(key))
if result is None:
raise KeyError
- return self.serializer.loads(result)
+ return self.deserialize(result)
def __setitem__(self, key, item):
"""Save an item to the cache, optionally with TTL"""
- if self.ttl and getattr(item, 'ttl', None):
- self.connection.setex(self._bkey(key), item.ttl, self.serializer.dumps(item))
+ ttl_seconds = getattr(item, 'ttl', None)
+ if self.ttl and ttl_seconds and ttl_seconds > 0:
+ self.connection.setex(self._bkey(key), round(ttl_seconds), self.serialize(item))
else:
- self.connection.set(self._bkey(key), self.serializer.dumps(item))
+ self.connection.set(self._bkey(key), self.serialize(item))
def __delitem__(self, key):
if not self.connection.delete(self._bkey(key)):
@@ -141,10 +142,10 @@
result = self.connection.hget(self._hash_key, encode(key))
if result is None:
raise KeyError
- return self.serializer.loads(result)
+ return self.deserialize(result)
def __setitem__(self, key, item):
- self.connection.hset(self._hash_key, encode(key), self.serializer.dumps(item))
+ self.connection.hset(self._hash_key, encode(key), self.serialize(item))
def __delitem__(self, key):
if not self.connection.hdel(self._hash_key, encode(key)):
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py
index eb9b712..cb15697 100644
--- a/requests_cache/backends/sqlite.py
+++ b/requests_cache/backends/sqlite.py
@@ -43,10 +43,14 @@
kwargs: Additional keyword arguments for :py:func:`sqlite3.connect`
"""
- def __init__(self, db_path: AnyPath = 'http_cache', **kwargs):
+ def __init__(self, db_path: AnyPath = 'http_cache', serializer=None, **kwargs):
super().__init__(cache_name=str(db_path), **kwargs)
- self.responses: SQLiteDict = SQLitePickleDict(db_path, table_name='responses', **kwargs)
- self.redirects: SQLiteDict = SQLiteDict(db_path, table_name='redirects', **kwargs)
+ self.responses: SQLiteDict = SQLiteDict(
+ db_path, table_name='responses', serializer=serializer or 'pickle', **kwargs
+ )
+ self.redirects: SQLiteDict = SQLiteDict(
+ db_path, table_name='redirects', serializer=None, **kwargs
+ )
@property
def db_path(self) -> AnyPath:
@@ -211,7 +215,8 @@
# raise error after the with block, otherwise the connection will be locked
if not row:
raise KeyError
- return row[0]
+
+ return self.deserialize(row[0])
def __setitem__(self, key, value):
self._insert(key, value)
@@ -288,36 +293,13 @@
f' ORDER BY {key} {direction} {limit_expr}',
params,
):
- yield row[0]
+ yield self.deserialize(row[0])
def vacuum(self):
with self.connection(commit=True) as con:
con.execute('VACUUM')
-class SQLitePickleDict(SQLiteDict):
- """Same as :class:`SQLiteDict`, but serializes values before saving"""
-
- def __setitem__(self, key, value: CachedResponse):
- serialized_value = self.serializer.dumps(value)
- if isinstance(serialized_value, bytes):
- serialized_value = sqlite3.Binary(serialized_value)
- super()._insert(key, serialized_value, getattr(value, 'expires', None))
-
- def __getitem__(self, key):
- return self.serializer.loads(super().__getitem__(key))
-
- def sorted(
- self,
- key: str = 'expires',
- reversed: bool = False,
- limit: int = None,
- exclude_expired: bool = False,
- ):
- for value in super().sorted(key, reversed, limit, exclude_expired):
- yield self.serializer.loads(value)
-
-
def _format_sequence(values: Collection) -> Tuple[str, List]:
"""Get SQL parameter marks for a sequence-based query"""
return ','.join(['?'] * len(values)), list(values)
@@ -372,9 +354,3 @@
uri: bool = False,
):
"""Template function to get an accurate signature for the builtin :py:func:`sqlite3.connect`"""
-
-
-# Aliases for backwards-compatibility
-DbCache = SQLiteCache
-DbDict = SQLiteDict
-DbPickeDict = SQLitePickleDict
diff --git a/requests_cache/models/response.py b/requests_cache/models/response.py
index c667856..fe6e4da 100755
--- a/requests_cache/models/response.py
+++ b/requests_cache/models/response.py
@@ -134,12 +134,12 @@
return self.expires is not None and datetime.utcnow() >= self.expires
@property
- def ttl(self) -> Optional[int]:
- """Get time to expiration in seconds"""
- if self.expires is None or self.is_expired:
+ def ttl(self) -> Optional[float]:
+ """Get time to expiration in seconds (rounded to the nearest second)"""
+ if self.expires is None:
return None
delta = self.expires - datetime.utcnow()
- return int(delta.total_seconds())
+ return delta.total_seconds()
@property
def next(self) -> Optional[PreparedRequest]:
diff --git a/requests_cache/serializers/__init__.py b/requests_cache/serializers/__init__.py
index d49545a..be2a338 100644
--- a/requests_cache/serializers/__init__.py
+++ b/requests_cache/serializers/__init__.py
@@ -9,6 +9,7 @@
dict_serializer,
dynamodb_document_serializer,
json_serializer,
+ no_op_serializer,
pickle_serializer,
safe_pickle_serializer,
utf8_encoder,
@@ -25,6 +26,7 @@
'dynamodb_document_serializer',
'dict_serializer',
'json_serializer',
+ 'no_op_serializer',
'pickle_serializer',
'safe_pickle_serializer',
'yaml_serializer',
@@ -40,9 +42,8 @@
}
-def init_serializer(serializer=None):
+def init_serializer(serializer):
"""Initialize a serializer from a name or instance"""
- serializer = serializer or 'pickle'
if isinstance(serializer, str):
serializer = SERIALIZERS[serializer]
return serializer
diff --git a/requests_cache/serializers/preconf.py b/requests_cache/serializers/preconf.py
index 95fef1f..1d7d70f 100644
--- a/requests_cache/serializers/preconf.py
+++ b/requests_cache/serializers/preconf.py
@@ -55,7 +55,7 @@
pickle_serializer = SerializerPipeline(
[base_stage, pickle], name='pickle', is_binary=True
) #: Pickle serializer
-
+no_op_serializer = SerializerPipeline([], name='no_op') #: Placeholder serializer that does nothing
# Safe pickle serializer
def signer_stage(secret_key=None, salt='requests-cache') -> Stage:
diff --git a/tests/integration/base_cache_test.py b/tests/integration/base_cache_test.py
index 081b5a1..6d71018 100644
--- a/tests/integration/base_cache_test.py
+++ b/tests/integration/base_cache_test.py
@@ -44,6 +44,7 @@
# Handle optional dependencies if they're not installed,
# so any skipped tests will explicitly be shown in pytest output
TEST_SERIALIZERS = SERIALIZERS.copy()
+TEST_SERIALIZERS['no_op'] = None
try:
TEST_SERIALIZERS['safe_pickle'] = safe_pickle_serializer(secret_key='hunter2')
except ImportError:
@@ -51,6 +52,10 @@
VALIDATOR_HEADERS = [{'ETag': ETAG}, {'Last-Modified': LAST_MODIFIED}]
+def _valid_serializer(serializer) -> bool:
+ return isinstance(serializer, (SerializerPipeline, Stage)) or serializer is None
+
+
class BaseCacheTest:
"""Base class for testing cache backend classes"""
@@ -79,7 +84,7 @@
"""Test all relevant combinations of (methods X data fields X serializers).
Requests with different request params, data, or json should be cached under different keys.
"""
- if not isinstance(serializer, (SerializerPipeline, Stage)):
+ if not _valid_serializer(serializer):
pytest.skip(f'Dependencies not installed for {serializer}')
url = httpbin(method.lower())
@@ -92,7 +97,7 @@
@pytest.mark.parametrize('response_format', HTTPBIN_FORMATS)
def test_all_response_formats(self, response_format, serializer):
"""Test all relevant combinations of (response formats X serializers)"""
- if not isinstance(serializer, SerializerPipeline):
+ if not _valid_serializer(serializer):
pytest.skip(f'Dependencies not installed for {serializer}')
session = self.init_session(serializer=serializer)
@@ -311,6 +316,7 @@
# Cache a response and some redirects, which should be the only non-expired cache items
session.get(httpbin('get'), expire_after=-1)
session.get(httpbin('redirect/3'), expire_after=-1)
+ assert len(session.cache.redirects.keys()) == 4
session.cache.remove_expired_responses()
assert len(session.cache.responses.keys()) == 2
diff --git a/tests/integration/base_storage_test.py b/tests/integration/base_storage_test.py
index 7d38643..ee3bf31 100644
--- a/tests/integration/base_storage_test.py
+++ b/tests/integration/base_storage_test.py
@@ -14,7 +14,6 @@
storage_class: Type[BaseStorage] = None
init_kwargs: Dict = {}
- picklable: bool = False
num_instances: int = 10 # Max number of cache instances to test
def init_cache(self, cache_name=CACHE_NAME, index=0, clear=True, **kwargs):
@@ -98,21 +97,20 @@
cache['key']
def test_picklable_dict(self):
- if self.picklable:
- cache = self.init_cache(serializer='pickle')
- original_obj = BasicDataclass(
- bool_attr=True,
- datetime_attr=datetime(2022, 2, 2),
- int_attr=2,
- str_attr='value',
- )
- cache['key_1'] = original_obj
+ cache = self.init_cache(serializer='pickle')
+ original_obj = BasicDataclass(
+ bool_attr=True,
+ datetime_attr=datetime(2022, 2, 2),
+ int_attr=2,
+ str_attr='value',
+ )
+ cache['key_1'] = original_obj
- obj = cache['key_1']
- assert obj.bool_attr == original_obj.bool_attr
- assert obj.datetime_attr == original_obj.datetime_attr
- assert obj.int_attr == original_obj.int_attr
- assert obj.str_attr == original_obj.str_attr
+ obj = cache['key_1']
+ assert obj.bool_attr == original_obj.bool_attr
+ assert obj.datetime_attr == original_obj.datetime_attr
+ assert obj.int_attr == original_obj.int_attr
+ assert obj.str_attr == original_obj.str_attr
def test_clear_and_work_again(self):
cache_1 = self.init_cache()
diff --git a/tests/integration/test_dynamodb.py b/tests/integration/test_dynamodb.py
index 84c1008..59f8ae1 100644
--- a/tests/integration/test_dynamodb.py
+++ b/tests/integration/test_dynamodb.py
@@ -5,7 +5,7 @@
import pytest
from botocore.exceptions import ClientError
-from requests_cache.backends import DynamoDbCache, DynamoDbDict, DynamoDbDocumentDict
+from requests_cache.backends import DynamoDbCache, DynamoDbDict
from requests_cache.serializers import dynamodb_document_serializer
from tests.conftest import HTTPBIN_FORMATS, HTTPBIN_METHODS, fail_if_no_connection
from tests.integration.base_cache_test import TEST_SERIALIZERS, BaseCacheTest
@@ -86,12 +86,6 @@
assert ttl_value is None
-class TestDynamoDbDocumentDict(BaseStorageTest):
- storage_class = DynamoDbDocumentDict
- init_kwargs = DYNAMODB_OPTIONS
- picklable = True
-
-
class TestDynamoDbCache(BaseCacheTest):
backend_class = DynamoDbCache
init_kwargs = DYNAMODB_OPTIONS
diff --git a/tests/integration/test_mongodb.py b/tests/integration/test_mongodb.py
index 262ddd6..5ff6ae9 100644
--- a/tests/integration/test_mongodb.py
+++ b/tests/integration/test_mongodb.py
@@ -6,14 +6,8 @@
from gridfs import GridFS
from gridfs.errors import CorruptGridFile, FileExists
-from requests_cache.backends import (
- GridFSCache,
- GridFSPickleDict,
- MongoCache,
- MongoDict,
- MongoDocumentDict,
-)
-from requests_cache.policy.expiration import NEVER_EXPIRE
+from requests_cache.backends import GridFSCache, GridFSDict, MongoCache, MongoDict
+from requests_cache.policy import NEVER_EXPIRE
from requests_cache.serializers import bson_document_serializer
from tests.conftest import HTTPBIN_FORMATS, HTTPBIN_METHODS, fail_if_no_connection, httpbin
from tests.integration.base_cache_test import TEST_SERIALIZERS, BaseCacheTest
@@ -37,11 +31,6 @@
class TestMongoDict(BaseStorageTest):
storage_class = MongoDict
-
-class TestMongoPickleDict(BaseStorageTest):
- storage_class = MongoDocumentDict
- picklable = True
-
def test_connection_kwargs(self):
"""A spot check to make sure optional connection kwargs gets passed to connection"""
# MongoClient prevents direct access to private members like __init_kwargs;
@@ -60,7 +49,6 @@
class TestMongoCache(BaseCacheTest):
backend_class = MongoCache
-
init_kwargs = {'serializer': None} # Use class default serializer instead of pickle
@pytest.mark.parametrize('serializer', MONGODB_SERIALIZERS)
@@ -113,14 +101,14 @@
assert session.cache.get_ttl() is None
-class TestGridFSPickleDict(BaseStorageTest):
- storage_class = GridFSPickleDict
+class TestGridFSDict(BaseStorageTest):
+ storage_class = GridFSDict
picklable = True
num_instances = 1 # Only test a single collecton instead of multiple
def test_connection_kwargs(self):
"""A spot check to make sure optional connection kwargs gets passed to connection"""
- cache = GridFSPickleDict(
+ cache = GridFSDict(
'test',
host='mongodb://0.0.0.0',
port=2222,
diff --git a/tests/integration/test_sqlite.py b/tests/integration/test_sqlite.py
index 84b399a..b696236 100644
--- a/tests/integration/test_sqlite.py
+++ b/tests/integration/test_sqlite.py
@@ -9,13 +9,14 @@
from platformdirs import user_cache_dir
from requests_cache.backends.base import BaseCache
-from requests_cache.backends.sqlite import MEMORY_URI, SQLiteCache, SQLiteDict, SQLitePickleDict
+from requests_cache.backends.sqlite import MEMORY_URI, SQLiteCache, SQLiteDict
from requests_cache.models.response import CachedResponse
from tests.integration.base_cache_test import BaseCacheTest
from tests.integration.base_storage_test import CACHE_NAME, BaseStorageTest
-class SQLiteTestCase(BaseStorageTest):
+class TestSQLiteDict(BaseStorageTest):
+ storage_class = SQLiteDict
init_kwargs = {'use_temp': True}
@classmethod
@@ -178,15 +179,6 @@
with pytest.raises(ValueError):
list(cache.sorted(key='invalid_key'))
-
-class TestSQLiteDict(SQLiteTestCase):
- storage_class = SQLiteDict
-
-
-class TestSQLitePickleDict(SQLiteTestCase):
- storage_class = SQLitePickleDict
- picklable = True
-
@pytest.mark.parametrize('limit', [None, 50])
def test_sorted__by_expires(self, limit):
cache = self.init_cache()
@@ -213,7 +205,7 @@
for i in range(100):
delta = 101 - i
if i % 2 == 1:
- delta -= 100
+ delta -= 101
response = CachedResponse(status_code=i, expires=now + timedelta(seconds=delta))
cache[f'key_{i}'] = response
@@ -227,6 +219,13 @@
assert prev_item is None or prev_item.expires < item.expires
assert item.status_code % 2 == 0
+ def test_filesize(self):
+ """Test approximate expected size of database file, in bytes"""
+ cache = self.init_cache()
+ for i in range(100):
+ cache[f'key_{i}'] = f'value_{i}'
+ assert 50000 < cache.filesize() < 200000
+
class TestSQLiteCache(BaseCacheTest):
backend_class = SQLiteCache
diff --git a/tests/unit/test_base_cache.py b/tests/unit/test_base_cache.py
index 0b87411..a39b238 100644
--- a/tests/unit/test_base_cache.py
+++ b/tests/unit/test_base_cache.py
@@ -5,7 +5,7 @@
import pytest
from requests_cache import CachedResponse
-from requests_cache.backends import BaseCache, SQLitePickleDict
+from requests_cache.backends import BaseCache, SQLiteDict
from tests.conftest import MOCKED_URL, MOCKED_URL_HTTPS, MOCKED_URL_JSON, MOCKED_URL_REDIRECT
YESTERDAY = datetime.utcnow() - timedelta(days=1)
@@ -24,7 +24,7 @@
def test_urls__with_invalid_response(mock_session):
responses = [mock_session.get(url) for url in [MOCKED_URL, MOCKED_URL_JSON, MOCKED_URL_HTTPS]]
responses[2] = AttributeError
- with patch.object(SQLitePickleDict, '__getitem__', side_effect=responses):
+ with patch.object(SQLiteDict, '__getitem__', side_effect=responses):
expected_urls = [MOCKED_URL, MOCKED_URL_JSON]
assert set(mock_session.cache.urls) == set(expected_urls)
@@ -67,7 +67,7 @@
responses[1] = AttributeError
responses[2] = CachedResponse(expires=YESTERDAY, url='test')
- with patch.object(SQLitePickleDict, '__getitem__', side_effect=responses):
+ with patch.object(SQLiteDict, '__getitem__', side_effect=responses):
values = mock_session.cache.values(check_expiry=check_expiry)
assert len(list(values)) == expected_count
diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py
index e64a457..12411eb 100644
--- a/tests/unit/test_session.py
+++ b/tests/unit/test_session.py
@@ -3,6 +3,7 @@
import time
from collections import UserDict, defaultdict
from datetime import datetime, timedelta
+from logging import getLogger
from pathlib import Path
from pickle import PickleError
from unittest.mock import patch
@@ -15,7 +16,7 @@
from requests_cache import ALL_METHODS, CachedSession
from requests_cache._utils import get_placeholder_class
-from requests_cache.backends import BACKEND_CLASSES, BaseCache, SQLiteDict, SQLitePickleDict
+from requests_cache.backends import BACKEND_CLASSES, BaseCache, SQLiteDict
from requests_cache.backends.base import DESERIALIZE_ERRORS
from requests_cache.policy.expiration import DO_NOT_CACHE, EXPIRE_IMMEDIATELY, NEVER_EXPIRE
from tests.conftest import (
@@ -32,6 +33,7 @@
# Some tests must disable url normalization to retain the custom `http+mock//` protocol
patch_normalize_url = patch('requests_cache.cache_keys.normalize_url', side_effect=lambda x, y: x)
+logger = getLogger(__name__)
# Basic initialization
# -----------------------------------------------------
@@ -421,7 +423,7 @@
"""If there is an error during deserialization, the request should be made again"""
assert mock_session.get(MOCKED_URL_JSON).from_cache is False
- with patch.object(SQLitePickleDict, '__getitem__', side_effect=PickleError):
+ with patch.object(SQLiteDict, '__getitem__', side_effect=PickleError):
resp = mock_session.get(MOCKED_URL_JSON)
assert resp.from_cache is False
assert resp.json()['message'] == 'mock json response'
@@ -649,7 +651,7 @@
raise PickleError
return response_1
- with patch.object(SQLitePickleDict, '__getitem__', side_effect=error_on_key):
+ with patch.object(SQLiteDict, '__getitem__', side_effect=error_on_key):
BaseCache.remove_expired_responses(mock_session.cache)
assert len(mock_session.cache.responses) == 1
assert mock_session.get(MOCKED_URL).from_cache is True
@@ -695,7 +697,7 @@
# All 3 responses should still be cached
mock_session.remove_expired_responses()
for response in mock_session.cache.responses.values():
- print('Expires:', response.expires - datetime.utcnow() if response.expires else None)
+ logger.info(f'Expires in {response.ttl} seconds')
assert len(mock_session.cache.responses) == 3
# One should be expired after 1s, and another should be expired after 2s