Source code for limits.aio.storage.redis

import time
import urllib
from typing import TYPE_CHECKING, cast

from deprecated.sphinx import versionadded
from packaging.version import Version

from limits.aio.storage.base import MovingWindowSupport, Storage
from limits.errors import ConfigurationError
from limits.typing import AsyncRedisClient, Dict, Optional, Tuple, Union
from limits.util import get_package_data

if TYPE_CHECKING:
    import coredis
    import coredis.commands


class RedisInteractor:
    RES_DIR = "resources/redis/lua_scripts"

    SCRIPT_MOVING_WINDOW = get_package_data(f"{RES_DIR}/moving_window.lua")
    SCRIPT_ACQUIRE_MOVING_WINDOW = get_package_data(
        f"{RES_DIR}/acquire_moving_window.lua"
    )
    SCRIPT_CLEAR_KEYS = get_package_data(f"{RES_DIR}/clear_keys.lua")
    SCRIPT_INCR_EXPIRE = get_package_data(f"{RES_DIR}/incr_expire.lua")

    lua_moving_window: "coredis.commands.Script[bytes]"
    lua_acquire_window: "coredis.commands.Script[bytes]"
    lua_clear_keys: "coredis.commands.Script[bytes]"
    lua_incr_expire: "coredis.commands.Script[bytes]"

    PREFIX = "LIMITS"

    def prefixed_key(self, key: str) -> str:
        return f"{self.PREFIX}:{key}"

    async def _incr(
        self,
        key: str,
        expiry: int,
        connection: AsyncRedisClient,
        elastic_expiry: bool = False,
        amount: int = 1,
    ) -> int:
        """
        increments the counter for a given rate limit key

        :param connection: Redis connection
        :param key: the key to increment
        :param expiry: amount in seconds for the key to expire in
        :param amount: the number to increment by
        """
        key = self.prefixed_key(key)
        value = await connection.incrby(key, amount)

        if elastic_expiry or value == amount:
            await connection.expire(key, expiry)

        return value

    async def _get(self, key: str, connection: AsyncRedisClient) -> int:
        """
        :param connection: Redis connection
        :param key: the key to get the counter value for
        """

        key = self.prefixed_key(key)
        return int(await connection.get(key) or 0)

    async def _clear(self, key: str, connection: AsyncRedisClient) -> None:
        """
        :param key: the key to clear rate limits for
        :param connection: Redis connection
        """
        key = self.prefixed_key(key)
        await connection.delete([key])

    async def get_moving_window(
        self, key: str, limit: int, expiry: int
    ) -> Tuple[int, int]:
        """
        returns the starting point and the number of entries in the moving
        window

        :param key: rate limit key
        :param expiry: expiry of entry
        :return: (start of window, number of acquired entries)
        """
        key = self.prefixed_key(key)
        timestamp = int(time.time())
        window = await self.lua_moving_window.execute(
            [key], [int(timestamp - expiry), limit]
        )
        if window:
            return tuple(window)  # type: ignore
        return timestamp, 0

    async def _acquire_entry(
        self,
        key: str,
        limit: int,
        expiry: int,
        connection: AsyncRedisClient,
        amount: int = 1,
    ) -> bool:
        """
        :param key: rate limit key to acquire an entry in
        :param limit: amount of entries allowed
        :param expiry: expiry of the entry
        :param connection: Redis connection
        """
        key = self.prefixed_key(key)
        timestamp = time.time()
        acquired = await self.lua_acquire_window.execute(
            [key], [timestamp, limit, expiry, amount]
        )

        return bool(acquired)

    async def _get_expiry(self, key: str, connection: AsyncRedisClient) -> int:
        """
        :param key: the key to get the expiry for
        :param connection: Redis connection
        """

        key = self.prefixed_key(key)
        return int(max(await connection.ttl(key), 0) + time.time())

    async def _check(self, connection: AsyncRedisClient) -> bool:
        """
        check if storage is healthy

        :param connection: Redis connection
        """
        try:
            await connection.ping()

            return True
        except:  # noqa
            return False


[docs] @versionadded(version="2.1") class RedisStorage(RedisInteractor, Storage, MovingWindowSupport): """ Rate limit storage with redis as backend. Depends on :pypi:`coredis` """ STORAGE_SCHEME = ["async+redis", "async+rediss", "async+redis+unix"] """ The storage schemes for redis to be used in an async context """ DEPENDENCIES = {"coredis": Version("3.4.0")} def __init__( self, uri: str, connection_pool: Optional["coredis.ConnectionPool"] = None, **options: Union[float, str, bool], ) -> None: """ :param uri: uri of the form: - ``async+redis://[:password]@host:port`` - ``async+redis://[:password]@host:port/db`` - ``async+rediss://[:password]@host:port`` - ``async+unix:///path/to/sock`` etc... This uri is passed directly to :meth:`coredis.Redis.from_url` with the initial ``async`` removed, except for the case of ``async+redis+unix`` where it is replaced with ``unix``. :param connection_pool: if provided, the redis client is initialized with the connection pool and any other params passed as :paramref:`options` :param options: all remaining keyword arguments are passed directly to the constructor of :class:`coredis.Redis` :raise ConfigurationError: when the redis library is not available """ uri = uri.replace("async+redis", "redis", 1) uri = uri.replace("redis+unix", "unix") super().__init__(uri, **options) self.dependency = self.dependencies["coredis"].module if connection_pool: self.storage = self.dependency.Redis( connection_pool=connection_pool, **options ) else: self.storage = self.dependency.Redis.from_url(uri, **options) self.initialize_storage(uri) def initialize_storage(self, _uri: str) -> None: # all these methods are coroutines, so must be called with await self.lua_moving_window = self.storage.register_script(self.SCRIPT_MOVING_WINDOW) self.lua_acquire_window = self.storage.register_script( self.SCRIPT_ACQUIRE_MOVING_WINDOW ) self.lua_clear_keys = self.storage.register_script(self.SCRIPT_CLEAR_KEYS) self.lua_incr_expire = self.storage.register_script( RedisStorage.SCRIPT_INCR_EXPIRE )
[docs] async def incr( self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1 ) -> int: """ increments the counter for a given rate limit key :param key: the key to increment :param expiry: amount in seconds for the key to expire in :param amount: the number to increment by """ if elastic_expiry: return await super()._incr( key, expiry, self.storage, elastic_expiry, amount ) else: key = self.prefixed_key(key) return cast( int, await self.lua_incr_expire.execute([key], [expiry, amount]) )
[docs] async def get(self, key: str) -> int: """ :param key: the key to get the counter value for """ return await super()._get(key, self.storage)
[docs] async def clear(self, key: str) -> None: """ :param key: the key to clear rate limits for """ return await super()._clear(key, self.storage)
[docs] async def acquire_entry( self, key: str, limit: int, expiry: int, amount: int = 1 ) -> bool: """ :param key: rate limit key to acquire an entry in :param limit: amount of entries allowed :param expiry: expiry of the entry :param amount: the number of entries to acquire """ return await super()._acquire_entry(key, limit, expiry, self.storage, amount)
[docs] async def get_expiry(self, key: str) -> int: """ :param key: the key to get the expiry for """ return await super()._get_expiry(key, self.storage)
[docs] async def check(self) -> bool: """ Check if storage is healthy by calling :meth:`coredis.Redis.ping` """ return await super()._check(self.storage)
[docs] async def reset(self) -> Optional[int]: """ This function calls a Lua Script to delete keys prefixed with `self.PREFIX` in block of 5000. .. warning:: This operation was designed to be fast, but was not tested on a large production based system. Be careful with its usage as it could be slow on very large data sets. """ prefix = self.prefixed_key("*") return cast(int, await self.lua_clear_keys.execute([prefix]))
[docs] @versionadded(version="2.1") class RedisClusterStorage(RedisStorage): """ Rate limit storage with redis cluster as backend Depends on :pypi:`coredis` """ STORAGE_SCHEME = ["async+redis+cluster"] """ The storage schemes for redis cluster to be used in an async context """ DEFAULT_OPTIONS: Dict[str, Union[float, str, bool]] = { "max_connections": 1000, } "Default options passed to :class:`coredis.RedisCluster`" def __init__(self, uri: str, **options: Union[float, str, bool]) -> None: """ :param uri: url of the form ``async+redis+cluster://[:password]@host:port,host:port`` :param options: all remaining keyword arguments are passed directly to the constructor of :class:`coredis.RedisCluster` :raise ConfigurationError: when the coredis library is not available or if the redis host cannot be pinged. """ parsed = urllib.parse.urlparse(uri) parsed_auth: Dict[str, Union[float, str, bool]] = {} if parsed.username: parsed_auth["username"] = parsed.username if parsed.password: parsed_auth["password"] = parsed.password sep = parsed.netloc.find("@") + 1 cluster_hosts = [] for loc in parsed.netloc[sep:].split(","): host, port = loc.split(":") cluster_hosts.append({"host": host, "port": int(port)}) super(RedisStorage, self).__init__(uri, **options) self.dependency = self.dependencies["coredis"].module self.storage: "coredis.RedisCluster[str]" = self.dependency.RedisCluster( startup_nodes=cluster_hosts, **{**self.DEFAULT_OPTIONS, **parsed_auth, **options}, ) self.initialize_storage(uri)
[docs] async def reset(self) -> Optional[int]: """ Redis Clusters are sharded and deleting across shards can't be done atomically. Because of this, this reset loops over all keys that are prefixed with `self.PREFIX` and calls delete on them, one at a time. .. warning:: This operation was not tested with extremely large data sets. On a large production based system, care should be taken with its usage as it could be slow on very large data sets """ prefix = self.prefixed_key("*") keys = await self.storage.keys(prefix) count = 0 for key in keys: count += await self.storage.delete([key]) return count
[docs] @versionadded(version="2.1") class RedisSentinelStorage(RedisStorage): """ Rate limit storage with redis sentinel as backend Depends on :pypi:`coredis` """ STORAGE_SCHEME = ["async+redis+sentinel"] """The storage scheme for redis accessed via a redis sentinel installation""" DEPENDENCIES = {"coredis.sentinel": Version("3.4.0")} def __init__( self, uri: str, service_name: Optional[str] = None, use_replicas: bool = True, sentinel_kwargs: Optional[Dict[str, Union[float, str, bool]]] = None, **options: Union[float, str, bool], ): """ :param uri: url of the form ``async+redis+sentinel://host:port,host:port/service_name`` :param service_name, optional: sentinel service name (if not provided in `uri`) :param use_replicas: Whether to use replicas for read only operations :param sentinel_kwargs, optional: kwargs to pass as ``sentinel_kwargs`` to :class:`coredis.sentinel.Sentinel` :param options: all remaining keyword arguments are passed directly to the constructor of :class:`coredis.sentinel.Sentinel` :raise ConfigurationError: when the coredis library is not available or if the redis primary host cannot be pinged. """ parsed = urllib.parse.urlparse(uri) sentinel_configuration = [] connection_options = options.copy() sentinel_options = sentinel_kwargs.copy() if sentinel_kwargs else {} parsed_auth: Dict[str, Union[float, str, bool]] = {} if parsed.username: parsed_auth["username"] = parsed.username if parsed.password: parsed_auth["password"] = parsed.password sep = parsed.netloc.find("@") + 1 for loc in parsed.netloc[sep:].split(","): host, port = loc.split(":") sentinel_configuration.append((host, int(port))) self.service_name = ( parsed.path.replace("/", "") if parsed.path else service_name ) if self.service_name is None: raise ConfigurationError("'service_name' not provided") super(RedisStorage, self).__init__() self.dependency = self.dependencies["coredis.sentinel"].module self.sentinel = self.dependency.Sentinel( sentinel_configuration, sentinel_kwargs={**parsed_auth, **sentinel_options}, **{**parsed_auth, **connection_options}, ) self.storage = self.sentinel.primary_for(self.service_name) self.storage_replica = self.sentinel.replica_for(self.service_name) self.use_replicas = use_replicas self.initialize_storage(uri)
[docs] async def get(self, key: str) -> int: """ :param key: the key to get the counter value for """ return await super()._get( key, self.storage_replica if self.use_replicas else self.storage )
[docs] async def get_expiry(self, key: str) -> int: """ :param key: the key to get the expiry for """ return await super()._get_expiry( key, self.storage_replica if self.use_replicas else self.storage )
[docs] async def check(self) -> bool: """ Check if storage is healthy by calling :meth:`coredis.Redis.ping` on the replica. """ return await super()._check( self.storage_replica if self.use_replicas else self.storage )