Add get_ttl() method for convenience
diff --git a/requests_cache/backends/mongodb.py b/requests_cache/backends/mongodb.py
index 16b65ff..cd39f49 100644
--- a/requests_cache/backends/mongodb.py
+++ b/requests_cache/backends/mongodb.py
@@ -97,7 +97,7 @@
"""
from datetime import timedelta
from logging import getLogger
-from typing import Iterable, Mapping, Union
+from typing import Iterable, Mapping, Optional, Union
from pymongo import MongoClient
from pymongo.errors import OperationFailure
@@ -112,7 +112,6 @@
logger = getLogger(__name__)
-# TODO: TTL tests
# TODO: Is there any reason to support custom serializers here?
# TODO: Save items with different cache keys to avoid conflicts with old serialization format?
class MongoCache(BaseCache):
@@ -139,6 +138,10 @@
**kwargs,
)
+ def get_ttl(self) -> Optional[int]:
+ """Get the currently defined TTL value in seconds, if any"""
+ return self.responses.get_ttl()
+
def set_ttl(self, ttl: Union[int, timedelta], overwrite: bool = False):
"""Create or update a TTL index. Notes:
@@ -172,15 +175,25 @@
self.connection = connection or MongoClient(**connection_kwargs)
self.collection = self.connection[db_name][collection_name]
- def set_ttl(self, ttl: Union[int, timedelta], overwrite: bool = False):
- if overwrite:
- try:
- self.collection.drop_index('ttl_idx')
- logger.info('Dropped TTL index')
- except OperationFailure:
- pass
+ def get_ttl(self) -> Optional[int]:
+ """Get the currently defined TTL value in seconds, if any"""
+ idx_info = self.collection.index_information().get('ttl_idx', {})
+ return idx_info.get('expireAfterSeconds')
- ttl = get_expiration_seconds(ttl)
+ def set_ttl(self, ttl: Union[int, timedelta], overwrite: bool = False):
+ """Create or update a TTL index, and ignore and log any errors due to dropping a nonexistent
+ index or attempting to overwrite without ```overwrite=True``.
+ """
+ try:
+ self._set_ttl(get_expiration_seconds(ttl), overwrite=overwrite)
+ except OperationFailure:
+ logger.warning('Failed to update TTL index', exc_info=True)
+
+ def _set_ttl(self, ttl: int, overwrite: bool = False):
+ if overwrite:
+ self.collection.drop_index('ttl_idx')
+ logger.info('Dropped TTL index')
+
if ttl and ttl != NEVER_EXPIRE:
logger.info(f'Creating TTL index for {ttl} seconds')
self.collection.create_index('created_at', name='ttl_idx', expireAfterSeconds=ttl)