Split datetime-related utility functions into a separate module
diff --git a/requests_cache/_utils.py b/requests_cache/_utils.py
index b3dfaee..b0067cd 100644
--- a/requests_cache/_utils.py
+++ b/requests_cache/_utils.py
@@ -1,7 +1,7 @@
-"""Miscellaneous minor utility functions that don't really belong anywhere else"""
+"""Minor internal utility functions that don't really belong anywhere else"""
from inspect import signature
from logging import getLogger
-from typing import Any, Callable, Dict, Iterable, Iterator, List
+from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional
logger = getLogger('requests_cache')
@@ -18,6 +18,18 @@
return next((v for v in values if v is not None), default)
+def decode(value, encoding='utf-8') -> str:
+ """Decode a value from bytes, if hasn't already been.
+ Note: ``PreparedRequest.body`` is always encoded in utf-8.
+ """
+ return value.decode(encoding) if isinstance(value, bytes) else value
+
+
+def encode(value, encoding='utf-8') -> bytes:
+ """Encode a value to bytes, if it hasn't already been"""
+ return value if isinstance(value, bytes) else str(value).encode(encoding)
+
+
def get_placeholder_class(original_exception: Exception = None):
"""Create a placeholder type for a class that does not have dependencies installed.
This allows delaying ImportErrors until init time, rather than at import time.
@@ -46,3 +58,11 @@
params = list(signature(func).parameters)
params.extend(extras or [])
return {k: v for k, v in kwargs.items() if k in params and v is not None}
+
+
+def try_int(value: Any) -> Optional[int]:
+ """Convert a value to an int, if possible, otherwise ``None``"""
+ try:
+ return int(value)
+ except (TypeError, ValueError):
+ return None
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py
index 77df723..ffbe0ed 100644
--- a/requests_cache/backends/base.py
+++ b/requests_cache/backends/base.py
@@ -12,8 +12,8 @@
from logging import getLogger
from typing import Iterable, Iterator, Optional, Tuple, Union
-from ..cache_control import ExpirationTime
from ..cache_keys import create_key, redact_response
+from ..expiration import ExpirationTime
from ..models import AnyRequest, AnyResponse, CachedResponse, CacheSettings
from ..serializers import init_serializer
diff --git a/requests_cache/cache_control.py b/requests_cache/cache_control.py
index 10a1de8..2a32540 100644
--- a/requests_cache/cache_control.py
+++ b/requests_cache/cache_control.py
@@ -12,30 +12,29 @@
"""
from __future__ import annotations
-from datetime import datetime, timedelta, timezone
-from email.utils import parsedate_to_datetime
-from fnmatch import fnmatch
+from datetime import datetime
from logging import getLogger
-from math import ceil
-from typing import TYPE_CHECKING, Any, Dict, MutableMapping, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, MutableMapping, Optional, Tuple, Union
from attr import define, field
from requests import PreparedRequest, Response
from requests.models import CaseInsensitiveDict
-from ._utils import coalesce
+from ._utils import coalesce, try_int
+from .expiration import (
+ DO_NOT_CACHE,
+ NEVER_EXPIRE,
+ ExpirationTime,
+ get_expiration_datetime,
+ get_url_expiration,
+)
-__all__ = ['DO_NOT_CACHE', 'CacheActions']
+__all__ = ['CacheActions']
if TYPE_CHECKING:
from .models import CachedResponse, CacheSettings, RequestSettings
-# May be set by either headers or expire_after param to disable caching or disable expiration
-DO_NOT_CACHE = 0
-NEVER_EXPIRE = -1
CacheDirective = Union[None, int, bool]
-ExpirationTime = Union[None, int, float, str, datetime, timedelta]
-ExpirationPatterns = Dict[str, ExpirationTime]
logger = getLogger(__name__)
@@ -132,9 +131,10 @@
elif is_expired and not (self.settings.only_if_cached and self.settings.stale_if_error):
self.resend_request = True
- if response is None:
- return
+ if response is not None:
+ self._update_validation_headers(response)
+ def _update_validation_headers(self, response: CachedResponse):
# Revalidation may be triggered by either stale response or request/cached response headers
directives = get_cache_directives(response.headers)
self.revalidate = _has_validator(response.headers) and any(
@@ -146,6 +146,7 @@
]
)
+ # Add the appropriate validation headers, if needed
if self.revalidate:
self.send_request = True
if response.headers.get('ETag'):
@@ -156,7 +157,7 @@
def update_from_response(self, response: Response):
"""Update expiration + actions based on headers from a new response.
- Used after receiving a new response but before saving it to the cache.
+ Used after receiving a new response, but before saving it to the cache.
"""
if not response or not self.settings.cache_control:
return
@@ -222,32 +223,6 @@
return headers
-def get_expiration_datetime(expire_after: ExpirationTime) -> Optional[datetime]:
- """Convert an expiration value in any supported format to an absolute datetime"""
- # Never expire
- if expire_after is None or expire_after == NEVER_EXPIRE:
- return None
- # Expire immediately
- elif try_int(expire_after) == DO_NOT_CACHE:
- return datetime.utcnow()
- # Already a datetime or datetime str
- if isinstance(expire_after, str):
- return parse_http_date(expire_after)
- elif isinstance(expire_after, datetime):
- return to_utc(expire_after)
-
- # Otherwise, it must be a timedelta or time in seconds
- if not isinstance(expire_after, timedelta):
- expire_after = timedelta(seconds=expire_after)
- return datetime.utcnow() + expire_after
-
-
-def get_expiration_seconds(expire_after: ExpirationTime) -> int:
- """Convert an expiration value in any supported format to an expiration time in seconds"""
- expires = get_expiration_datetime(expire_after)
- return ceil((expires - datetime.utcnow()).total_seconds()) if expires else NEVER_EXPIRE
-
-
def get_cache_directives(headers: MutableMapping) -> Dict[str, CacheDirective]:
"""Get all Cache-Control directives as a dict. Handle duplicate headers and comma-separated
lists. Key-only directives are returned as ``{key: True}``.
@@ -265,41 +240,6 @@
return kv_directives
-def get_504_response(request: PreparedRequest) -> Response:
- from .models import CachedResponse
-
- return CachedResponse(
- url=request.url or '',
- status_code=504,
- reason='Not Cached',
- request=request, # type: ignore
- )
-
-
-def get_url_expiration(
- url: Optional[str], urls_expire_after: ExpirationPatterns = None
-) -> ExpirationTime:
- """Check for a matching per-URL expiration, if any"""
- if not url:
- return None
-
- for pattern, expire_after in (urls_expire_after or {}).items():
- if url_match(url, pattern):
- logger.debug(f'URL {url} matched pattern "{pattern}": {expire_after}')
- return expire_after
- return None
-
-
-def parse_http_date(value: str) -> Optional[datetime]:
- """Attempt to parse an HTTP (RFC 5322-compatible) timestamp"""
- try:
- expire_after = parsedate_to_datetime(value)
- return to_utc(expire_after)
- except (TypeError, ValueError):
- logger.debug(f'Failed to parse timestamp: {value}')
- return None
-
-
def split_kv_directive(header_value: str) -> Tuple[str, CacheDirective]:
"""Split a cache directive into a ``(key, int)`` pair, if possible; otherwise just
``(key, True)``.
@@ -312,43 +252,5 @@
return header_value, True
-def to_utc(dt: datetime):
- """All internal datetimes are UTC and timezone-naive. Convert any user/header-provided
- datetimes to the same format.
- """
- if dt.tzinfo:
- dt = dt.astimezone(timezone.utc)
- dt = dt.replace(tzinfo=None)
- return dt
-
-
-def try_int(value: Any) -> Optional[int]:
- """Convert a value to an int, if possible, otherwise ``None``"""
- try:
- return int(value)
- except (TypeError, ValueError):
- return None
-
-
-def url_match(url: str, pattern: str) -> bool:
- """Determine if a URL matches a pattern
-
- Args:
- url: URL to test. Its base URL (without protocol) will be used.
- pattern: Glob pattern to match against. A recursive wildcard will be added if not present
-
- Example:
- >>> url_match('https://httpbin.org/delay/1', 'httpbin.org/delay')
- True
- >>> url_match('https://httpbin.org/stream/1', 'httpbin.org/*/1')
- True
- >>> url_match('https://httpbin.org/stream/2', 'httpbin.org/*/1')
- False
- """
- url = url.split('://')[-1]
- pattern = pattern.split('://')[-1].rstrip('*') + '**'
- return fnmatch(url, pattern)
-
-
def _has_validator(headers: MutableMapping) -> bool:
return bool(headers.get('ETag') or headers.get('Last-Modified'))
diff --git a/requests_cache/cache_keys.py b/requests_cache/cache_keys.py
index ba9f626..e28d533 100644
--- a/requests_cache/cache_keys.py
+++ b/requests_cache/cache_keys.py
@@ -16,12 +16,11 @@
from requests.models import CaseInsensitiveDict
from url_normalize import url_normalize
-from ._utils import get_valid_kwargs
-
-if TYPE_CHECKING:
- from .models import AnyPreparedRequest, AnyRequest, CachedResponse
+from ._utils import decode, encode, get_valid_kwargs
__all__ = ['create_key', 'normalize_request']
+if TYPE_CHECKING:
+ from .models import AnyPreparedRequest, AnyRequest, CachedResponse
# Request headers that are always excluded from cache keys, but not redacted from cached responses
DEFAULT_EXCLUDE_HEADERS = {'Cache-Control', 'If-None-Match', 'If-Modified-Since'}
@@ -201,18 +200,6 @@
return response
-def decode(value, encoding='utf-8') -> str:
- """Decode a value from bytes, if hasn't already been.
- Note: ``PreparedRequest.body`` is always encoded in utf-8.
- """
- return value.decode(encoding) if isinstance(value, bytes) else value
-
-
-def encode(value, encoding='utf-8') -> bytes:
- """Encode a value to bytes, if it hasn't already been"""
- return value if isinstance(value, bytes) else str(value).encode(encoding)
-
-
def filter_sort_json(data: Union[List, Mapping], ignored_parameters: ParamList):
if isinstance(data, Mapping):
return filter_sort_dict(data, ignored_parameters)
diff --git a/requests_cache/expiration.py b/requests_cache/expiration.py
new file mode 100644
index 0000000..e70cae4
--- /dev/null
+++ b/requests_cache/expiration.py
@@ -0,0 +1,100 @@
+"""Utility functions used for converting expiration values"""
+from datetime import datetime, timedelta, timezone
+from email.utils import parsedate_to_datetime
+from fnmatch import fnmatch
+from logging import getLogger
+from math import ceil
+from typing import Dict, Optional, Union
+
+from ._utils import try_int
+
+__all__ = ['DO_NOT_CACHE', 'NEVER_EXPIRE', 'get_expiration_datetime']
+
+# May be set by either headers or expire_after param to disable caching or disable expiration
+DO_NOT_CACHE = 0
+NEVER_EXPIRE = -1
+
+ExpirationTime = Union[None, int, float, str, datetime, timedelta]
+ExpirationPatterns = Dict[str, ExpirationTime]
+
+logger = getLogger(__name__)
+
+
+def get_expiration_datetime(expire_after: ExpirationTime) -> Optional[datetime]:
+ """Convert an expiration value in any supported format to an absolute datetime"""
+ # Never expire
+ if expire_after is None or expire_after == NEVER_EXPIRE:
+ return None
+ # Expire immediately
+ elif try_int(expire_after) == DO_NOT_CACHE:
+ return datetime.utcnow()
+ # Already a datetime or datetime str
+ if isinstance(expire_after, str):
+ return parse_http_date(expire_after)
+ elif isinstance(expire_after, datetime):
+ return to_utc(expire_after)
+
+ # Otherwise, it must be a timedelta or time in seconds
+ if not isinstance(expire_after, timedelta):
+ expire_after = timedelta(seconds=expire_after)
+ return datetime.utcnow() + expire_after
+
+
+def get_expiration_seconds(expire_after: ExpirationTime) -> int:
+ """Convert an expiration value in any supported format to an expiration time in seconds"""
+ expires = get_expiration_datetime(expire_after)
+ return ceil((expires - datetime.utcnow()).total_seconds()) if expires else NEVER_EXPIRE
+
+
+def get_url_expiration(
+ url: Optional[str], urls_expire_after: ExpirationPatterns = None
+) -> ExpirationTime:
+ """Check for a matching per-URL expiration, if any"""
+ if not url:
+ return None
+
+ for pattern, expire_after in (urls_expire_after or {}).items():
+ if url_match(url, pattern):
+ logger.debug(f'URL {url} matched pattern "{pattern}": {expire_after}')
+ return expire_after
+ return None
+
+
+def parse_http_date(value: str) -> Optional[datetime]:
+ """Attempt to parse an HTTP (RFC 5322-compatible) timestamp"""
+ try:
+ expire_after = parsedate_to_datetime(value)
+ return to_utc(expire_after)
+ except (TypeError, ValueError):
+ logger.debug(f'Failed to parse timestamp: {value}')
+ return None
+
+
+def to_utc(dt: datetime):
+ """All internal datetimes are UTC and timezone-naive. Convert any user/header-provided
+ datetimes to the same format.
+ """
+ if dt.tzinfo:
+ dt = dt.astimezone(timezone.utc)
+ dt = dt.replace(tzinfo=None)
+ return dt
+
+
+def url_match(url: str, pattern: str) -> bool:
+ """Determine if a URL matches a pattern
+
+ Args:
+ url: URL to test. Its base URL (without protocol) will be used.
+ pattern: Glob pattern to match against. A recursive wildcard will be added if not present
+
+ Example:
+ >>> url_match('https://httpbin.org/delay/1', 'httpbin.org/delay')
+ True
+ >>> url_match('https://httpbin.org/stream/1', 'httpbin.org/*/1')
+ True
+ >>> url_match('https://httpbin.org/stream/2', 'httpbin.org/*/1')
+ False
+ """
+ url = url.split('://')[-1]
+ pattern = pattern.split('://')[-1].rstrip('*') + '**'
+ return fnmatch(url, pattern)
diff --git a/requests_cache/models/response.py b/requests_cache/models/response.py
index e0dbbce..b760776 100755
--- a/requests_cache/models/response.py
+++ b/requests_cache/models/response.py
@@ -9,7 +9,7 @@
from requests.structures import CaseInsensitiveDict
from urllib3._collections import HTTPHeaderDict
-from ..cache_control import ExpirationTime, get_expiration_datetime
+from ..expiration import ExpirationTime, get_expiration_datetime
from . import CachedHTTPResponse, CachedRequest
DATETIME_FORMAT = '%Y-%m-%d %H:%M:%S %Z' # Format used for __str__ only
diff --git a/requests_cache/models/settings.py b/requests_cache/models/settings.py
index b32df3d..adc03ec 100644
--- a/requests_cache/models/settings.py
+++ b/requests_cache/models/settings.py
@@ -3,7 +3,7 @@
from attr import asdict, define, field
from .._utils import get_valid_kwargs
-from ..cache_control import ExpirationTime
+from ..expiration import ExpirationTime
if TYPE_CHECKING:
from . import AnyResponse
diff --git a/requests_cache/patcher.py b/requests_cache/patcher.py
index c2ac61d..72ae15e 100644
--- a/requests_cache/patcher.py
+++ b/requests_cache/patcher.py
@@ -14,7 +14,7 @@
import requests
from .backends import BackendSpecifier, BaseCache
-from .cache_control import ExpirationTime
+from .expiration import ExpirationTime
from .session import CachedSession, OriginalSession
logger = getLogger(__name__)
diff --git a/requests_cache/session.py b/requests_cache/session.py
index 218605d..b957124 100644
--- a/requests_cache/session.py
+++ b/requests_cache/session.py
@@ -25,13 +25,8 @@
from ._utils import get_valid_kwargs
from .backends import BackendSpecifier, init_backend
-from .cache_control import (
- CacheActions,
- ExpirationTime,
- append_directive,
- get_504_response,
- get_expiration_seconds,
-)
+from .cache_control import CacheActions, append_directive
+from .expiration import ExpirationTime, get_expiration_seconds
from .models import (
AnyResponse,
CachedResponse,
@@ -285,6 +280,16 @@
"""
+def get_504_response(request: PreparedRequest) -> CachedResponse:
+ """Get a 504: Not Cached error response, for use with only-if-cached option"""
+ return CachedResponse(
+ url=request.url or '',
+ status_code=504,
+ reason='Not Cached',
+ request=request, # type: ignore
+ )
+
+
@contextmanager
def patch_form_boundary(**request_kwargs):
"""If the ``files`` param is present, patch the form boundary used to separate multipart
diff --git a/tests/unit/test_cache_control.py b/tests/unit/test_cache_control.py
index c942891..1d9f417 100644
--- a/tests/unit/test_cache_control.py
+++ b/tests/unit/test_cache_control.py
@@ -1,17 +1,12 @@
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
import pytest
from requests import PreparedRequest
-from requests_cache.cache_control import (
- DO_NOT_CACHE,
- CacheActions,
- get_expiration_datetime,
- get_url_expiration,
-)
-from requests_cache.models.response import CachedResponse, CacheSettings, RequestSettings
-from tests.conftest import ETAG, HTTPDATE_DATETIME, HTTPDATE_STR, LAST_MODIFIED
+from requests_cache.cache_control import DO_NOT_CACHE, CacheActions
+from requests_cache.models import CachedResponse, CacheSettings, RequestSettings
+from tests.conftest import ETAG, HTTPDATE_STR, LAST_MODIFIED
IGNORED_DIRECTIVES = [
'no-transform',
@@ -261,7 +256,7 @@
@pytest.mark.parametrize('validator_headers', [{'ETag': ETAG}, {'Last-Modified': LAST_MODIFIED}])
@pytest.mark.parametrize('cache_headers', [{'Cache-Control': 'max-age=0'}, {'Expires': '0'}])
-@patch('requests_cache.cache_control.datetime')
+@patch('requests_cache.expiration.datetime')
def test_update_from_response__revalidate(mock_datetime, cache_headers, validator_headers):
"""If expiration is 0 and there's a validator, the response should be cached, but with immediate
expiration
@@ -289,81 +284,3 @@
assert actions.revalidate is False
assert actions.skip_read is False
assert actions.skip_write is False
-
-
-@patch('requests_cache.cache_control.datetime')
-def test_get_expiration_datetime__no_expiration(mock_datetime):
- assert get_expiration_datetime(None) is None
- assert get_expiration_datetime(-1) is None
- assert get_expiration_datetime(DO_NOT_CACHE) == mock_datetime.utcnow()
-
-
-@pytest.mark.parametrize(
- 'expire_after, expected_expiration_delta',
- [
- (timedelta(seconds=60), timedelta(seconds=60)),
- (60, timedelta(seconds=60)),
- (33.3, timedelta(seconds=33.3)),
- ],
-)
-def test_get_expiration_datetime__relative(expire_after, expected_expiration_delta):
- expires = get_expiration_datetime(expire_after)
- expected_expiration = datetime.utcnow() + expected_expiration_delta
- # Instead of mocking datetime (which adds some complications), check for approximate value
- assert abs((expires - expected_expiration).total_seconds()) <= 1
-
-
-def test_get_expiration_datetime__tzinfo():
- tz = timezone(-timedelta(hours=5))
- dt = datetime(2021, 2, 1, 7, 0, tzinfo=tz)
- assert get_expiration_datetime(dt) == datetime(2021, 2, 1, 12, 0)
-
-
-def test_get_expiration_datetime__httpdate():
- assert get_expiration_datetime(HTTPDATE_STR) == HTTPDATE_DATETIME
- assert get_expiration_datetime('P12Y34M56DT78H90M12.345S') is None
-
-
-@pytest.mark.parametrize(
- 'url, expected_expire_after',
- [
- ('img.site_1.com', 60 * 60),
- ('http://img.site_1.com/base/img.jpg', 60 * 60),
- ('https://img.site_2.com/base/img.jpg', None),
- ('site_2.com/resource_1', 60 * 60 * 2),
- ('http://site_2.com/resource_1/index.html', 60 * 60 * 2),
- ('http://site_2.com/resource_2/', 60 * 60 * 24),
- ('http://site_2.com/static/', -1),
- ('http://site_2.com/static/img.jpg', -1),
- ('site_2.com', None),
- ('some_other_site.com', None),
- (None, None),
- ],
-)
-def test_get_url_expiration(url, expected_expire_after, mock_session):
- urls_expire_after = {
- '*.site_1.com': 60 * 60,
- 'site_2.com/resource_1': 60 * 60 * 2,
- 'site_2.com/resource_2': 60 * 60 * 24,
- 'site_2.com/static': -1,
- }
- assert get_url_expiration(url, urls_expire_after) == expected_expire_after
-
-
-@pytest.mark.parametrize(
- 'url, expected_expire_after',
- [
- ('https://img.site_1.com/image.jpeg', 60 * 60),
- ('https://img.site_1.com/resource/1', 60 * 60 * 2),
- ('https://site_2.com', 1),
- ('https://any_other_site.com', 1),
- ],
-)
-def test_get_url_expiration__evaluation_order(url, expected_expire_after):
- """If there are multiple matches, the first match should be used in the order defined"""
- urls_expire_after = {
- '*.site_1.com/resource': 60 * 60 * 2,
- '*.site_1.com': 60 * 60,
- '*': 1,
- }
- assert get_url_expiration(url, urls_expire_after) == expected_expire_after
diff --git a/tests/unit/test_expiration.py b/tests/unit/test_expiration.py
new file mode 100644
index 0000000..05490dd
--- /dev/null
+++ b/tests/unit/test_expiration.py
@@ -0,0 +1,85 @@
+from datetime import datetime, timedelta, timezone
+from unittest.mock import patch
+
+import pytest
+
+from requests_cache.expiration import DO_NOT_CACHE, get_expiration_datetime, get_url_expiration
+from tests.conftest import HTTPDATE_DATETIME, HTTPDATE_STR
+
+
+@patch('requests_cache.expiration.datetime')
+def test_get_expiration_datetime__no_expiration(mock_datetime):
+ assert get_expiration_datetime(None) is None
+ assert get_expiration_datetime(-1) is None
+ assert get_expiration_datetime(DO_NOT_CACHE) == mock_datetime.utcnow()
+
+
+@pytest.mark.parametrize(
+ 'expire_after, expected_expiration_delta',
+ [
+ (timedelta(seconds=60), timedelta(seconds=60)),
+ (60, timedelta(seconds=60)),
+ (33.3, timedelta(seconds=33.3)),
+ ],
+)
+def test_get_expiration_datetime__relative(expire_after, expected_expiration_delta):
+ expires = get_expiration_datetime(expire_after)
+ expected_expiration = datetime.utcnow() + expected_expiration_delta
+ # Instead of mocking datetime (which adds some complications), check for approximate value
+ assert abs((expires - expected_expiration).total_seconds()) <= 1
+
+
+def test_get_expiration_datetime__tzinfo():
+ tz = timezone(-timedelta(hours=5))
+ dt = datetime(2021, 2, 1, 7, 0, tzinfo=tz)
+ assert get_expiration_datetime(dt) == datetime(2021, 2, 1, 12, 0)
+
+
+def test_get_expiration_datetime__httpdate():
+ assert get_expiration_datetime(HTTPDATE_STR) == HTTPDATE_DATETIME
+ assert get_expiration_datetime('P12Y34M56DT78H90M12.345S') is None
+
+
+@pytest.mark.parametrize(
+ 'url, expected_expire_after',
+ [
+ ('img.site_1.com', 60 * 60),
+ ('http://img.site_1.com/base/img.jpg', 60 * 60),
+ ('https://img.site_2.com/base/img.jpg', None),
+ ('site_2.com/resource_1', 60 * 60 * 2),
+ ('http://site_2.com/resource_1/index.html', 60 * 60 * 2),
+ ('http://site_2.com/resource_2/', 60 * 60 * 24),
+ ('http://site_2.com/static/', -1),
+ ('http://site_2.com/static/img.jpg', -1),
+ ('site_2.com', None),
+ ('some_other_site.com', None),
+ (None, None),
+ ],
+)
+def test_get_url_expiration(url, expected_expire_after, mock_session):
+ urls_expire_after = {
+ '*.site_1.com': 60 * 60,
+ 'site_2.com/resource_1': 60 * 60 * 2,
+ 'site_2.com/resource_2': 60 * 60 * 24,
+ 'site_2.com/static': -1,
+ }
+ assert get_url_expiration(url, urls_expire_after) == expected_expire_after
+
+
+@pytest.mark.parametrize(
+ 'url, expected_expire_after',
+ [
+ ('https://img.site_1.com/image.jpeg', 60 * 60),
+ ('https://img.site_1.com/resource/1', 60 * 60 * 2),
+ ('https://site_2.com', 1),
+ ('https://any_other_site.com', 1),
+ ],
+)
+def test_get_url_expiration__evaluation_order(url, expected_expire_after):
+ """If there are multiple matches, the first match should be used in the order defined"""
+ urls_expire_after = {
+ '*.site_1.com/resource': 60 * 60 * 2,
+ '*.site_1.com': 60 * 60,
+ '*': 1,
+ }
+ assert get_url_expiration(url, urls_expire_after) == expected_expire_after