Merge pull request #610 from JWCook/merge-storage-classes
Merge serializer-specific storage classes into respective parent classes
diff --git a/HISTORY.md b/HISTORY.md
index 1edc48b..d8e6950 100644
--- a/HISTORY.md
+++ b/HISTORY.md
@@ -23,6 +23,7 @@
**Backends:**
* SQLite:
* Improve performance for removing expired items
+ * Add `size()` method to get estimated size of the database (including in-memory databases)
* Add `sorted()` method with sorting and other query options
* Add `wal` parameter to enable write-ahead logging
* MongoDB:
@@ -33,8 +34,15 @@
* Create default table in on-demand mode instead of provisioned
* Add optional integration with DynamoDB TTL to improve performance for removing expired responses
* This is enabled by default, but may be disabled
+* Filesystem:
+ * The default file format has been changed from pickle to JSON
* SQLite, Redis, MongoDB, and GridFS: Close open database connections when `CachedSession` is used as a contextmanager, or if `CachedSession.close()` is called
+**Request Matching & Filtering:**
+* Add serializer name to cache keys to avoid errors due to switching serializers
+* Always skip both cache read and write for requests excluded by `allowable_methods` (previously only skipped write)
+* Ignore and redact common authentication headers and request parameters by default. This provides some default recommended values for `ignored_parameters`, to avoid accidentally storing common credentials (e.g., OAuth tokens) in the cache. This will have no effect if you are already setting `ignored_parameters`.
+
**Type hints:**
* Add `OriginalResponse` type, which adds type hints to `requests.Response` objects for extra attributes added by requests-cache:
* `cache_key`
@@ -45,29 +53,28 @@
* `OriginalResponse.cache_key` and `expires` will be populated for any new response that was written to the cache
* Add request wrapper methods with return type hints for all HTTP methods (`CachedSession.get()`, `head()`, etc.)
-**Request Matching & Filtering:**
-* Add serializer name to cache keys to avoid errors due to switching serializers
-* Always skip both cache read and write for requests excluded by `allowable_methods` (previously only skipped write)
-* Ignore and redact common authentication params and headers (e.g., for OAuth2) by default
- * This is simply a default value for `ignored_parameters`, to avoid accidentally storing credentials in the cache
-
**Dependencies:**
* Replace `appdirs` with `platformdirs`
-**Potentially breaking changes:**
+**Breaking changes:**
+Some relatively minor breaking changes have been made that are not expected to affect most users.
+If you encounter a problem not listed here after updating, please file a bug report!
+
The following undocumented behaviors have been removed:
-* The arguments `match_headers` and `ignored_parameters` must be passed to `CachedSession`.
- * Previously, these could also be passed to a `BaseCache` instance.
-* The `CachedSession` `backend` argument must be either an instance or string alias.
- * Previously it would also accept a backend class.
+* The arguments `match_headers` and `ignored_parameters` must be passed to `CachedSession`. Previously, these could also be passed to a `BaseCache` instance.
+* The `CachedSession` `backend` argument must be either an instance or string alias. Previously it would also accept a backend class.
* After initialization, cache settings can only be accesed and modified via
- `CachedSession.settings`.
- * Previously, some settings could be modified by setting them on either `CachedSession` or `BaseCache`. In some cases this could silently fail or otherwise have undefined behavior.
+ `CachedSession.settings`. Previously, some settings could be modified by setting them on either `CachedSession` or `BaseCache`. In some cases this could silently fail or otherwise have undefined behavior.
-Internal module changes:
-* The contents of the `cache_control` module have been split up into multiple modules in a new `policy` subpackage
+The following is relevant for users who have made custom backends that extend built-in storage classes:
+* All `BaseStorage` subclasses now have a `serializer` attribute, which will be unused if
+ set to `None`.
+* All serializer-specific `BaseStorage` subclasses have been removed, and merged into their respective parent classes. This includes `SQLitePickleDict`, `MongoPickleDict`, and `GridFSPickleDict`.
-## 0.9.4 (2022-04-21)
+Internal utility module changes:
+* The `cache_control` module (added in `0.7`) has been split up into multiple modules in a new `policy` subpackage
+
+### 0.9.4 (2022-04-22)
* Fix forwarding connection parameters passed to `RedisCache` for redis-py 4.2 and python <=3.8
* Fix forwarding connection parameters passed to `MongoCache` for pymongo 4.1 and python <=3.8
diff --git a/docs/user_guide/backends/filesystem.md b/docs/user_guide/backends/filesystem.md
index 9fdee7c..07e41c4 100644
--- a/docs/user_guide/backends/filesystem.md
+++ b/docs/user_guide/backends/filesystem.md
@@ -26,17 +26,9 @@
```
## File Formats
-By default, responses are saved as pickle files. If you want to save responses in a human-readable
-format, you can use one of the other available {ref}`serializers`. For example, to save responses as
-JSON files:
-```python
->>> session = CachedSession('~/http_cache', backend='filesystem', serializer='json')
->>> session.get('https://httpbin.org/get')
->>> print(list(session.cache.paths()))
-> ['/home/user/http_cache/4dc151d95200ec.json']
-```
-
-Or as YAML (requires `pyyaml`):
+By default, responses are saved as JSON files. If you prefer a deiffernt format, you can use of the
+other available {ref}`serializers` or provide your own. For example, to save responses as
+YAML files (requires `pyyaml`):
```python
>>> session = CachedSession('~/http_cache', backend='filesystem', serializer='yaml')
>>> session.get('https://httpbin.org/get')
diff --git a/docs/user_guide/backends/redis.md b/docs/user_guide/backends/redis.md
index cfe1898..141834c 100644
--- a/docs/user_guide/backends/redis.md
+++ b/docs/user_guide/backends/redis.md
@@ -47,6 +47,9 @@
Redis natively supports TTL on a per-key basis, and can automatically remove expired responses from
the cache. This will be set by by default, according to normal {ref}`expiration settings <expiration>`.
+Expired items are not removed immediately, but will never be returned from the cache. See
+[Redis: EXPIRE](https://redis.io/commands/expire/) docs for more details.
+
If you intend to reuse expired responses, e.g. with {ref}`conditional-requests` or `stale_if_error`,
you can disable this behavior with the `ttl` argument:
```python
diff --git a/docs/user_guide/expiration.md b/docs/user_guide/expiration.md
index d0465f1..a53da87 100644
--- a/docs/user_guide/expiration.md
+++ b/docs/user_guide/expiration.md
@@ -142,15 +142,13 @@
>>> session.remove_expired_responses(expire_after=timedelta(days=30))
```
+(ttl)=
### Automatic Removal
The following backends have native TTL support, which can be used to automatically remove expired
responses:
+* {py:mod}`DynamoDB <requests_cache.backends.dynamodb>`
* {py:mod}`MongoDB <requests_cache.backends.mongodb>`
* {py:mod}`Redis <requests_cache.backends.redis>`
-<!--
-TODO: Not yet supported:
-* {py:mod}`DynamoDB <requests_cache.backends.dynamodb>`
--->
## Request Options
In addition to the base arguments for {py:func}`requests.request`, requests-cache adds some extra
diff --git a/requests_cache/backends/__init__.py b/requests_cache/backends/__init__.py
index 7695b8f..9dd206a 100644
--- a/requests_cache/backends/__init__.py
+++ b/requests_cache/backends/__init__.py
@@ -15,35 +15,35 @@
# Import all backend classes for which dependencies are installed
try:
- from .dynamodb import DynamoDbCache, DynamoDbDict, DynamoDbDocumentDict
+ from .dynamodb import DynamoDbCache, DynamoDbDict
except ImportError as e:
- DynamoDbCache = DynamoDbDict = DynamoDbDocumentDict = get_placeholder_class(e) # type: ignore
+ DynamoDbCache = DynamoDbDict = get_placeholder_class(e) # type: ignore
+
try:
- from .gridfs import GridFSCache, GridFSPickleDict
+ from .gridfs import GridFSCache, GridFSDict
except ImportError as e:
- GridFSCache = GridFSPickleDict = get_placeholder_class(e) # type: ignore
+ GridFSCache = GridFSDict = get_placeholder_class(e) # type: ignore
+
try:
- from .mongodb import MongoCache, MongoDict, MongoDocumentDict
+ from .mongodb import MongoCache, MongoDict
except ImportError as e:
- MongoCache = MongoDict = MongoDocumentDict = get_placeholder_class(e) # type: ignore
+ MongoCache = MongoDict = get_placeholder_class(e) # type: ignore
+
try:
from .redis import RedisCache, RedisDict, RedisHashDict
except ImportError as e:
RedisCache = RedisDict = RedisHashDict = get_placeholder_class(e) # type: ignore
+
try:
- # Note: Heroku doesn't support SQLite due to ephemeral storage
- from .sqlite import SQLiteCache, SQLiteDict, SQLitePickleDict
+ from .sqlite import SQLiteCache, SQLiteDict
except ImportError as e:
- SQLiteCache = SQLiteDict = SQLitePickleDict = get_placeholder_class(e) # type: ignore
+ SQLiteCache = SQLiteDict = get_placeholder_class(e) # type: ignore
+
try:
from .filesystem import FileCache, FileDict
except ImportError as e:
FileCache = FileDict = get_placeholder_class(e) # type: ignore
-# Aliases for backwards-compatibility
-DbCache = SQLiteCache
-DbDict = SQLiteDict
-DbPickleDict = SQLitePickleDict
BACKEND_CLASSES = {
'dynamodb': DynamoDbCache,
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py
index 0a551ea..a1c6863 100644
--- a/requests_cache/backends/base.py
+++ b/requests_cache/backends/base.py
@@ -20,7 +20,7 @@
from ..models import CachedResponse
from ..policy.expiration import ExpirationTime
from ..policy.settings import DEFAULT_CACHE_NAME, CacheSettings
-from ..serializers import init_serializer
+from ..serializers import SerializerType, init_serializer, pickle_serializer
# Specific exceptions that may be raised during deserialization
DESERIALIZE_ERRORS = (AttributeError, ImportError, PickleError, TypeError, ValueError)
@@ -239,10 +239,10 @@
"""Show a count of total **rows** currently stored in the backend. For performance reasons,
this does not check for invalid or expired responses.
"""
- return f'Total rows: {len(self.responses)} responses, {len(self.redirects)} redirects'
+ return f'<{self.__class__.__name__}(name={self.cache_name})>'
def __repr__(self):
- return f'<{self.__class__.__name__}(name={self.cache_name})>'
+ return str(self)
class BaseStorage(MutableMapping, ABC):
@@ -260,12 +260,20 @@
Args:
serializer: Custom serializer that provides ``loads`` and ``dumps`` methods
+ no_serializer: Explicitly disable serialization, and write values as-is; this is to avoid
+ ambiguity with ``serializer=None``
kwargs: Additional backend-specific keyword arguments
"""
- def __init__(self, serializer=None, **kwargs):
- self.serializer = init_serializer(serializer)
- logger.debug(f'Initializing {type(self).__name__} with serializer: {self.serializer}')
+ # Default serializer to use for responses, if one isn't specified; may be overridden
+ default_serializer: SerializerType = pickle_serializer
+
+ def __init__(self, serializer: SerializerType = None, no_serializer: bool = False, **kwargs):
+ if no_serializer:
+ self.serializer = None
+ else:
+ self.serializer = init_serializer(serializer or self.default_serializer)
+ logger.debug(f'Initialized {type(self).__name__} with serializer: {self.serializer}')
def bulk_delete(self, keys: Iterable[str]):
"""Delete multiple keys from the cache, without raising errors for missing keys. This is a
@@ -281,6 +289,14 @@
def close(self):
"""Close any open backend connections"""
+ def serialize(self, value):
+ """Serialize value, if a serializer is available"""
+ return self.serializer.dumps(value) if self.serializer else value
+
+ def deserialize(self, value):
+ """Deserialize value, if a serializer is available"""
+ return self.serializer.loads(value) if self.serializer else value
+
def __str__(self):
return str(list(self.keys()))
@@ -297,7 +313,6 @@
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self._serializer = None
self.serializer = None
def __getitem__(self, key):
diff --git a/requests_cache/backends/dynamodb.py b/requests_cache/backends/dynamodb.py
index d6a5fdb..11776d1 100644
--- a/requests_cache/backends/dynamodb.py
+++ b/requests_cache/backends/dynamodb.py
@@ -4,7 +4,6 @@
:classes-only:
:nosignatures:
"""
-from time import time
from typing import Dict, Iterable
import boto3
@@ -19,6 +18,7 @@
class DynamoDbCache(BaseCache):
"""DynamoDB cache backend.
+ By default, responses are only partially serialized into a DynamoDB-compatible document format.
Args:
table_name: DynamoDB table name
@@ -37,11 +37,20 @@
**kwargs,
):
super().__init__(cache_name=table_name, **kwargs)
- self.responses = DynamoDbDocumentDict(
- table_name, 'responses', ttl=ttl, connection=connection, **kwargs
+ self.responses = DynamoDbDict(
+ table_name,
+ namespace='responses',
+ ttl=ttl,
+ connection=connection,
+ **kwargs,
)
self.redirects = DynamoDbDict(
- table_name, 'redirects', ttl=False, connection=self.responses.connection, **kwargs
+ table_name,
+ namespace='redirects',
+ ttl=False,
+ connection=self.responses.connection,
+ no_serializer=True,
+ **kwargs,
)
@@ -57,6 +66,8 @@
kwargs: Additional keyword arguments for :py:meth:`~boto3.session.Session.resource`
"""
+ default_serializer = dynamodb_document_serializer
+
def __init__(
self,
table_name: str,
@@ -132,14 +143,15 @@
# With a custom serializer, the value may be a Binary object
raw_value = result['Item']['value']
- return raw_value.value if isinstance(raw_value, Binary) else raw_value
+ value = raw_value.value if isinstance(raw_value, Binary) else raw_value
+ return self.deserialize(value)
def __setitem__(self, key, value):
- item = {**self._composite_key(key), 'value': value}
+ item = {**self._composite_key(key), 'value': self.serialize(value)}
# If enabled, set TTL value as a timestamp in unix format
- if self.ttl and getattr(value, 'ttl', None):
- item['ttl'] = int(time() + value.ttl)
+ if self.ttl and getattr(value, 'expires_unix', None):
+ item['ttl'] = value.expires_unix
self._table.put_item(Item=item)
@@ -169,19 +181,3 @@
def clear(self):
self.bulk_delete((k for k in self))
-
-
-class DynamoDbDocumentDict(DynamoDbDict):
- """Same as :class:`DynamoDbDict`, but serializes values before saving.
-
- By default, responses are only partially serialized into a DynamoDB-compatible document format.
- """
-
- def __init__(self, *args, serializer=None, **kwargs):
- super().__init__(*args, serializer=serializer or dynamodb_document_serializer, **kwargs)
-
- def __getitem__(self, key):
- return self.serializer.loads(super().__getitem__(key))
-
- def __setitem__(self, key, item):
- super().__setitem__(key, self.serializer.dumps(item))
diff --git a/requests_cache/backends/filesystem.py b/requests_cache/backends/filesystem.py
index 834b0ab..ece7042 100644
--- a/requests_cache/backends/filesystem.py
+++ b/requests_cache/backends/filesystem.py
@@ -12,7 +12,7 @@
from threading import RLock
from typing import Iterator
-from ..serializers import SERIALIZERS
+from ..serializers import SERIALIZERS, json_serializer
from . import BaseCache, BaseStorage
from .sqlite import AnyPath, SQLiteDict, get_cache_path
@@ -33,7 +33,7 @@
super().__init__(cache_name=str(cache_name), **kwargs)
self.responses: FileDict = FileDict(cache_name, use_temp=use_temp, **kwargs)
self.redirects: SQLiteDict = SQLiteDict(
- self.cache_dir / 'redirects.sqlite', 'redirects', **kwargs
+ self.cache_dir / 'redirects.sqlite', 'redirects', no_serializer=True, **kwargs
)
@property
@@ -59,6 +59,8 @@
class FileDict(BaseStorage):
"""A dictionary-like interface to files on the local filesystem"""
+ default_serializer = json_serializer
+
def __init__(
self,
cache_name: AnyPath,
@@ -91,7 +93,7 @@
mode = 'rb' if self.is_binary else 'r'
with self._try_io():
with self._path(key).open(mode) as f:
- return self.serializer.loads(f.read())
+ return self.deserialize(f.read())
def __delitem__(self, key):
with self._try_io():
@@ -100,7 +102,7 @@
def __setitem__(self, key, value):
with self._try_io():
with self._path(key).open(mode='wb' if self.is_binary else 'w') as f:
- f.write(self.serializer.dumps(value))
+ f.write(self.serialize(value))
def __iter__(self):
yield from self.keys()
diff --git a/requests_cache/backends/gridfs.py b/requests_cache/backends/gridfs.py
index eb139d5..0e01ad3 100644
--- a/requests_cache/backends/gridfs.py
+++ b/requests_cache/backends/gridfs.py
@@ -29,9 +29,13 @@
def __init__(self, db_name: str, **kwargs):
super().__init__(cache_name=db_name, **kwargs)
- self.responses = GridFSPickleDict(db_name, **kwargs)
+ self.responses = GridFSDict(db_name, **kwargs)
self.redirects = MongoDict(
- db_name, collection_name='redirects', connection=self.responses.connection, **kwargs
+ db_name,
+ collection_name='redirects',
+ connection=self.responses.connection,
+ no_serializer=True,
+ **kwargs
)
def remove_expired_responses(self, *args, **kwargs):
@@ -39,7 +43,7 @@
return super().remove_expired_responses(*args, **kwargs)
-class GridFSPickleDict(BaseStorage):
+class GridFSDict(BaseStorage):
"""A dictionary-like interface for a GridFS database
Args:
@@ -63,13 +67,13 @@
result = self.fs.find_one({'_id': key})
if result is None:
raise KeyError
- return self.serializer.loads(result.read())
+ return self.deserialize(result.read())
except CorruptGridFile as e:
logger.warning(e, exc_info=True)
raise KeyError
def __setitem__(self, key, item):
- value = self.serializer.dumps(item)
+ value = self.serialize(item)
encoding = None if isinstance(value, bytes) else 'utf-8'
with self._lock:
diff --git a/requests_cache/backends/mongodb.py b/requests_cache/backends/mongodb.py
index d0fe79c..eb1cc22 100644
--- a/requests_cache/backends/mongodb.py
+++ b/requests_cache/backends/mongodb.py
@@ -21,6 +21,7 @@
class MongoCache(BaseCache):
"""MongoDB cache backend.
+ By default, responses are only partially serialized into a MongoDB-compatible document format.
Args:
db_name: Database name
@@ -30,7 +31,7 @@
def __init__(self, db_name: str = 'http_cache', connection: MongoClient = None, **kwargs):
super().__init__(cache_name=db_name, **kwargs)
- self.responses: MongoDict = MongoDocumentDict(
+ self.responses: MongoDict = MongoDict(
db_name,
collection_name='responses',
connection=connection,
@@ -40,6 +41,7 @@
db_name,
collection_name='redirects',
connection=self.responses.connection,
+ no_serializer=True,
**kwargs,
)
@@ -68,6 +70,8 @@
kwargs: Additional keyword arguments for :py:class:`pymongo.MongoClient`
"""
+ default_serializer = bson_document_serializer
+
def __init__(
self,
db_name: str,
@@ -107,15 +111,17 @@
result = self.collection.find_one({'_id': key})
if result is None:
raise KeyError
- return result['data'] if 'data' in result else result
+ value = result['data'] if 'data' in result else result
+ return self.deserialize(value)
- def __setitem__(self, key, item):
- """If ``item`` is already a dict, its values will be stored under top-level keys.
+ def __setitem__(self, key, value):
+ """If ``value`` is already a dict, its values will be stored under top-level keys.
Otherwise, it will be stored under a 'data' key.
"""
- if not isinstance(item, Mapping):
- item = {'data': item}
- self.collection.replace_one({'_id': key}, item, upsert=True)
+ value = self.serialize(value)
+ if not isinstance(value, Mapping):
+ value = {'data': value}
+ self.collection.replace_one({'_id': key}, value, upsert=True)
def __delitem__(self, key):
result = self.collection.find_one_and_delete({'_id': key}, {'_id': True})
@@ -138,19 +144,3 @@
def close(self):
self.connection.close()
-
-
-class MongoDocumentDict(MongoDict):
- """Same as :class:`MongoDict`, but serializes values before saving.
-
- By default, responses are only partially serialized into a MongoDB-compatible document format.
- """
-
- def __init__(self, *args, serializer=None, **kwargs):
- super().__init__(*args, serializer=serializer or bson_document_serializer, **kwargs)
-
- def __getitem__(self, key):
- return self.serializer.loads(super().__getitem__(key))
-
- def __setitem__(self, key, item):
- super().__setitem__(key, self.serializer.dumps(item))
diff --git a/requests_cache/backends/redis.py b/requests_cache/backends/redis.py
index 0697023..95b44b5 100644
--- a/requests_cache/backends/redis.py
+++ b/requests_cache/backends/redis.py
@@ -11,13 +11,14 @@
from .._utils import get_valid_kwargs
from ..cache_keys import decode, encode
+from ..serializers import utf8_encoder
from . import BaseCache, BaseStorage
logger = getLogger(__name__)
# TODO: TTL tests
-# TODO: Option to set a different (typically longer) TTL than expire_after, like MongoCache
+# TODO: Option to set a TTL offset, for longer expiration than expire_after
class RedisCache(BaseCache):
"""Redis cache backend.
@@ -33,8 +34,13 @@
):
super().__init__(cache_name=namespace, **kwargs)
self.responses = RedisDict(namespace, connection=connection, ttl=ttl, **kwargs)
+ kwargs.pop('serializer', None)
self.redirects = RedisHashDict(
- namespace, 'redirects', connection=self.responses.connection, **kwargs
+ namespace,
+ 'redirects',
+ connection=self.responses.connection,
+ serializer=utf8_encoder, # Only needs encoding to/decoding from bytes
+ **kwargs,
)
@@ -75,14 +81,15 @@
result = self.connection.get(self._bkey(key))
if result is None:
raise KeyError
- return self.serializer.loads(result)
+ return self.deserialize(result)
def __setitem__(self, key, item):
"""Save an item to the cache, optionally with TTL"""
- if self.ttl and getattr(item, 'ttl', None):
- self.connection.setex(self._bkey(key), item.ttl, self.serializer.dumps(item))
+ expires_delta = getattr(item, 'expires_delta', None)
+ if self.ttl and (expires_delta or 0) > 0:
+ self.connection.setex(self._bkey(key), expires_delta, self.serialize(item))
else:
- self.connection.set(self._bkey(key), self.serializer.dumps(item))
+ self.connection.set(self._bkey(key), self.serialize(item))
def __delitem__(self, key):
if not self.connection.delete(self._bkey(key)):
@@ -115,14 +122,14 @@
return [(k, self[k]) for k in self.keys()]
def values(self):
- return [self.serializer.loads(v) for v in self.connection.mget(*self._bkeys(self.keys()))]
+ return [self.deserialize(v) for v in self.connection.mget(*self._bkeys(self.keys()))]
class RedisHashDict(BaseStorage):
"""A dictionary-like interface for operations on a single Redis hash
**Notes:**
- * All keys will be encoded as bytes, and all values will be serialized
+ * All keys will be encoded as bytes
* Items will be stored in a hash named ``namespace:collection_name``
"""
@@ -141,10 +148,10 @@
result = self.connection.hget(self._hash_key, encode(key))
if result is None:
raise KeyError
- return self.serializer.loads(result)
+ return self.deserialize(result)
def __setitem__(self, key, item):
- self.connection.hset(self._hash_key, encode(key), self.serializer.dumps(item))
+ self.connection.hset(self._hash_key, encode(key), self.serialize(item))
def __delitem__(self, key):
if not self.connection.hdel(self._hash_key, encode(key)):
@@ -170,10 +177,10 @@
def items(self):
"""Get all ``(key, value)`` pairs in the hash"""
return [
- (decode(k), self.serializer.loads(v))
+ (decode(k), self.deserialize(v))
for k, v in self.connection.hgetall(self._hash_key).items()
]
def values(self):
"""Get all values in the hash"""
- return [self.serializer.loads(v) for v in self.connection.hvals(self._hash_key)]
+ return [self.deserialize(v) for v in self.connection.hvals(self._hash_key)]
diff --git a/requests_cache/backends/sqlite.py b/requests_cache/backends/sqlite.py
index eb9b712..76cd1ab 100644
--- a/requests_cache/backends/sqlite.py
+++ b/requests_cache/backends/sqlite.py
@@ -7,18 +7,17 @@
import sqlite3
import threading
from contextlib import contextmanager
-from datetime import datetime
from logging import getLogger
from os import unlink
-from os.path import isfile
+from os.path import getsize, isfile
from pathlib import Path
from tempfile import gettempdir
+from time import time
from typing import Collection, Iterator, List, Tuple, Type, Union
from platformdirs import user_cache_dir
from .._utils import chunkify, get_valid_kwargs
-from ..models import CachedResponse
from ..policy.expiration import ExpirationTime
from . import BaseCache, BaseStorage
@@ -45,8 +44,10 @@
def __init__(self, db_path: AnyPath = 'http_cache', **kwargs):
super().__init__(cache_name=str(db_path), **kwargs)
- self.responses: SQLiteDict = SQLitePickleDict(db_path, table_name='responses', **kwargs)
- self.redirects: SQLiteDict = SQLiteDict(db_path, table_name='redirects', **kwargs)
+ self.responses: SQLiteDict = SQLiteDict(db_path, table_name='responses', **kwargs)
+ self.redirects: SQLiteDict = SQLiteDict(
+ db_path, table_name='redirects', no_serializer=True, **kwargs
+ )
@property
def db_path(self) -> AnyPath:
@@ -211,17 +212,19 @@
# raise error after the with block, otherwise the connection will be locked
if not row:
raise KeyError
- return row[0]
+
+ return self.deserialize(row[0])
def __setitem__(self, key, value):
- self._insert(key, value)
-
- def _insert(self, key, value, expires: datetime = None):
- posix_expires = round(expires.timestamp()) if expires else None
+ # If available, set expiration as a timestamp in unix format
+ expires = value.expires_unix if getattr(value, 'expires_unix', None) else None
+ value = self.serialize(value)
+ if isinstance(value, bytes):
+ value = sqlite3.Binary(value)
with self.connection(commit=True) as con:
con.execute(
f'INSERT OR REPLACE INTO {self.table_name} (key,value,expires) VALUES (?,?,?)',
- (key, value, posix_expires),
+ (key, value, expires),
)
def __iter__(self):
@@ -257,11 +260,26 @@
def clear_expired(self):
"""Remove expired items from the cache"""
- posix_now = round(datetime.utcnow().timestamp())
with self._lock, self.connection(commit=True) as con:
- con.execute(f"DELETE FROM {self.table_name} WHERE expires <= ?", (posix_now,))
+ con.execute(f"DELETE FROM {self.table_name} WHERE expires <= ?", (round(time()),))
self.vacuum()
+ def size(self) -> int:
+ """Return the size of the database, in bytes. For an in-memory database, this will be an
+ estimate based on page size.
+ """
+ try:
+ return getsize(self.db_path)
+ except IOError:
+ return self._estimate_size()
+
+ def _estimate_size(self) -> int:
+ """Estimate the current size of the database based on page count * size"""
+ with self.connection() as conn:
+ page_count = conn.execute('PRAGMA page_count').fetchone()[0]
+ page_size = conn.execute('PRAGMA page_size').fetchone()[0]
+ return page_count * page_size
+
def sorted(
self, key: str = 'expires', reversed: bool = False, limit: int = None, exclude_expired=False
):
@@ -278,9 +296,8 @@
filter_expr = ''
params: Tuple = ()
if exclude_expired:
- posix_now = round(datetime.utcnow().timestamp())
filter_expr = 'WHERE expires is null or expires > ?'
- params = (posix_now,)
+ params = (time(),)
with self.connection(commit=True) as con:
for row in con.execute(
@@ -288,36 +305,13 @@
f' ORDER BY {key} {direction} {limit_expr}',
params,
):
- yield row[0]
+ yield self.deserialize(row[0])
def vacuum(self):
with self.connection(commit=True) as con:
con.execute('VACUUM')
-class SQLitePickleDict(SQLiteDict):
- """Same as :class:`SQLiteDict`, but serializes values before saving"""
-
- def __setitem__(self, key, value: CachedResponse):
- serialized_value = self.serializer.dumps(value)
- if isinstance(serialized_value, bytes):
- serialized_value = sqlite3.Binary(serialized_value)
- super()._insert(key, serialized_value, getattr(value, 'expires', None))
-
- def __getitem__(self, key):
- return self.serializer.loads(super().__getitem__(key))
-
- def sorted(
- self,
- key: str = 'expires',
- reversed: bool = False,
- limit: int = None,
- exclude_expired: bool = False,
- ):
- for value in super().sorted(key, reversed, limit, exclude_expired):
- yield self.serializer.loads(value)
-
-
def _format_sequence(values: Collection) -> Tuple[str, List]:
"""Get SQL parameter marks for a sequence-based query"""
return ','.join(['?'] * len(values)), list(values)
@@ -372,9 +366,3 @@
uri: bool = False,
):
"""Template function to get an accurate signature for the builtin :py:func:`sqlite3.connect`"""
-
-
-# Aliases for backwards-compatibility
-DbCache = SQLiteCache
-DbDict = SQLiteDict
-DbPickeDict = SQLitePickleDict
diff --git a/requests_cache/models/response.py b/requests_cache/models/response.py
index c667856..9ad3a1a 100755
--- a/requests_cache/models/response.py
+++ b/requests_cache/models/response.py
@@ -2,6 +2,7 @@
from datetime import datetime, timedelta, timezone
from logging import getLogger
+from time import time
from typing import TYPE_CHECKING, List, Optional
import attr
@@ -134,12 +135,18 @@
return self.expires is not None and datetime.utcnow() >= self.expires
@property
- def ttl(self) -> Optional[int]:
- """Get time to expiration in seconds"""
- if self.expires is None or self.is_expired:
+ def expires_delta(self) -> Optional[int]:
+ """Get time to expiration in seconds (rounded to the nearest second)"""
+ if self.expires is None:
return None
delta = self.expires - datetime.utcnow()
- return int(delta.total_seconds())
+ return round(delta.total_seconds())
+
+ @property
+ def expires_unix(self) -> Optional[int]:
+ """Get expiration time as a Unix timestamp"""
+ seconds = self.expires_delta
+ return round(time() + seconds) if seconds else None
@property
def next(self) -> Optional[PreparedRequest]:
diff --git a/requests_cache/serializers/__init__.py b/requests_cache/serializers/__init__.py
index d49545a..a1fc1e5 100644
--- a/requests_cache/serializers/__init__.py
+++ b/requests_cache/serializers/__init__.py
@@ -1,6 +1,8 @@
"""Response serialization utilities. See :ref:`serializers` for general usage info.
"""
# flake8: noqa: F401
+from typing import Union
+
from .cattrs import CattrStage
from .pipeline import SerializerPipeline, Stage
from .preconf import (
@@ -39,10 +41,11 @@
'yaml': yaml_serializer,
}
+SerializerType = Union[str, SerializerPipeline, Stage]
-def init_serializer(serializer=None):
+
+def init_serializer(serializer: SerializerType = None):
"""Initialize a serializer from a name or instance"""
- serializer = serializer or 'pickle'
if isinstance(serializer, str):
serializer = SERIALIZERS[serializer]
return serializer
diff --git a/requests_cache/serializers/pipeline.py b/requests_cache/serializers/pipeline.py
index cf10714..3518229 100644
--- a/requests_cache/serializers/pipeline.py
+++ b/requests_cache/serializers/pipeline.py
@@ -9,7 +9,8 @@
class Stage:
- """Generic class to wrap serialization steps with consistent ``dumps()`` and ``loads()`` methods
+ """A single stage in a serializer pipeline. This wraps serialization steps with consistent
+ ``dumps()`` and ``loads()`` methods
Args:
obj: Serializer object or module, if applicable
diff --git a/requests_cache/serializers/preconf.py b/requests_cache/serializers/preconf.py
index 95fef1f..dc871a2 100644
--- a/requests_cache/serializers/preconf.py
+++ b/requests_cache/serializers/preconf.py
@@ -53,10 +53,9 @@
[base_stage], name='dict', is_binary=False
) #: Partial serializer that unstructures responses into dicts
pickle_serializer = SerializerPipeline(
- [base_stage, pickle], name='pickle', is_binary=True
+ [base_stage, Stage(pickle)], name='pickle', is_binary=True
) #: Pickle serializer
-
# Safe pickle serializer
def signer_stage(secret_key=None, salt='requests-cache') -> Stage:
"""Create a stage that uses ``itsdangerous`` to add a signature to responses on write, and
@@ -77,7 +76,7 @@
responses on write, and validate that signature with a secret key on read.
"""
return SerializerPipeline(
- [base_stage, pickle, signer_stage(secret_key, salt)],
+ [base_stage, Stage(pickle), signer_stage(secret_key, salt)],
name='safe_pickle',
is_binary=True,
)
diff --git a/requests_cache/session.py b/requests_cache/session.py
index 59a1d4f..40960ec 100644
--- a/requests_cache/session.py
+++ b/requests_cache/session.py
@@ -25,7 +25,7 @@
KeyCallback,
set_request_headers,
)
-from .serializers import SerializerPipeline
+from .serializers import SerializerType
__all__ = ['CachedSession', 'CacheMixin']
if TYPE_CHECKING:
@@ -45,7 +45,7 @@
self,
cache_name: str = DEFAULT_CACHE_NAME,
backend: BackendSpecifier = None,
- serializer: Union[str, SerializerPipeline] = None,
+ serializer: SerializerType = None,
expire_after: ExpirationTime = -1,
urls_expire_after: ExpirationPatterns = None,
cache_control: bool = False,
diff --git a/tests/integration/base_cache_test.py b/tests/integration/base_cache_test.py
index 081b5a1..6cd0a3e 100644
--- a/tests/integration/base_cache_test.py
+++ b/tests/integration/base_cache_test.py
@@ -51,6 +51,10 @@
VALIDATOR_HEADERS = [{'ETag': ETAG}, {'Last-Modified': LAST_MODIFIED}]
+def _valid_serializer(serializer) -> bool:
+ return isinstance(serializer, (SerializerPipeline, Stage))
+
+
class BaseCacheTest:
"""Base class for testing cache backend classes"""
@@ -79,7 +83,7 @@
"""Test all relevant combinations of (methods X data fields X serializers).
Requests with different request params, data, or json should be cached under different keys.
"""
- if not isinstance(serializer, (SerializerPipeline, Stage)):
+ if not _valid_serializer(serializer):
pytest.skip(f'Dependencies not installed for {serializer}')
url = httpbin(method.lower())
@@ -92,7 +96,7 @@
@pytest.mark.parametrize('response_format', HTTPBIN_FORMATS)
def test_all_response_formats(self, response_format, serializer):
"""Test all relevant combinations of (response formats X serializers)"""
- if not isinstance(serializer, SerializerPipeline):
+ if not _valid_serializer(serializer):
pytest.skip(f'Dependencies not installed for {serializer}')
session = self.init_session(serializer=serializer)
@@ -311,6 +315,9 @@
# Cache a response and some redirects, which should be the only non-expired cache items
session.get(httpbin('get'), expire_after=-1)
session.get(httpbin('redirect/3'), expire_after=-1)
+ assert len(session.cache.redirects.keys()) == 4
+ print(list(session.cache.redirects.items()))
+ print(list(session.cache.responses.keys()))
session.cache.remove_expired_responses()
assert len(session.cache.responses.keys()) == 2
diff --git a/tests/integration/base_storage_test.py b/tests/integration/base_storage_test.py
index 7d38643..ee3bf31 100644
--- a/tests/integration/base_storage_test.py
+++ b/tests/integration/base_storage_test.py
@@ -14,7 +14,6 @@
storage_class: Type[BaseStorage] = None
init_kwargs: Dict = {}
- picklable: bool = False
num_instances: int = 10 # Max number of cache instances to test
def init_cache(self, cache_name=CACHE_NAME, index=0, clear=True, **kwargs):
@@ -98,21 +97,20 @@
cache['key']
def test_picklable_dict(self):
- if self.picklable:
- cache = self.init_cache(serializer='pickle')
- original_obj = BasicDataclass(
- bool_attr=True,
- datetime_attr=datetime(2022, 2, 2),
- int_attr=2,
- str_attr='value',
- )
- cache['key_1'] = original_obj
+ cache = self.init_cache(serializer='pickle')
+ original_obj = BasicDataclass(
+ bool_attr=True,
+ datetime_attr=datetime(2022, 2, 2),
+ int_attr=2,
+ str_attr='value',
+ )
+ cache['key_1'] = original_obj
- obj = cache['key_1']
- assert obj.bool_attr == original_obj.bool_attr
- assert obj.datetime_attr == original_obj.datetime_attr
- assert obj.int_attr == original_obj.int_attr
- assert obj.str_attr == original_obj.str_attr
+ obj = cache['key_1']
+ assert obj.bool_attr == original_obj.bool_attr
+ assert obj.datetime_attr == original_obj.datetime_attr
+ assert obj.int_attr == original_obj.int_attr
+ assert obj.str_attr == original_obj.str_attr
def test_clear_and_work_again(self):
cache_1 = self.init_cache()
diff --git a/tests/integration/test_dynamodb.py b/tests/integration/test_dynamodb.py
index 84c1008..52fe3f3 100644
--- a/tests/integration/test_dynamodb.py
+++ b/tests/integration/test_dynamodb.py
@@ -5,7 +5,7 @@
import pytest
from botocore.exceptions import ClientError
-from requests_cache.backends import DynamoDbCache, DynamoDbDict, DynamoDbDocumentDict
+from requests_cache.backends import DynamoDbCache, DynamoDbDict
from requests_cache.serializers import dynamodb_document_serializer
from tests.conftest import HTTPBIN_FORMATS, HTTPBIN_METHODS, fail_if_no_connection
from tests.integration.base_cache_test import TEST_SERIALIZERS, BaseCacheTest
@@ -69,7 +69,7 @@
"""
cache = self.init_cache(ttl=ttl_enabled)
item = OrderedDict(foo='bar')
- item.ttl = 60
+ item.expires_unix = 60
cache['key'] = item
# 'ttl' is a reserved word, so to retrieve it we need to alias it
@@ -86,12 +86,6 @@
assert ttl_value is None
-class TestDynamoDbDocumentDict(BaseStorageTest):
- storage_class = DynamoDbDocumentDict
- init_kwargs = DYNAMODB_OPTIONS
- picklable = True
-
-
class TestDynamoDbCache(BaseCacheTest):
backend_class = DynamoDbCache
init_kwargs = DYNAMODB_OPTIONS
diff --git a/tests/integration/test_mongodb.py b/tests/integration/test_mongodb.py
index 262ddd6..5ff6ae9 100644
--- a/tests/integration/test_mongodb.py
+++ b/tests/integration/test_mongodb.py
@@ -6,14 +6,8 @@
from gridfs import GridFS
from gridfs.errors import CorruptGridFile, FileExists
-from requests_cache.backends import (
- GridFSCache,
- GridFSPickleDict,
- MongoCache,
- MongoDict,
- MongoDocumentDict,
-)
-from requests_cache.policy.expiration import NEVER_EXPIRE
+from requests_cache.backends import GridFSCache, GridFSDict, MongoCache, MongoDict
+from requests_cache.policy import NEVER_EXPIRE
from requests_cache.serializers import bson_document_serializer
from tests.conftest import HTTPBIN_FORMATS, HTTPBIN_METHODS, fail_if_no_connection, httpbin
from tests.integration.base_cache_test import TEST_SERIALIZERS, BaseCacheTest
@@ -37,11 +31,6 @@
class TestMongoDict(BaseStorageTest):
storage_class = MongoDict
-
-class TestMongoPickleDict(BaseStorageTest):
- storage_class = MongoDocumentDict
- picklable = True
-
def test_connection_kwargs(self):
"""A spot check to make sure optional connection kwargs gets passed to connection"""
# MongoClient prevents direct access to private members like __init_kwargs;
@@ -60,7 +49,6 @@
class TestMongoCache(BaseCacheTest):
backend_class = MongoCache
-
init_kwargs = {'serializer': None} # Use class default serializer instead of pickle
@pytest.mark.parametrize('serializer', MONGODB_SERIALIZERS)
@@ -113,14 +101,14 @@
assert session.cache.get_ttl() is None
-class TestGridFSPickleDict(BaseStorageTest):
- storage_class = GridFSPickleDict
+class TestGridFSDict(BaseStorageTest):
+ storage_class = GridFSDict
picklable = True
num_instances = 1 # Only test a single collecton instead of multiple
def test_connection_kwargs(self):
"""A spot check to make sure optional connection kwargs gets passed to connection"""
- cache = GridFSPickleDict(
+ cache = GridFSDict(
'test',
host='mongodb://0.0.0.0',
port=2222,
diff --git a/tests/integration/test_sqlite.py b/tests/integration/test_sqlite.py
index 84b399a..54b0ceb 100644
--- a/tests/integration/test_sqlite.py
+++ b/tests/integration/test_sqlite.py
@@ -9,13 +9,14 @@
from platformdirs import user_cache_dir
from requests_cache.backends.base import BaseCache
-from requests_cache.backends.sqlite import MEMORY_URI, SQLiteCache, SQLiteDict, SQLitePickleDict
+from requests_cache.backends.sqlite import MEMORY_URI, SQLiteCache, SQLiteDict
from requests_cache.models.response import CachedResponse
from tests.integration.base_cache_test import BaseCacheTest
from tests.integration.base_storage_test import CACHE_NAME, BaseStorageTest
-class SQLiteTestCase(BaseStorageTest):
+class TestSQLiteDict(BaseStorageTest):
+ storage_class = SQLiteDict
init_kwargs = {'use_temp': True}
@classmethod
@@ -178,15 +179,6 @@
with pytest.raises(ValueError):
list(cache.sorted(key='invalid_key'))
-
-class TestSQLiteDict(SQLiteTestCase):
- storage_class = SQLiteDict
-
-
-class TestSQLitePickleDict(SQLiteTestCase):
- storage_class = SQLitePickleDict
- picklable = True
-
@pytest.mark.parametrize('limit', [None, 50])
def test_sorted__by_expires(self, limit):
cache = self.init_cache()
@@ -213,7 +205,7 @@
for i in range(100):
delta = 101 - i
if i % 2 == 1:
- delta -= 100
+ delta -= 101
response = CachedResponse(status_code=i, expires=now + timedelta(seconds=delta))
cache[f'key_{i}'] = response
@@ -227,6 +219,20 @@
assert prev_item is None or prev_item.expires < item.expires
assert item.status_code % 2 == 0
+ @pytest.mark.parametrize(
+ 'db_path, use_temp',
+ [
+ ('filesize_test', True),
+ (':memory:', False),
+ ],
+ )
+ def test_size(self, db_path, use_temp):
+ """Test approximate expected size of a database, for both file-based and in-memory databases"""
+ cache = self.init_cache(db_path, use_temp=use_temp)
+ for i in range(100):
+ cache[f'key_{i}'] = f'value_{i}'
+ assert 10000 < cache.size() < 200000
+
class TestSQLiteCache(BaseCacheTest):
backend_class = SQLiteCache
diff --git a/tests/unit/test_base_cache.py b/tests/unit/test_base_cache.py
index 0b87411..a39b238 100644
--- a/tests/unit/test_base_cache.py
+++ b/tests/unit/test_base_cache.py
@@ -5,7 +5,7 @@
import pytest
from requests_cache import CachedResponse
-from requests_cache.backends import BaseCache, SQLitePickleDict
+from requests_cache.backends import BaseCache, SQLiteDict
from tests.conftest import MOCKED_URL, MOCKED_URL_HTTPS, MOCKED_URL_JSON, MOCKED_URL_REDIRECT
YESTERDAY = datetime.utcnow() - timedelta(days=1)
@@ -24,7 +24,7 @@
def test_urls__with_invalid_response(mock_session):
responses = [mock_session.get(url) for url in [MOCKED_URL, MOCKED_URL_JSON, MOCKED_URL_HTTPS]]
responses[2] = AttributeError
- with patch.object(SQLitePickleDict, '__getitem__', side_effect=responses):
+ with patch.object(SQLiteDict, '__getitem__', side_effect=responses):
expected_urls = [MOCKED_URL, MOCKED_URL_JSON]
assert set(mock_session.cache.urls) == set(expected_urls)
@@ -67,7 +67,7 @@
responses[1] = AttributeError
responses[2] = CachedResponse(expires=YESTERDAY, url='test')
- with patch.object(SQLitePickleDict, '__getitem__', side_effect=responses):
+ with patch.object(SQLiteDict, '__getitem__', side_effect=responses):
values = mock_session.cache.values(check_expiry=check_expiry)
assert len(list(values)) == expected_count
diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py
index e64a457..ad2ecbd 100644
--- a/tests/unit/test_session.py
+++ b/tests/unit/test_session.py
@@ -3,6 +3,7 @@
import time
from collections import UserDict, defaultdict
from datetime import datetime, timedelta
+from logging import getLogger
from pathlib import Path
from pickle import PickleError
from unittest.mock import patch
@@ -15,7 +16,7 @@
from requests_cache import ALL_METHODS, CachedSession
from requests_cache._utils import get_placeholder_class
-from requests_cache.backends import BACKEND_CLASSES, BaseCache, SQLiteDict, SQLitePickleDict
+from requests_cache.backends import BACKEND_CLASSES, BaseCache, SQLiteDict
from requests_cache.backends.base import DESERIALIZE_ERRORS
from requests_cache.policy.expiration import DO_NOT_CACHE, EXPIRE_IMMEDIATELY, NEVER_EXPIRE
from tests.conftest import (
@@ -32,6 +33,7 @@
# Some tests must disable url normalization to retain the custom `http+mock//` protocol
patch_normalize_url = patch('requests_cache.cache_keys.normalize_url', side_effect=lambda x, y: x)
+logger = getLogger(__name__)
# Basic initialization
# -----------------------------------------------------
@@ -67,14 +69,11 @@
def test_repr(mock_session):
"""Test session and cache string representations"""
mock_session.settings.expire_after = 11
- mock_session.cache.responses['key'] = 'value'
- mock_session.cache.redirects['key'] = 'value'
- mock_session.cache.redirects['key_2'] = 'value'
+ mock_session.settings.cache_control = True
assert mock_session.cache.cache_name in repr(mock_session)
- assert '11' in repr(mock_session)
- assert '2 redirects' in str(mock_session.cache)
- assert '1 responses' in str(mock_session.cache)
+ assert 'expire_after=11' in repr(mock_session)
+ assert 'cache_control=True' in repr(mock_session)
def test_response_defaults(mock_session):
@@ -421,7 +420,7 @@
"""If there is an error during deserialization, the request should be made again"""
assert mock_session.get(MOCKED_URL_JSON).from_cache is False
- with patch.object(SQLitePickleDict, '__getitem__', side_effect=PickleError):
+ with patch.object(SQLiteDict, '__getitem__', side_effect=PickleError):
resp = mock_session.get(MOCKED_URL_JSON)
assert resp.from_cache is False
assert resp.json()['message'] == 'mock json response'
@@ -649,7 +648,7 @@
raise PickleError
return response_1
- with patch.object(SQLitePickleDict, '__getitem__', side_effect=error_on_key):
+ with patch.object(SQLiteDict, '__getitem__', side_effect=error_on_key):
BaseCache.remove_expired_responses(mock_session.cache)
assert len(mock_session.cache.responses) == 1
assert mock_session.get(MOCKED_URL).from_cache is True
@@ -689,17 +688,17 @@
mock_session.mock_adapter.register_uri('GET', second_url, status_code=200)
mock_session.mock_adapter.register_uri('GET', third_url, status_code=200)
mock_session.get(MOCKED_URL)
- mock_session.get(second_url, expire_after=1)
- mock_session.get(third_url, expire_after=2)
+ mock_session.get(second_url, expire_after=2)
+ mock_session.get(third_url, expire_after=4)
# All 3 responses should still be cached
mock_session.remove_expired_responses()
for response in mock_session.cache.responses.values():
- print('Expires:', response.expires - datetime.utcnow() if response.expires else None)
+ logger.info(f'Expires in {response.expires_delta} seconds')
assert len(mock_session.cache.responses) == 3
- # One should be expired after 1s, and another should be expired after 2s
- time.sleep(1)
+ # One should be expired after 2s, and another should be expired after 4s
+ time.sleep(2)
mock_session.remove_expired_responses()
assert len(mock_session.cache.responses) == 2
time.sleep(2)