"""
Rate limiting strategies
"""
from abc import ABCMeta, abstractmethod
from typing import Dict, Type, Union, cast
from .limits import RateLimitItem
from .storage import MovingWindowSupport, Storage, StorageTypes
from .util import WindowStats
[docs]
class RateLimiter(metaclass=ABCMeta):
def __init__(self, storage: StorageTypes):
assert isinstance(storage, Storage)
self.storage: Storage = storage
[docs]
@abstractmethod
def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:param cost: The cost of this hit, default 1
"""
raise NotImplementedError
[docs]
@abstractmethod
def test(self, item: RateLimitItem, *identifiers: str) -> bool:
"""
Check the rate limit without consuming from it.
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
"""
raise NotImplementedError
[docs]
@abstractmethod
def get_window_stats(self, item: RateLimitItem, *identifiers: str) -> WindowStats:
"""
Query the reset time and remaining amount for the limit
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:return: (reset time, remaining)
"""
raise NotImplementedError
def clear(self, item: RateLimitItem, *identifiers: str) -> None:
return self.storage.clear(item.key_for(*identifiers))
[docs]
class MovingWindowRateLimiter(RateLimiter):
"""
Reference: :ref:`strategies:moving window`
"""
def __init__(self, storage: StorageTypes):
if not (
hasattr(storage, "acquire_entry") or hasattr(storage, "get_moving_window")
):
raise NotImplementedError(
"MovingWindowRateLimiting is not implemented for storage "
"of type %s" % storage.__class__
)
super().__init__(storage)
[docs]
def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:param cost: The cost of this hit, default 1
:return: (reset time, remaining)
"""
return cast(MovingWindowSupport, self.storage).acquire_entry(
item.key_for(*identifiers), item.amount, item.get_expiry(), amount=cost
)
[docs]
def test(self, item: RateLimitItem, *identifiers: str) -> bool:
"""
Check if the rate limit can be consumed
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
"""
return (
cast(MovingWindowSupport, self.storage).get_moving_window(
item.key_for(*identifiers),
item.amount,
item.get_expiry(),
)[1]
< item.amount
)
[docs]
def get_window_stats(self, item: RateLimitItem, *identifiers: str) -> WindowStats:
"""
returns the number of requests remaining within this limit.
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:return: tuple (reset time, remaining)
"""
window_start, window_items = cast(
MovingWindowSupport, self.storage
).get_moving_window(item.key_for(*identifiers), item.amount, item.get_expiry())
reset = window_start + item.get_expiry()
return WindowStats(reset, item.amount - window_items)
[docs]
class FixedWindowRateLimiter(RateLimiter):
"""
Reference: :ref:`strategies:fixed window`
"""
[docs]
def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:param cost: The cost of this hit, default 1
"""
return (
self.storage.incr(
item.key_for(*identifiers),
item.get_expiry(),
elastic_expiry=False,
amount=cost,
)
<= item.amount
)
[docs]
def test(self, item: RateLimitItem, *identifiers: str) -> bool:
"""
Check if the rate limit can be consumed
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
"""
return self.storage.get(item.key_for(*identifiers)) < item.amount
[docs]
def get_window_stats(self, item: RateLimitItem, *identifiers: str) -> WindowStats:
"""
Query the reset time and remaining amount for the limit
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:return: (reset time, remaining)
"""
remaining = max(0, item.amount - self.storage.get(item.key_for(*identifiers)))
reset = self.storage.get_expiry(item.key_for(*identifiers))
return WindowStats(reset, remaining)
[docs]
class FixedWindowElasticExpiryRateLimiter(FixedWindowRateLimiter):
"""
Reference: :ref:`strategies:fixed window with elastic expiry`
"""
[docs]
def hit(self, item: RateLimitItem, *identifiers: str, cost: int = 1) -> bool:
"""
Consume the rate limit
:param item: The rate limit item
:param identifiers: variable list of strings to uniquely identify this
instance of the limit
:param cost: The cost of this hit, default 1
"""
return (
self.storage.incr(
item.key_for(*identifiers),
item.get_expiry(),
elastic_expiry=True,
amount=cost,
)
<= item.amount
)
KnownStrategy = Union[
Type[FixedWindowRateLimiter],
Type[FixedWindowElasticExpiryRateLimiter],
Type[MovingWindowRateLimiter],
]
STRATEGIES: Dict[str, KnownStrategy] = {
"fixed-window": FixedWindowRateLimiter,
"fixed-window-elastic-expiry": FixedWindowElasticExpiryRateLimiter,
"moving-window": MovingWindowRateLimiter,
}