diff --git a/giftless/auth/github.py b/giftless/auth/github.py index cd515a0..7a4d86e 100644 --- a/giftless/auth/github.py +++ b/giftless/auth/github.py @@ -5,11 +5,12 @@ import math import os import threading +import weakref from collections.abc import Callable, Mapping, MutableMapping from contextlib import AbstractContextManager, suppress from operator import attrgetter, itemgetter from threading import Lock, RLock -from typing import Any, Protocol, cast, overload +from typing import Any, Protocol, TypeVar, cast, overload import cachetools.keys import flask @@ -24,6 +25,10 @@ # THREAD SAFE CACHING UTILS +# original type preserving "return type" for the decorators below +_RT = TypeVar("_RT") + + class _LockType(AbstractContextManager, Protocol): """Generic type for threading.Lock and RLock.""" @@ -55,7 +60,7 @@ def _ensure_lock( @overload -def single_call_method(_method: Callable[..., Any]) -> Callable[..., Any]: +def single_call_method(_method: Callable[..., _RT]) -> Callable[..., _RT]: ... @@ -64,16 +69,16 @@ def single_call_method( *, key: Callable[..., Any] = cachetools.keys.methodkey, lock: Callable[[Any], _LockType] | None = None, -) -> Callable[[Callable[..., Any]], Callable[..., Any]]: +) -> Callable[[Callable[..., _RT]], Callable[..., _RT]]: ... def single_call_method( - _method: Callable[..., Any] | None = None, + _method: Callable[..., _RT] | None = None, *, key: Callable[..., Any] = cachetools.keys.methodkey, lock: Callable[[Any], _LockType] | None = None, -) -> Callable[..., Any]: +) -> Callable[..., _RT] | Callable[[Callable[..., _RT]], Callable[..., _RT]]: """Thread-safe decorator limiting concurrency of an idempotent method call. When multiple threads concurrently call the decorated method with the same arguments (governed by the 'key' callable argument), only the first one @@ -91,12 +96,12 @@ def single_call_method( """ lock = _ensure_lock(lock) - def decorator(method: Callable[..., Any]) -> Callable[..., Any]: + def decorator(method: Callable[..., _RT]) -> Callable[..., _RT]: # tracking concurrent calls per method arguments concurrent_calls: dict[Any, SingleCallContext] = {} @functools.wraps(method) - def wrapper(self: Any, *args: tuple, **kwargs: dict) -> Any: + def wrapper(self: Any, *args: tuple, **kwargs: dict) -> _RT: lck = lock(self) k = key(self, *args, **kwargs) with lck: @@ -128,7 +133,8 @@ def wrapper(self: Any, *args: tuple, **kwargs: dict) -> Any: # call is done if ctx.error: raise ctx.error - return ctx.result + # https://github.com/python/mypy/issues/3737 + return cast(_RT, ctx.result) return wrapper @@ -139,18 +145,18 @@ def wrapper(self: Any, *args: tuple, **kwargs: dict) -> Any: def cachedmethod_threadsafe( - cache: Callable[[Any], MutableMapping], + cache: Callable[[Any], MutableMapping[Any, _RT]], key: Callable[..., Any] = cachetools.keys.methodkey, lock: Callable[[Any], _LockType] | None = None, -) -> Callable[..., Any]: +) -> Callable[..., Callable[..., _RT]]: """Threadsafe variant of cachetools.cachedmethod.""" lock = _ensure_lock(lock) - def decorator(method: Callable[..., Any]) -> Callable[..., Any]: + def decorator(method: Callable[..., _RT]) -> Callable[..., _RT]: @cachetools.cachedmethod(cache=cache, key=key, lock=lock) @single_call_method(key=key, lock=lock) @functools.wraps(method) - def wrapper(self: Any, *args: tuple, **kwargs: dict) -> Any: + def wrapper(self: Any, *args: tuple, **kwargs: dict) -> _RT: return method(self, *args, **kwargs) return wrapper @@ -163,8 +169,6 @@ def wrapper(self: Any, *args: tuple, **kwargs: dict) -> Any: class CacheConfig: """Cache configuration.""" - # max number of entries in the unique user LRU cache - user_max_size: int # max number of entries in the token -> user LRU cache token_max_size: int # max number of authenticated org/repos TTL(LRU) for each user @@ -175,9 +179,6 @@ class CacheConfig: auth_other_ttl: float class Schema(ma.Schema): - user_max_size = ma.fields.Int( - load_default=32, validate=ma.validate.Range(min=0) - ) token_max_size = ma.fields.Int( load_default=32, validate=ma.validate.Range(min=0) ) @@ -237,25 +238,39 @@ def from_dict(cls, data: Mapping[str, Any]) -> "Config": # CORE AUTH +@dataclasses.dataclass(frozen=True, slots=True) +class _CoreGithubIdentity: + """Entries uniquely identifying a GitHub user (from a token). + + This serves as a key to mappings/caches of unique users. + """ + + id: str + github_id: str + + @classmethod + def from_token( + cls, token_data: Mapping[str, Any] + ) -> "_CoreGithubIdentity": + return cls(*itemgetter("login", "id")(token_data)) + + class GithubIdentity(Identity): """User identity belonging to an authentication token. + Tracks user's permission for particular organizations/repositories. """ def __init__( self, - login: str, - github_id: str, - name: str, - email: str, - *, + core_identity: _CoreGithubIdentity, + token_data: Mapping[str, Any], cc: CacheConfig, ) -> None: - super().__init__() - self.id = login - self.github_id = github_id - self.name = name - self.email = email + super().__init__( + token_data.get("name"), core_identity.id, token_data.get("email") + ) + self.core_identity = core_identity # Expiring cache of authorized repos with different TTL for each # permission type. It's assumed that anyone granted the WRITE @@ -278,20 +293,9 @@ def expiration(_key: Any, value: set[Permission], now: float) -> float: self._auth_cache = cachetools.TLRUCache(cc.auth_max_size, expiration) self._auth_cache_lock = Lock() - def __repr__(self) -> str: - return ( - f"<{self.__class__.__name__} " - f"id:{self.id} github_id:{self.github_id} name:{self.name}>" - ) - - def __eq__(self, other: object) -> bool: - return isinstance(other, self.__class__) and ( - self.id, - self.github_id, - ) == (other.id, other.github_id) - - def __hash__(self) -> int: - return hash((self.id, self.github_id)) + def __getattr__(self, attr: str) -> Any: + # proxy to the core_identity for its attributes + return getattr(self.core_identity, attr) def permissions( self, org: str, repo: str, *, authoritative: bool = False @@ -341,17 +345,6 @@ def cache_ttl(self, permissions: set[Permission]) -> float: """Return default cache TTL [seconds] for a certain permission set.""" return self._auth_cache.ttu(None, permissions, 0.0) - @staticmethod - def cache_key(data: Mapping[str, Any]) -> tuple: - """Return caching key from significant fields.""" - return cachetools.keys.hashkey(*itemgetter("login", "id")(data)) - - @classmethod - def from_dict( - cls, data: Mapping[str, Any], cc: CacheConfig - ) -> "GithubIdentity": - return cls(*itemgetter("login", "id", "name", "email")(data), cc=cc) - class GithubAuthenticator: """Main class performing GitHub "proxy" authentication/authorization.""" @@ -393,14 +386,18 @@ def __init__(self, cfg: Config) -> None: self._api_headers = {"Accept": "application/vnd.github+json"} if cfg.api_version: self._api_headers["X-GitHub-Api-Version"] = cfg.api_version - # user identities per raw user data (keeping them authorized) - self._user_cache: MutableMapping[ - Any, GithubIdentity - ] = cachetools.LRUCache(maxsize=cfg.cache.user_max_size) - # user identities per token (shortcut to the cached entries above) + # user identities per token self._token_cache: MutableMapping[ Any, GithubIdentity ] = cachetools.LRUCache(maxsize=cfg.cache.token_max_size) + # unique user identities, to get the same identity that's + # potentially already cached for a different token (same user) + # If all the token entries for one user get evicted from the + # token cache, the user entry here automatically ceases to exist too. + self._cached_users: MutableMapping[ + Any, GithubIdentity + ] = weakref.WeakValueDictionary() + self._cache_lock = RLock() self._cache_config = cfg.cache def _api_get(self, uri: str, ctx: CallContext) -> Mapping[str, Any]: @@ -411,33 +408,32 @@ def _api_get(self, uri: str, ctx: CallContext) -> Mapping[str, Any]: response.raise_for_status() return cast(Mapping[str, Any], response.json()) - @cachedmethod_threadsafe( - attrgetter("_user_cache"), - lambda self, data: GithubIdentity.cache_key(data), - ) - def _get_user_cached(self, data: Mapping[str, Any]) -> GithubIdentity: - """Return internal GitHub user identity from raw GitHub user data - [cached per login & id]. - """ - return GithubIdentity.from_dict(data, self._cache_config) - @cachedmethod_threadsafe( attrgetter("_token_cache"), lambda self, ctx: cachetools.keys.hashkey(ctx.token), + attrgetter("_cache_lock"), ) def _authenticate(self, ctx: CallContext) -> GithubIdentity: - """Return internal GitHub user identity for a GitHub token in ctx - [cached per token]. - """ + """Return internal GitHub user identity for a GitHub token in ctx.""" _logger.debug("Authenticating user") try: - user_data = self._api_get("/user", ctx) + token_data = self._api_get("/user", ctx) except requests.exceptions.RequestException as e: _logger.warning(msg := f"Couldn't authenticate the user: {e}") raise Unauthorized(msg) from None - # different tokens can bear the same identity - return cast(GithubIdentity, self._get_user_cached(user_data)) + core_identity = _CoreGithubIdentity.from_token(token_data) + # check if we haven't seen this identity before + # guard the code with the same lock as the _token_cache + with self._cache_lock: + try: + user = self._cached_users[core_identity] + except KeyError: + user = GithubIdentity( + core_identity, token_data, self._cache_config + ) + self._cached_users[core_identity] = user + return user @staticmethod def _perm_list(permissions: set[Permission]) -> str: @@ -445,7 +441,7 @@ def _perm_list(permissions: set[Permission]) -> str: @single_call_method( key=lambda self, ctx, user: cachetools.keys.hashkey( - ctx.org, ctx.repo, user + ctx.org, ctx.repo, user.core_identity ) ) def _authorize(self, ctx: CallContext, user: GithubIdentity) -> None: diff --git a/giftless/auth/identity.py b/giftless/auth/identity.py index 11055c4..ec6e767 100644 --- a/giftless/auth/identity.py +++ b/giftless/auth/identity.py @@ -29,9 +29,15 @@ class Identity(ABC): perform some actions. """ - name: str | None = None - id: str | None = None - email: str | None = None + def __init__( + self, + name: str | None = None, + id: str | None = None, + email: str | None = None, + ) -> None: + self.name = name + self.id = id + self.email = email @abstractmethod def is_authorized( @@ -58,9 +64,7 @@ def __init__( id: str | None = None, email: str | None = None, ) -> None: - self.name = name - self.id = id - self.email = email + super().__init__(name, id, email) self._allowed: PermissionTree = defaultdict( lambda: defaultdict(lambda: defaultdict(set)) ) diff --git a/tests/auth/test_github.py b/tests/auth/test_github.py index 6dd406b..1409284 100644 --- a/tests/auth/test_github.py +++ b/tests/auth/test_github.py @@ -6,6 +6,7 @@ from time import sleep from typing import Any, cast +import cachetools.keys import flask import pytest import responses @@ -173,15 +174,14 @@ def test_config_schema_empty_cache() -> None: DEFAULT_CONFIG = gh.Config.from_dict({}) -DEFAULT_USER_DICT = { +DEFAULT_TOKEN_DICT = { "login": "kingofthebritons", "id": "12345678", "name": "arthur", "email": "arthur@camelot.gov.uk", } -DEFAULT_USER_ARGS = tuple(DEFAULT_USER_DICT.values()) +DEFAULT_USER_ARGS = tuple(DEFAULT_TOKEN_DICT.values()) ZERO_CACHE_CONFIG = gh.CacheConfig( - user_max_size=0, token_max_size=0, auth_max_size=0, # deliberately non-zero to not get rejected on setting by the timeout logic @@ -194,23 +194,13 @@ def test_config_schema_empty_cache() -> None: def test_github_identity_core() -> None: # use some value to get filtered out - user_dict = DEFAULT_USER_DICT | {"other_field": "other_value"} + token_dict = DEFAULT_TOKEN_DICT | {"other_field": "other_value"} cache_cfg = DEFAULT_CONFIG.cache - user = gh.GithubIdentity.from_dict(user_dict, cc=cache_cfg) - assert ( - user.id, - user.github_id, - user.name, - user.email, - ) == DEFAULT_USER_ARGS - assert all(arg in repr(user) for arg in DEFAULT_USER_ARGS[:3]) - assert hash(user) == hash((user.id, user.github_id)) - - args2 = (*DEFAULT_USER_ARGS[:2], "spammer", "spam@camelot.gov.uk") - user2 = gh.GithubIdentity(*args2, cc=cache_cfg) - assert user == user2 - user2.id = "654321" - assert user != user2 + core_identity = gh._CoreGithubIdentity.from_token(token_dict) + user = gh.GithubIdentity(core_identity, token_dict, cache_cfg) + assert (user.id, user.github_id, user.name, user.email) == tuple( + DEFAULT_TOKEN_DICT.values() + ) assert user.cache_ttl({Permission.WRITE}) == cache_cfg.auth_write_ttl assert ( @@ -220,7 +210,10 @@ def test_github_identity_core() -> None: def test_github_identity_authorization_cache() -> None: - user = gh.GithubIdentity(*DEFAULT_USER_ARGS, cc=DEFAULT_CONFIG.cache) + core_identity = gh._CoreGithubIdentity.from_token(DEFAULT_TOKEN_DICT) + user = gh.GithubIdentity( + core_identity, DEFAULT_TOKEN_DICT, DEFAULT_CONFIG.cache + ) assert not user.is_authorized(ORG, REPO, Permission.READ_META) user.authorize(ORG, REPO, {Permission.READ_META, Permission.READ}) assert user.permissions(ORG, REPO) == { @@ -233,7 +226,10 @@ def test_github_identity_authorization_cache() -> None: def test_github_identity_authorization_proxy_cache_only() -> None: - user = gh.GithubIdentity(*DEFAULT_USER_ARGS, cc=ZERO_CACHE_CONFIG) + core_identity = gh._CoreGithubIdentity.from_token(DEFAULT_TOKEN_DICT) + user = gh.GithubIdentity( + core_identity, DEFAULT_TOKEN_DICT, ZERO_CACHE_CONFIG + ) org, repo, repo2 = ORG, REPO, "repo2" user.authorize(org, repo, Permission.all()) user.authorize(org, repo2, Permission.all()) @@ -250,12 +246,12 @@ def auth_request( org: str = ORG, repo: str = REPO, req_auth_header: str | None = "", + token: str = "dummy-github-token", ) -> Identity | None: if req_auth_header is None: headers = None elif req_auth_header == "": # default - token - token = "dummy-github-token" basic_auth = base64.b64encode( b":".join([b"token", token.encode()]) ).decode() @@ -282,7 +278,7 @@ def mock_perm( auth: gh.GithubAuthenticator, org: str = ORG, repo: str = REPO, - login: str = DEFAULT_USER_DICT["login"], + login: str = DEFAULT_TOKEN_DICT["login"], *args: Any, **kwargs: Any, ) -> responses.BaseResponse: @@ -317,7 +313,7 @@ def test_github_auth_request_bad_user(app: flask.Flask) -> None: @responses.activate def test_github_auth_request_bad_perm(app: flask.Flask) -> None: auth = gh.factory(api_version=None) - mock_user(auth, json=DEFAULT_USER_DICT) + mock_user(auth, json=DEFAULT_TOKEN_DICT) mock_perm(auth, json={"error": "Forbidden"}, status=403) with pytest.raises(Unauthorized): @@ -327,7 +323,7 @@ def test_github_auth_request_bad_perm(app: flask.Flask) -> None: @responses.activate def test_github_auth_request_admin(app: flask.Flask) -> None: auth = gh.factory() - mock_user(auth, json=DEFAULT_USER_DICT) + mock_user(auth, json=DEFAULT_TOKEN_DICT) mock_perm(auth, json={"permission": "admin"}) identity = auth_request(app, auth) @@ -338,7 +334,7 @@ def test_github_auth_request_admin(app: flask.Flask) -> None: @responses.activate def test_github_auth_request_read(app: flask.Flask) -> None: auth = gh.factory() - mock_user(auth, json=DEFAULT_USER_DICT) + mock_user(auth, json=DEFAULT_TOKEN_DICT) mock_perm(auth, json={"permission": "read"}) identity = auth_request(app, auth) @@ -350,7 +346,7 @@ def test_github_auth_request_read(app: flask.Flask) -> None: @responses.activate def test_github_auth_request_none(app: flask.Flask) -> None: auth = gh.factory() - mock_user(auth, json=DEFAULT_USER_DICT) + mock_user(auth, json=DEFAULT_TOKEN_DICT) mock_perm(auth, json={"permission": "none"}) identity = auth_request(app, auth) @@ -362,7 +358,7 @@ def test_github_auth_request_none(app: flask.Flask) -> None: @responses.activate def test_github_auth_request_cached(app: flask.Flask) -> None: auth = gh.factory() - user_resp = mock_user(auth, json=DEFAULT_USER_DICT) + user_resp = mock_user(auth, json=DEFAULT_TOKEN_DICT) perm_resp = mock_perm(auth, json={"permission": "admin"}) auth_request(app, auth) @@ -372,3 +368,68 @@ def test_github_auth_request_cached(app: flask.Flask) -> None: assert identity.is_authorized(ORG, REPO, Permission.WRITE) assert user_resp.call_count == 1 assert perm_resp.call_count == 1 + + +@responses.activate +def test_github_auth_request_cache_no_leak(app: flask.Flask) -> None: + auth = gh.factory(cache={"token_max_size": 2}) + user_resp = mock_user(auth, json=DEFAULT_TOKEN_DICT) + perm_resp = mock_perm(auth, json={"permission": "admin"}) + + # authenticate 1st token, check it got cached properly + token1 = "token-1" + token1_cache_key = cachetools.keys.hashkey(token1) + identity1 = auth_request(app, auth, token=token1) + assert len(auth._token_cache) == 1 + assert token1_cache_key in auth._token_cache + assert len(auth._cached_users) == 1 + assert any(i is identity1 for i in auth._cached_users.values()) + # see both the authentication and authorization requests took place + assert user_resp.call_count == 1 + assert perm_resp.call_count == 1 + # remove local strong reference + del identity1 + + # authenticate the same user with different token (fill cache) + token2 = "token-2" + token2_cache_key = cachetools.keys.hashkey(token2) + identity2 = auth_request(app, auth, token=token2) + assert len(auth._token_cache) == 2 + assert token2_cache_key in auth._token_cache + assert len(auth._cached_users) == 1 + assert any(i is identity2 for i in auth._cached_users.values()) + # see only the authentication request took place + assert user_resp.call_count == 2 + assert perm_resp.call_count == 1 + del identity2 + + # authenticate once more (cache will evict oldest) + token3 = "token-3" + token3_cache_key = cachetools.keys.hashkey(token3) + identity3 = auth_request(app, auth, token=token3) + assert len(auth._token_cache) == 2 + assert token3_cache_key in auth._token_cache + assert token1_cache_key not in auth._token_cache + assert len(auth._cached_users) == 1 + assert any(i is identity3 for i in auth._cached_users.values()) + # see only the authentication request took place + assert user_resp.call_count == 3 + assert perm_resp.call_count == 1 + del identity3 + + # evict 2nd cached token + del auth._token_cache[token2_cache_key] + assert len(auth._token_cache) == 1 + assert len(auth._cached_users) == 1 + # evict 3rd + del auth._token_cache[token3_cache_key] + assert len(auth._token_cache) == 0 + assert len(auth._cached_users) == 0 + + # try once more with 1st token + auth_request(app, auth, token=token1) + assert len(auth._token_cache) == 1 + assert len(auth._cached_users) == 1 + # see both the authentication and authorization requests took place + assert user_resp.call_count == 4 + assert perm_resp.call_count == 2