"""Asynchronous rate limiter for API requests."""
import asyncio
import time
from collections import deque
from typing import Any, Literal
from pybdl.api.exceptions import RateLimitDelayExceeded, RateLimitError
from pybdl.api.utils.quota_cache import PersistentQuotaCache
[docs]
class AsyncRateLimiter:
"""
Asyncio-compatible rate limiter for API requests.
Enforces multiple quota periods and persists usage if a cache is provided.
"""
def __init__(
self,
quotas: dict[int, int | tuple],
is_registered: bool,
cache: PersistentQuotaCache | None = None,
max_delay: float | None = None,
raise_on_limit: bool = True,
buffer_seconds: float = 0.05,
) -> None:
"""
Initialize the async rate limiter.
Args:
quotas: Dictionary of {period_seconds: limit or (anon_limit, reg_limit)}.
is_registered: Whether the user is registered (affects quota).
cache: Optional persistent cache for quota usage.
max_delay: Maximum delay in seconds to wait. None = wait forever, 0 = raise immediately.
raise_on_limit: If True, raise exception immediately when limit exceeded. If False, wait.
buffer_seconds: Small buffer time to add to wait calculations (default: 0.05s).
"""
self.quotas = quotas
self.is_registered = is_registered
self.lock = asyncio.Lock() # Single lock for all periods
self.calls: dict[int, deque[float]] = {period: deque() for period in quotas}
self.cache = cache
# Use unified cache key so sync and async share quota state
self.cache_key = f"{'reg' if is_registered else 'anon'}"
self.max_delay = max_delay
self.raise_on_limit = raise_on_limit
self.buffer_seconds = buffer_seconds
if self.cache and self.cache.enabled:
self._load_from_cache()
def _get_limit(self, period: int) -> int:
# quotas: {period: tuple of (anonymous_limit, registered_limit) or int}
limit_value = self.quotas[period]
if isinstance(limit_value, tuple):
return limit_value[1] if self.is_registered else limit_value[0]
return limit_value
def _load_from_cache(self, merge: bool = True) -> None:
"""
Load quota state from cache.
Args:
merge: If True, merge cached calls with current calls. If False, replace current with cached.
"""
if self.cache is not None:
for period in self.quotas:
cached = self.cache.get(f"{self.cache_key}_{period}")
if merge:
# Merge cached calls with current calls, keeping all unique timestamps
current_times = set(self.calls[period])
cached_times = set(cached)
merged = sorted(current_times | cached_times)
self.calls[period] = deque(merged)
else:
# Replace current calls with cached calls (for atomic check-before-record)
self.calls[period] = deque(cached)
def _save_to_cache(self) -> None:
if not self.cache or not self.cache.enabled:
return
for period in self.quotas:
self.cache.set(f"{self.cache_key}_{period}", list(self.calls[period]))
def _get_limit_info(self) -> dict[str, Any]:
"""Get current rate limit information."""
return {
"quotas": [
{
"period": period,
"limit": self._get_limit(period),
"current": len(self.calls[period]),
}
for period in self.quotas
],
"is_registered": self.is_registered,
}
[docs]
def get_remaining_quota(self) -> dict[int, int]:
"""Get remaining quota for each period."""
now = time.monotonic()
remaining = {}
# Note: This is a sync method, so we can't use async lock here
# For async usage, use await get_remaining_quota_async()
for period in self.quotas:
q = self.calls[period]
limit = self._get_limit(period)
# Clean up old calls
while q and q[0] <= now - period:
q.popleft()
remaining[period] = max(0, limit - len(q))
return remaining
[docs]
async def get_remaining_quota_async(self) -> dict[int, int]:
"""Get remaining quota for each period (async version)."""
now = time.monotonic()
remaining = {}
async with self.lock:
for period in self.quotas:
q = self.calls[period]
limit = self._get_limit(period)
# Clean up old calls
while q and q[0] <= now - period:
q.popleft()
remaining[period] = max(0, limit - len(q))
return remaining
[docs]
def reset(self) -> None:
"""Reset all quota counters."""
for period in self.quotas:
self.calls[period].clear()
self._save_to_cache()
[docs]
async def reset_async(self) -> None:
"""Reset all quota counters (async version)."""
async with self.lock:
for period in self.quotas:
self.calls[period].clear()
self._save_to_cache()
[docs]
async def acquire(self) -> None:
"""
Acquire a slot for an API request asynchronously.
If rate limit is exceeded:
- If raise_on_limit=True: Raises RateLimitError immediately
- If raise_on_limit=False: Waits until quota available
- If max_delay is set and wait_time > max_delay: Raises RateLimitDelayExceeded
Raises:
RateLimitError: If the rate limit is exceeded and raise_on_limit=True.
RateLimitDelayExceeded: If required delay exceeds max_delay.
"""
now = time.monotonic()
# Check if we need to wait
max_wait = 0.0
async with self.lock:
# Reload from cache to get updates from other limiters (sync/async)
if self.cache and self.cache.enabled:
self._load_from_cache()
for period in self.quotas:
q = self.calls[period]
limit = self._get_limit(period)
while q and q[0] <= now - period:
q.popleft()
if len(q) >= limit:
wait_time = period - (now - q[0])
max_wait = max(max_wait, wait_time)
if max_wait > 0:
# Add buffer to be safe
max_wait += self.buffer_seconds
if self.raise_on_limit:
self._save_to_cache()
raise RateLimitError(
retry_after=max_wait,
limit_info=self._get_limit_info(),
)
if self.max_delay is not None and max_wait > self.max_delay:
self._save_to_cache()
raise RateLimitDelayExceeded(
actual_delay=max_wait,
max_delay=self.max_delay,
limit_info=self._get_limit_info(),
)
await asyncio.sleep(max_wait) # Non-blocking sleep
# Record call
async with self.lock:
now = time.monotonic() # Refresh after potential sleep
# Use atomic cache operation to prevent race conditions
if self.cache and self.cache.enabled:
# Try to atomically record the call in cache for each period
# This ensures only one limiter can successfully record at a time
for period in self.quotas:
limit = self._get_limit(period)
cache_key = f"{self.cache_key}_{period}"
cleanup_before = now - period # Remove calls older than this
# Try atomic append - this checks limits and records atomically
success = self.cache.try_append_if_under_limit(cache_key, now, limit, cleanup_before)
if not success:
# Failed to record - we're at the limit
# Reload to get current state for error message
self._load_from_cache(merge=False)
q = self.calls[period]
while q and q[0] <= now - period:
q.popleft()
wait_time = (
(period - (now - q[0]) + self.buffer_seconds) if q and q[0] > now - period else period
)
if self.raise_on_limit:
raise RateLimitError(
retry_after=wait_time,
limit_info=self._get_limit_info(),
)
if self.max_delay is not None and wait_time > self.max_delay:
raise RateLimitDelayExceeded(
actual_delay=wait_time,
max_delay=self.max_delay,
limit_info=self._get_limit_info(),
)
await asyncio.sleep(wait_time)
now = time.monotonic() # Refresh after sleep
cleanup_before = now - period
# Retry after waiting
success = self.cache.try_append_if_under_limit(cache_key, now, limit, cleanup_before)
if not success:
# Still at limit after waiting - reload and raise
self._load_from_cache(merge=False)
q = self.calls[period]
while q and q[0] <= now - period:
q.popleft()
wait_time = (
(period - (now - q[0]) + self.buffer_seconds) if q and q[0] > now - period else period
)
if self.raise_on_limit:
raise RateLimitError(
retry_after=wait_time,
limit_info=self._get_limit_info(),
)
# Reload from cache to sync local state with what was actually recorded
self._load_from_cache(merge=False)
else:
# No cache - just record locally
for period in self.quotas:
self.calls[period].append(now)
async def __aenter__(self) -> "AsyncRateLimiter":
"""Async context manager entry."""
await self.acquire()
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Literal[False]:
"""Async context manager exit."""
return False