Skip to content

Commit

Permalink
Merge pull request #162 from the-mama-ai/fix-cache
Browse files Browse the repository at this point in the history
feat(github): simplify identity caching
  • Loading branch information
athornton committed May 15, 2024
2 parents f7594fa + f16d658 commit f105957
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 108 deletions.
144 changes: 70 additions & 74 deletions giftless/auth/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
import math
import os
import threading
import weakref
from collections.abc import Callable, Mapping, MutableMapping
from contextlib import AbstractContextManager, suppress
from operator import attrgetter, itemgetter
from threading import Lock, RLock
from typing import Any, Protocol, cast, overload
from typing import Any, Protocol, TypeVar, cast, overload

import cachetools.keys
import flask
Expand All @@ -24,6 +25,10 @@


# THREAD SAFE CACHING UTILS
# original type preserving "return type" for the decorators below
_RT = TypeVar("_RT")


class _LockType(AbstractContextManager, Protocol):
"""Generic type for threading.Lock and RLock."""

Expand Down Expand Up @@ -55,7 +60,7 @@ def _ensure_lock(


@overload
def single_call_method(_method: Callable[..., Any]) -> Callable[..., Any]:
def single_call_method(_method: Callable[..., _RT]) -> Callable[..., _RT]:
...


Expand All @@ -64,16 +69,16 @@ def single_call_method(
*,
key: Callable[..., Any] = cachetools.keys.methodkey,
lock: Callable[[Any], _LockType] | None = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
) -> Callable[[Callable[..., _RT]], Callable[..., _RT]]:
...


def single_call_method(
_method: Callable[..., Any] | None = None,
_method: Callable[..., _RT] | None = None,
*,
key: Callable[..., Any] = cachetools.keys.methodkey,
lock: Callable[[Any], _LockType] | None = None,
) -> Callable[..., Any]:
) -> Callable[..., _RT] | Callable[[Callable[..., _RT]], Callable[..., _RT]]:
"""Thread-safe decorator limiting concurrency of an idempotent method call.
When multiple threads concurrently call the decorated method with the same
arguments (governed by the 'key' callable argument), only the first one
Expand All @@ -91,12 +96,12 @@ def single_call_method(
"""
lock = _ensure_lock(lock)

def decorator(method: Callable[..., Any]) -> Callable[..., Any]:
def decorator(method: Callable[..., _RT]) -> Callable[..., _RT]:
# tracking concurrent calls per method arguments
concurrent_calls: dict[Any, SingleCallContext] = {}

@functools.wraps(method)
def wrapper(self: Any, *args: tuple, **kwargs: dict) -> Any:
def wrapper(self: Any, *args: tuple, **kwargs: dict) -> _RT:
lck = lock(self)
k = key(self, *args, **kwargs)
with lck:
Expand Down Expand Up @@ -128,7 +133,8 @@ def wrapper(self: Any, *args: tuple, **kwargs: dict) -> Any:
# call is done
if ctx.error:
raise ctx.error
return ctx.result
# https://github.com/python/mypy/issues/3737
return cast(_RT, ctx.result)

return wrapper

Expand All @@ -139,18 +145,18 @@ def wrapper(self: Any, *args: tuple, **kwargs: dict) -> Any:


def cachedmethod_threadsafe(
cache: Callable[[Any], MutableMapping],
cache: Callable[[Any], MutableMapping[Any, _RT]],
key: Callable[..., Any] = cachetools.keys.methodkey,
lock: Callable[[Any], _LockType] | None = None,
) -> Callable[..., Any]:
) -> Callable[..., Callable[..., _RT]]:
"""Threadsafe variant of cachetools.cachedmethod."""
lock = _ensure_lock(lock)

def decorator(method: Callable[..., Any]) -> Callable[..., Any]:
def decorator(method: Callable[..., _RT]) -> Callable[..., _RT]:
@cachetools.cachedmethod(cache=cache, key=key, lock=lock)
@single_call_method(key=key, lock=lock)
@functools.wraps(method)
def wrapper(self: Any, *args: tuple, **kwargs: dict) -> Any:
def wrapper(self: Any, *args: tuple, **kwargs: dict) -> _RT:
return method(self, *args, **kwargs)

return wrapper
Expand All @@ -163,8 +169,6 @@ def wrapper(self: Any, *args: tuple, **kwargs: dict) -> Any:
class CacheConfig:
"""Cache configuration."""

# max number of entries in the unique user LRU cache
user_max_size: int
# max number of entries in the token -> user LRU cache
token_max_size: int
# max number of authenticated org/repos TTL(LRU) for each user
Expand All @@ -175,9 +179,6 @@ class CacheConfig:
auth_other_ttl: float

class Schema(ma.Schema):
user_max_size = ma.fields.Int(
load_default=32, validate=ma.validate.Range(min=0)
)
token_max_size = ma.fields.Int(
load_default=32, validate=ma.validate.Range(min=0)
)
Expand Down Expand Up @@ -237,25 +238,39 @@ def from_dict(cls, data: Mapping[str, Any]) -> "Config":


# CORE AUTH
@dataclasses.dataclass(frozen=True, slots=True)
class _CoreGithubIdentity:
"""Entries uniquely identifying a GitHub user (from a token).
This serves as a key to mappings/caches of unique users.
"""

id: str
github_id: str

@classmethod
def from_token(
cls, token_data: Mapping[str, Any]
) -> "_CoreGithubIdentity":
return cls(*itemgetter("login", "id")(token_data))


class GithubIdentity(Identity):
"""User identity belonging to an authentication token.
Tracks user's permission for particular organizations/repositories.
"""

def __init__(
self,
login: str,
github_id: str,
name: str,
email: str,
*,
core_identity: _CoreGithubIdentity,
token_data: Mapping[str, Any],
cc: CacheConfig,
) -> None:
super().__init__()
self.id = login
self.github_id = github_id
self.name = name
self.email = email
super().__init__(
token_data.get("name"), core_identity.id, token_data.get("email")
)
self.core_identity = core_identity

# Expiring cache of authorized repos with different TTL for each
# permission type. It's assumed that anyone granted the WRITE
Expand All @@ -278,20 +293,9 @@ def expiration(_key: Any, value: set[Permission], now: float) -> float:
self._auth_cache = cachetools.TLRUCache(cc.auth_max_size, expiration)
self._auth_cache_lock = Lock()

def __repr__(self) -> str:
return (
f"<{self.__class__.__name__} "
f"id:{self.id} github_id:{self.github_id} name:{self.name}>"
)

def __eq__(self, other: object) -> bool:
return isinstance(other, self.__class__) and (
self.id,
self.github_id,
) == (other.id, other.github_id)

def __hash__(self) -> int:
return hash((self.id, self.github_id))
def __getattr__(self, attr: str) -> Any:
# proxy to the core_identity for its attributes
return getattr(self.core_identity, attr)

def permissions(
self, org: str, repo: str, *, authoritative: bool = False
Expand Down Expand Up @@ -341,17 +345,6 @@ def cache_ttl(self, permissions: set[Permission]) -> float:
"""Return default cache TTL [seconds] for a certain permission set."""
return self._auth_cache.ttu(None, permissions, 0.0)

@staticmethod
def cache_key(data: Mapping[str, Any]) -> tuple:
"""Return caching key from significant fields."""
return cachetools.keys.hashkey(*itemgetter("login", "id")(data))

@classmethod
def from_dict(
cls, data: Mapping[str, Any], cc: CacheConfig
) -> "GithubIdentity":
return cls(*itemgetter("login", "id", "name", "email")(data), cc=cc)


class GithubAuthenticator:
"""Main class performing GitHub "proxy" authentication/authorization."""
Expand Down Expand Up @@ -393,14 +386,18 @@ def __init__(self, cfg: Config) -> None:
self._api_headers = {"Accept": "application/vnd.github+json"}
if cfg.api_version:
self._api_headers["X-GitHub-Api-Version"] = cfg.api_version
# user identities per raw user data (keeping them authorized)
self._user_cache: MutableMapping[
Any, GithubIdentity
] = cachetools.LRUCache(maxsize=cfg.cache.user_max_size)
# user identities per token (shortcut to the cached entries above)
# user identities per token
self._token_cache: MutableMapping[
Any, GithubIdentity
] = cachetools.LRUCache(maxsize=cfg.cache.token_max_size)
# unique user identities, to get the same identity that's
# potentially already cached for a different token (same user)
# If all the token entries for one user get evicted from the
# token cache, the user entry here automatically ceases to exist too.
self._cached_users: MutableMapping[
Any, GithubIdentity
] = weakref.WeakValueDictionary()
self._cache_lock = RLock()
self._cache_config = cfg.cache

def _api_get(self, uri: str, ctx: CallContext) -> Mapping[str, Any]:
Expand All @@ -411,41 +408,40 @@ def _api_get(self, uri: str, ctx: CallContext) -> Mapping[str, Any]:
response.raise_for_status()
return cast(Mapping[str, Any], response.json())

@cachedmethod_threadsafe(
attrgetter("_user_cache"),
lambda self, data: GithubIdentity.cache_key(data),
)
def _get_user_cached(self, data: Mapping[str, Any]) -> GithubIdentity:
"""Return internal GitHub user identity from raw GitHub user data
[cached per login & id].
"""
return GithubIdentity.from_dict(data, self._cache_config)

@cachedmethod_threadsafe(
attrgetter("_token_cache"),
lambda self, ctx: cachetools.keys.hashkey(ctx.token),
attrgetter("_cache_lock"),
)
def _authenticate(self, ctx: CallContext) -> GithubIdentity:
"""Return internal GitHub user identity for a GitHub token in ctx
[cached per token].
"""
"""Return internal GitHub user identity for a GitHub token in ctx."""
_logger.debug("Authenticating user")
try:
user_data = self._api_get("/user", ctx)
token_data = self._api_get("/user", ctx)
except requests.exceptions.RequestException as e:
_logger.warning(msg := f"Couldn't authenticate the user: {e}")
raise Unauthorized(msg) from None

# different tokens can bear the same identity
return cast(GithubIdentity, self._get_user_cached(user_data))
core_identity = _CoreGithubIdentity.from_token(token_data)
# check if we haven't seen this identity before
# guard the code with the same lock as the _token_cache
with self._cache_lock:
try:
user = self._cached_users[core_identity]
except KeyError:
user = GithubIdentity(
core_identity, token_data, self._cache_config
)
self._cached_users[core_identity] = user
return user

@staticmethod
def _perm_list(permissions: set[Permission]) -> str:
return f"[{', '.join(sorted(p.value for p in permissions))}]"

@single_call_method(
key=lambda self, ctx, user: cachetools.keys.hashkey(
ctx.org, ctx.repo, user
ctx.org, ctx.repo, user.core_identity
)
)
def _authorize(self, ctx: CallContext, user: GithubIdentity) -> None:
Expand Down
16 changes: 10 additions & 6 deletions giftless/auth/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,15 @@ class Identity(ABC):
perform some actions.
"""

name: str | None = None
id: str | None = None
email: str | None = None
def __init__(
self,
name: str | None = None,
id: str | None = None,
email: str | None = None,
) -> None:
self.name = name
self.id = id
self.email = email

@abstractmethod
def is_authorized(
Expand All @@ -58,9 +64,7 @@ def __init__(
id: str | None = None,
email: str | None = None,
) -> None:
self.name = name
self.id = id
self.email = email
super().__init__(name, id, email)
self._allowed: PermissionTree = defaultdict(
lambda: defaultdict(lambda: defaultdict(set))
)
Expand Down
Loading

0 comments on commit f105957

Please sign in to comment.