import asyncio
import time
import urllib.parse
from typing import TYPE_CHECKING, Optional

from import Storage
from limits.errors import ConcurrentUpdateError

    import aetcd

[docs] class EtcdStorage(Storage): """ Rate limit storage with etcd as backend. Depends on :pypi:`aetcd`. """ STORAGE_SCHEME = ["async+etcd"] """The async storage scheme for etcd""" DEPENDENCIES = ["aetcd"] PREFIX = "limits" MAX_RETRIES = 5 def __init__( self, uri: str, max_retries: int = MAX_RETRIES, **options: str, ) -> None: """ :param uri: etcd location of the form ``async+etcd://host:port``, :param max_retries: Maximum number of attempts to retry in the case of concurrent updates to a rate limit key :param options: all remaining keyword arguments are passed directly to the constructor of :class:`aetcd.client.Client` :raise ConfigurationError: when :pypi:`aetcd` is not available """ parsed = urllib.parse.urlparse(uri) self.lib = self.dependencies["aetcd"].module "aetcd.Client" = self.lib.Client( host=parsed.hostname, port=parsed.port, **options ) self.max_retries = max_retries def prefixed_key(self, key: str) -> bytes: return f"{self.PREFIX}/{key}".encode()
[docs] async def incr( self, key: str, expiry: int, elastic_expiry: bool = False, amount: int = 1 ) -> int: retries = 0 etcd_key = self.prefixed_key(key) while retries < self.max_retries: now = time.time() lease = await window_end = now + expiry create_attempt = await compare=[ == b"0"], success=[ etcd_key, f"{amount}:{window_end}".encode(), ) ], failure=[], ) if create_attempt[0]: return amount else: cur = create_attempt[1][0][0][1] cur_value, window_end = cur.value.split(b":") window_end = float(window_end) if window_end <= now: await asyncio.gather(,, ) else: if elastic_expiry: await window_end = now + expiry new = int(cur_value) + amount if ( await compare=[ == cur.value ], success=[ etcd_key, f"{new}:{window_end}".encode(),, ) ], failure=[], ) )[0]: return new retries += 1 raise ConcurrentUpdateError(key, retries)
[docs] async def get(self, key: str) -> int: cur = await if cur: amount, expiry = cur.value.split(b":") if float(expiry) > time.time(): return int(amount) return 0
[docs] async def get_expiry(self, key: str) -> int: cur = await if cur: window_end = float(cur.value.split(b":")[1]) return int(window_end) return int(time.time())
[docs] async def check(self) -> bool: try: await return True except: # noqa return False
[docs] async def reset(self) -> Optional[int]: return (await"{self.PREFIX}/".encode())).deleted
[docs] async def clear(self, key: str) -> None: await