Source code for limits.storage.memcached

import inspect
import threading
import time
import urllib.parse
from types import ModuleType
from typing import cast

from limits.errors import ConfigurationError
from limits.storage.base import Storage
from limits.typing import (
    Callable,
    List,
    MemcachedClientP,
    Optional,
    P,
    R,
    Tuple,
    Type,
    Union,
)
from limits.util import get_dependency


[docs] class MemcachedStorage(Storage): """ Rate limit storage with memcached as backend. Depends on :pypi:`pymemcache`. """ STORAGE_SCHEME = ["memcached"] """The storage scheme for memcached""" DEPENDENCIES = ["pymemcache"] def __init__( self, uri: str, wrap_exceptions: bool = False, **options: Union[str, Callable[[], MemcachedClientP]], ) -> None: """ :param uri: memcached location of the form ``memcached://host:port,host:port``, ``memcached:///var/tmp/path/to/sock`` :param wrap_exceptions: Whether to wrap storage exceptions in :exc:`limits.errors.StorageError` before raising it. :param options: all remaining keyword arguments are passed directly to the constructor of :class:`pymemcache.client.base.PooledClient` or :class:`pymemcache.client.hash.HashClient` (if there are more than one hosts specified) :raise ConfigurationError: when :pypi:`pymemcache` is not available """ parsed = urllib.parse.urlparse(uri) self.hosts = [] for loc in parsed.netloc.strip().split(","): if not loc: continue host, port = loc.split(":") self.hosts.append((host, int(port))) else: # filesystem path to UDS if parsed.path and not parsed.netloc and not parsed.port: self.hosts = [parsed.path] # type: ignore self.dependency = self.dependencies["pymemcache"].module self.library = str(options.pop("library", "pymemcache.client")) self.cluster_library = str( options.pop("cluster_library", "pymemcache.client.hash") ) self.client_getter = cast( Callable[[ModuleType, List[Tuple[str, int]]], MemcachedClientP], options.pop("client_getter", self.get_client), ) self.options = options if not get_dependency(self.library): raise ConfigurationError( "memcached prerequisite not available." " please install %s" % self.library ) # pragma: no cover self.local_storage = threading.local() self.local_storage.storage = None super().__init__(uri, wrap_exceptions=wrap_exceptions) @property def base_exceptions( self, ) -> Union[Type[Exception], Tuple[Type[Exception], ...]]: # pragma: no cover return self.dependency.MemcacheError # type: ignore[no-any-return]
[docs] def get_client( self, module: ModuleType, hosts: List[Tuple[str, int]], **kwargs: str ) -> MemcachedClientP: """ returns a memcached client. :param module: the memcached module :param hosts: list of memcached hosts """ return cast( MemcachedClientP, ( module.HashClient(hosts, **kwargs) if len(hosts) > 1 else module.PooledClient(*hosts, **kwargs) ), )
def call_memcached_func( self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs ) -> R: if "noreply" in kwargs: argspec = inspect.getfullargspec(func) if not ("noreply" in argspec.args or argspec.varkw): kwargs.pop("noreply") return func(*args, **kwargs) @property def storage(self) -> MemcachedClientP: """ lazily creates a memcached client instance using a thread local """ if not (hasattr(self.local_storage, "storage") and self.local_storage.storage): dependency = get_dependency( self.cluster_library if len(self.hosts) > 1 else self.library )[0] if not dependency: raise ConfigurationError(f"Unable to import {self.cluster_library}") self.local_storage.storage = self.client_getter( dependency, self.hosts, **self.options ) return cast(MemcachedClientP, self.local_storage.storage)
[docs] def get(self, key: str) -> int: """ :param key: the key to get the counter value for """ return int(self.storage.get(key) or 0)
[docs] def clear(self, key: str) -> None: """ :param key: the key to clear rate limits for """ self.storage.delete(key)
[docs] def incr( self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1 ) -> int: """ increments the counter for a given rate limit key :param key: the key to increment :param expiry: amount in seconds for the key to expire in :param elastic_expiry: whether to keep extending the rate limit window every hit. :param amount: the number to increment by """ if not self.call_memcached_func( self.storage.add, key, amount, expiry, noreply=False ): value = self.storage.incr(key, amount) or amount if elastic_expiry: self.call_memcached_func(self.storage.touch, key, expiry) self.call_memcached_func( self.storage.set, key + "/expires", expiry + time.time(), expire=expiry, noreply=False, ) return value else: self.call_memcached_func( self.storage.set, key + "/expires", expiry + time.time(), expire=expiry, noreply=False, ) return amount
[docs] def get_expiry(self, key: str) -> int: """ :param key: the key to get the expiry for """ return int(float(self.storage.get(key + "/expires") or time.time()))
[docs] def check(self) -> bool: """ Check if storage is healthy by calling the ``get`` command on the key ``limiter-check`` """ try: self.call_memcached_func(self.storage.get, "limiter-check") return True except: # noqa return False
[docs] def reset(self) -> Optional[int]: raise NotImplementedError