Source code for limits.storage.memcached

from __future__ import annotations

import inspect
import threading
import time
from collections.abc import Iterable
from math import ceil, floor
from types import ModuleType

from limits._storage_scheme import parse_storage_uri
from limits.errors import ConfigurationError
from limits.storage.base import (
    SlidingWindowCounterSupport,
    Storage,
    TimestampedSlidingWindow,
)
from limits.typing import (
    Any,
    Callable,
    MemcachedClientP,
    P,
    R,
    cast,
)
from limits.util import get_dependency


[docs] class MemcachedStorage(Storage, SlidingWindowCounterSupport, TimestampedSlidingWindow): """ Rate limit storage with memcached as backend. Depends on :pypi:`pymemcache`. """ STORAGE_SCHEME = ["memcached"] """The storage scheme for memcached""" DEPENDENCIES = ["pymemcache"] def __init__( self, uri: str, wrap_exceptions: bool = False, **options: str | Callable[[], MemcachedClientP], ) -> None: """ :param uri: memcached location of the form ``memcached://host:port,host:port``, ``memcached:///var/tmp/path/to/sock`` :param wrap_exceptions: Whether to wrap storage exceptions in :exc:`limits.errors.StorageError` before raising it. :param options: all remaining keyword arguments are passed directly to the constructor of :class:`pymemcache.client.base.PooledClient` or :class:`pymemcache.client.hash.HashClient` (if there are more than one hosts specified) :raise ConfigurationError: when :pypi:`pymemcache` is not available """ storage_uri_options = parse_storage_uri(uri) self.hosts: list[tuple[str, int]] | list[str] if storage_uri_options.path: self.hosts = [storage_uri_options.path] else: self.hosts = storage_uri_options.locations self.dependency = self.dependencies["pymemcache"].module self.library = str(options.pop("library", "pymemcache.client")) self.cluster_library = str( options.pop("cluster_library", "pymemcache.client.hash") ) self.client_getter = cast( Callable[[ModuleType, list[tuple[str, int]] | list[str]], MemcachedClientP], options.pop("client_getter", self.get_client), ) self.options = options if not get_dependency(self.library): raise ConfigurationError( f"memcached prerequisite not available. please install {self.library}" ) # pragma: no cover self.local_storage = threading.local() self.local_storage.storage = None super().__init__(uri, wrap_exceptions=wrap_exceptions) @property def base_exceptions( self, ) -> type[Exception] | tuple[type[Exception], ...]: # pragma: no cover return self.dependency.MemcacheError # type: ignore[no-any-return]
[docs] def get_client( self, module: ModuleType, hosts: list[tuple[str, int]], **kwargs: str ) -> MemcachedClientP: """ returns a memcached client. :param module: the memcached module :param hosts: list of memcached hosts """ return cast( MemcachedClientP, ( module.HashClient(hosts, **kwargs) if len(hosts) > 1 else module.PooledClient(*hosts, **kwargs) ), )
def call_memcached_func( self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> R: if "noreply" in kwargs: argspec = inspect.getfullargspec(func) if not ("noreply" in argspec.args or argspec.varkw): kwargs.pop("noreply") return func(*args, **kwargs) @property def storage(self) -> MemcachedClientP: """ lazily creates a memcached client instance using a thread local """ if not (hasattr(self.local_storage, "storage") and self.local_storage.storage): dependency = get_dependency( self.cluster_library if len(self.hosts) > 1 else self.library )[0] if not dependency: raise ConfigurationError(f"Unable to import {self.cluster_library}") self.local_storage.storage = self.client_getter( dependency, self.hosts, **self.options ) return cast(MemcachedClientP, self.local_storage.storage)
[docs] def get(self, key: str) -> int: """ :param key: the key to get the counter value for """ return int(self.storage.get(key, "0"))
def get_many(self, keys: Iterable[str]) -> dict[str, Any]: # type:ignore[explicit-any] """ Return multiple counters at once :param keys: the keys to get the counter values for :meta private: """ return self.storage.get_many(keys)
[docs] def clear(self, key: str) -> None: """ :param key: the key to clear rate limits for """ self.storage.delete(key)
[docs] def incr( self, key: str, expiry: float, amount: int = 1, set_expiration_key: bool = True, ) -> int: """ increments the counter for a given rate limit key :param key: the key to increment :param expiry: amount in seconds for the key to expire in window every hit. :param amount: the number to increment by :param set_expiration_key: set the expiration key with the expiration time if needed. If set to False, the key will still expire, but memcached cannot provide the expiration time. """ if ( value := self.call_memcached_func( self.storage.incr, key, amount, noreply=False ) ) is not None: return value else: if not self.call_memcached_func( self.storage.add, key, amount, ceil(expiry), noreply=False ): return self.storage.incr(key, amount) or amount else: if set_expiration_key: self.call_memcached_func( self.storage.set, self._expiration_key(key), expiry + time.time(), expire=ceil(expiry), noreply=False, ) return amount
[docs] def get_expiry(self, key: str) -> float: """ :param key: the key to get the expiry for """ return float(self.storage.get(self._expiration_key(key)) or time.time())
def _expiration_key(self, key: str) -> str: """ Return the expiration key for the given counter key. Memcached doesn't natively return the expiration time or TTL for a given key, so we implement the expiration time on a separate key. """ return key + "/expires"
[docs] def check(self) -> bool: """ Check if storage is healthy by calling the ``get`` command on the key ``limiter-check`` """ try: self.call_memcached_func(self.storage.get, "limiter-check") return True except: # noqa return False
[docs] def reset(self) -> int | None: raise NotImplementedError
[docs] def acquire_sliding_window_entry( self, key: str, limit: int, expiry: int, amount: int = 1, ) -> bool: if amount > limit: return False now = time.time() previous_key, current_key = self.sliding_window_keys(key, expiry, now) previous_count, previous_ttl, current_count, _ = self._get_sliding_window_info( previous_key, current_key, expiry, now=now ) weighted_count = previous_count * previous_ttl / expiry + current_count if floor(weighted_count) + amount > limit: return False else: # Hit, increase the current counter. # If the counter doesn't exist yet, set twice the theorical expiry. # We don't need the expiration key as it is estimated with the timestamps directly. current_count = self.incr( current_key, 2 * expiry, amount=amount, set_expiration_key=False ) actualised_previous_ttl = min(0, previous_ttl - (time.time() - now)) weighted_count = ( previous_count * actualised_previous_ttl / expiry + current_count ) if floor(weighted_count) > limit: # Another hit won the race condition: revert the incrementation and refuse this hit # Limitation: during high concurrency at the end of the window, # the counter is shifted and cannot be decremented, so less requests than expected are allowed. self.call_memcached_func( self.storage.decr, current_key, amount, noreply=True, ) return False return True
[docs] def get_sliding_window( self, key: str, expiry: int ) -> tuple[int, float, int, float]: now = time.time() previous_key, current_key = self.sliding_window_keys(key, expiry, now) return self._get_sliding_window_info(previous_key, current_key, expiry, now)
[docs] def clear_sliding_window(self, key: str, expiry: int) -> None: now = time.time() previous_key, current_key = self.sliding_window_keys(key, expiry, now) self.clear(previous_key) self.clear(current_key)
def _get_sliding_window_info( self, previous_key: str, current_key: str, expiry: int, now: float ) -> tuple[int, float, int, float]: result = self.get_many([previous_key, current_key]) previous_count, current_count = ( int(result.get(previous_key, 0)), int(result.get(current_key, 0)), ) if previous_count == 0: previous_ttl = float(0) else: previous_ttl = (1 - (((now - expiry) / expiry) % 1)) * expiry current_ttl = (1 - ((now / expiry) % 1)) * expiry + expiry return previous_count, previous_ttl, current_count, current_ttl