Improve GridFS backend thread safety
diff --git a/requests_cache/backends/gridfs.py b/requests_cache/backends/gridfs.py index ca370f6..418dba6 100644 --- a/requests_cache/backends/gridfs.py +++ b/requests_cache/backends/gridfs.py
@@ -11,13 +11,19 @@ :classes-only: :nosignatures: """ +from logging import getLogger +from threading import RLock + from gridfs import GridFS +from gridfs.errors import FileExists from pymongo import MongoClient from .._utils import get_valid_kwargs from .base import BaseCache, BaseStorage from .mongodb import MongoDict +logger = getLogger(__name__) + class GridFSCache(BaseCache): """GridFS cache backend. @@ -56,27 +62,32 @@ self.connection = connection or MongoClient(**connection_kwargs) self.db = self.connection[db_name] self.fs = GridFS(self.db) + self._lock = RLock() def __getitem__(self, key): - result = self.fs.find_one({'_id': key}) - if result is None: - raise KeyError - return self.serializer.loads(result.read()) + with self._lock: + result = self.fs.find_one({'_id': key}) + if result is None: + raise KeyError + return self.serializer.loads(result.read()) def __setitem__(self, key, item): - try: - self.__delitem__(key) - except KeyError: - pass value = self.serializer.dumps(item) encoding = None if isinstance(value, bytes) else 'utf-8' - self.fs.put(value, encoding=encoding, **{'_id': key}) + + with self._lock: + try: + self.fs.delete(key) + self.fs.put(value, encoding=encoding, **{'_id': key}) + except FileExists as e: + logger.warning(e, exc_info=True) def __delitem__(self, key): - res = self.fs.find_one({'_id': key}) - if res is None: - raise KeyError - self.fs.delete(res._id) + with self._lock: + res = self.fs.find_one({'_id': key}) + if res is None: + raise KeyError + self.fs.delete(res._id) def __len__(self): return self.db['fs.files'].estimated_document_count()
diff --git a/tests/integration/base_cache_test.py b/tests/integration/base_cache_test.py index 750c31f..13e2ede 100644 --- a/tests/integration/base_cache_test.py +++ b/tests/integration/base_cache_test.py
@@ -87,7 +87,7 @@ pytest.skip(f'Dependencies not installed for {serializer}') session = self.init_session(serializer=serializer) - # Temporary workaround for this issue: https://github.com/kevin1024/pytest-httpbin/issues/60 + # Workaround for this issue: https://github.com/kevin1024/pytest-httpbin/issues/60 if response_format == 'json' and USE_PYTEST_HTTPBIN: session.allowable_codes = (200, 404)