diff --git a/giftless/auth/github.py b/giftless/auth/github.py index 0b29ff4..07e1b06 100644 --- a/giftless/auth/github.py +++ b/giftless/auth/github.py @@ -2,13 +2,14 @@ import dataclasses import functools import logging +import math import os import threading from collections.abc import Callable, Mapping, MutableMapping -from contextlib import AbstractContextManager +from contextlib import AbstractContextManager, suppress from operator import attrgetter, itemgetter -from threading import Condition, Lock, RLock -from typing import Any, cast, overload +from threading import Lock, RLock +from typing import Any, Protocol, cast, overload import cachetools.keys import flask @@ -23,22 +24,30 @@ # THREAD SAFE CACHING UTILS +class _LockType(AbstractContextManager, Protocol): + """Generic type for threading.Lock and RLock.""" + + def acquire(self, blocking: bool = ..., timeout: float = ...) -> bool: + ... + + def release(self) -> None: + ... + + @dataclasses.dataclass(kw_only=True) class SingleCallContext: """Thread-safety context for the single_call_method decorator.""" - # condition variable blocking a call with particular arguments - cond: Condition = dataclasses.field(default_factory=Condition) - # None - call not started, False - call ongoing, True - call done - # the three states are needed to cover any spurious (pthread-like) wake-ups - call_status: bool | None = None + # reentrant lock guarding a call with particular arguments + rlock: _LockType = dataclasses.field(default_factory=RLock) + start_call: bool = True result: Any = None error: BaseException | None = None def _ensure_lock( - existing_lock: Callable[[Any], AbstractContextManager] | None, -) -> Callable[[Any], AbstractContextManager]: + existing_lock: Callable[[Any], _LockType] | None = None, +) -> Callable[[Any], _LockType]: if existing_lock is None: default_lock = RLock() return lambda _self: default_lock @@ -54,7 +63,7 @@ def single_call_method(_method: Callable[..., Any]) -> Callable[..., Any]: def single_call_method( *, key: Callable[..., Any] = cachetools.keys.methodkey, - lock: Callable[[Any], AbstractContextManager] | None = None, + lock: Callable[[Any], _LockType] | None = None, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: ... @@ -63,7 +72,7 @@ def single_call_method( _method: Callable[..., Any] | None = None, *, key: Callable[..., Any] = cachetools.keys.methodkey, - lock: Callable[[Any], AbstractContextManager] | None = None, + lock: Callable[[Any], _LockType] | None = None, ) -> Callable[..., Any]: """Thread-safe decorator limiting concurrency of an idempotent method call. When multiple threads concurrently call the decorated method with the same @@ -78,7 +87,7 @@ def single_call_method( It's possible to provide a "getter" callable for the lock guarding the main call cache, called as 'lock(self)'. There's a built-in lock by default. - Each concurrent call is then guarded by its own lock/conditional variable. + Each concurrent call is then guarded by its own reentrant lock variable. """ lock = _ensure_lock(lock) @@ -97,35 +106,29 @@ def wrapper(self: Any, *args: tuple, **kwargs: dict) -> Any: concurrent_calls[k] = ctx = SingleCallContext() # start locked for the current thread, so the following # gap won't let other threads populate the result - ctx.cond.acquire() + ctx.rlock.acquire() - with ctx.cond: - if ctx.call_status is None: - # populating the result - ctx.call_status = False + with ctx.rlock: + if ctx.start_call: + ctx.start_call = False + ctx.rlock.release() # unlock the starting lock try: result = method(self, *args, **kwargs) except BaseException as e: ctx.error = e raise finally: - # call is done, cleanup its entry and notify threads + # call is done, cleanup its entry with lck: del concurrent_calls[k] - ctx.cond.release() # unlock the starting lock - ctx.cond.notify_all() ctx.result = result - ctx.call_status = True return result else: - # waiting for the result to get populated - while True: - if ctx.error: - raise ctx.error - if ctx.call_status: - return ctx.result - ctx.cond.wait() + # call is done + if ctx.error: + raise ctx.error + return ctx.result return wrapper @@ -138,7 +141,7 @@ def wrapper(self: Any, *args: tuple, **kwargs: dict) -> Any: def cachedmethod_threadsafe( cache: Callable[[Any], MutableMapping], key: Callable[..., Any] = cachetools.keys.methodkey, - lock: Callable[[Any], AbstractContextManager] | None = None, + lock: Callable[[Any], _LockType] | None = None, ) -> Callable[..., Any]: """Threadsafe variant of cachetools.cachedmethod.""" lock = _ensure_lock(lock) @@ -178,15 +181,14 @@ class Schema(ma.Schema): token_max_size = ma.fields.Int( load_default=32, validate=ma.validate.Range(min=0) ) - # the auth cache must have at least one valid slot auth_max_size = ma.fields.Int( - load_default=32, validate=ma.validate.Range(min=1) + load_default=32, validate=ma.validate.Range(min=0) ) auth_write_ttl = ma.fields.Float( - load_default=15 * 60.0, validate=ma.validate.Range(min=1.0) + load_default=15 * 60.0, validate=ma.validate.Range(min=0) ) auth_other_ttl = ma.fields.Float( - load_default=30.0, validate=ma.validate.Range(min=1.0) + load_default=30.0, validate=ma.validate.Range(min=0) ) @ma.post_load @@ -224,8 +226,9 @@ class Schema(ma.Schema): @ma.post_load def make_object( - self, data: Mapping[str, Any], **_kwargs: Mapping + self, data: MutableMapping[str, Any], **_kwargs: Mapping ) -> "Config": + data["api_url"] = data["api_url"].rstrip("/") return Config(**data) @classmethod @@ -262,6 +265,10 @@ def expiration(_key: Any, value: set[Permission], now: float) -> float: ) return now + ttl + # size-unlimited proxy cache to ensure at least one successful hit + self._auth_cache_read_proxy: MutableMapping[ + Any, set[Permission] + ] = cachetools.TTLCache(math.inf, 60.0) self._auth_cache = cachetools.TLRUCache(cc.auth_max_size, expiration) self._auth_cache_lock = Lock() @@ -280,17 +287,28 @@ def __eq__(self, other: object) -> bool: def __hash__(self) -> int: return hash((self.login, self.id)) - def permissions(self, org: str, repo: str) -> set[Permission] | None: + def permissions( + self, org: str, repo: str, *, authoritative: bool = False + ) -> set[Permission] | None: key = cachetools.keys.hashkey(org, repo) with self._auth_cache_lock: - return self._auth_cache.get(key) + if authoritative: + permission = self._auth_cache_read_proxy.pop(key, None) + else: + permission = self._auth_cache_read_proxy.get(key) + if permission is None: + return self._auth_cache.get(key) + if authoritative: + with suppress(ValueError): + self._auth_cache[key] = permission + return permission def authorize( self, org: str, repo: str, permissions: set[Permission] | None ) -> None: key = cachetools.keys.hashkey(org, repo) with self._auth_cache_lock: - self._auth_cache[key] = ( + self._auth_cache_read_proxy[key] = ( permissions if permissions is not None else set() ) @@ -301,7 +319,7 @@ def is_authorized( permission: Permission, oid: str | None = None, ) -> bool: - permissions = self.permissions(organization, repo) + permissions = self.permissions(organization, repo, authoritative=True) return permission in permissions if permissions else False def cache_ttl(self, permissions: set[Permission]) -> float: @@ -351,11 +369,12 @@ def _extract_token(self, request: flask.Request) -> str: return token def __post_init__(self, request: flask.Request) -> None: - self.org, self.repo = request.path.split("/", maxsplit=3)[1:3] + org_repo_getter = itemgetter("organization", "repo") + self.org, self.repo = org_repo_getter(request.view_args or {}) self.token = self._extract_token(request) def __init__(self, cfg: Config) -> None: - self._api_url = cfg.api_url.rstrip("/") + self._api_url = cfg.api_url self._api_headers = {"Accept": "application/vnd.github+json"} if cfg.api_version: self._api_headers["X-GitHub-Api-Version"] = cfg.api_version @@ -471,8 +490,12 @@ def __call__(self, request: flask.Request) -> Identity | None: self._authorize(ctx, user) return user + @property + def api_url(self) -> str: + return self._api_url + -def factory(**options: Mapping[str, Any]) -> GithubAuthenticator: +def factory(**options: Any) -> GithubAuthenticator: """Build GitHub Authenticator from supplied options.""" config = Config.from_dict(options) return GithubAuthenticator(config) diff --git a/requirements/dev.in b/requirements/dev.in index e32f3ba..99cea93 100644 --- a/requirements/dev.in +++ b/requirements/dev.in @@ -9,6 +9,7 @@ pytest-mypy pytest-env pytest-cov pytest-vcr +responses pytz types-pytz diff --git a/requirements/dev.txt b/requirements/dev.txt index d8030a4..2e05383 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -163,13 +163,17 @@ pytz==2023.3.post1 pyyaml==6.0.1 # via # -c requirements/main.txt + # responses # vcrpy recommonmark==0.7.1 # via -r requirements/dev.in requests==2.31.0 # via # -c requirements/main.txt + # responses # sphinx +responses==0.25.0 + # via -r requirements/dev.in rsa==4.9 # via # -c requirements/main.txt @@ -241,6 +245,7 @@ urllib3==2.0.7 # via # -c requirements/main.txt # requests + # responses # types-requests vcrpy==5.1.0 # via pytest-vcr diff --git a/tests/auth/test_github.py b/tests/auth/test_github.py new file mode 100644 index 0000000..66b137e --- /dev/null +++ b/tests/auth/test_github.py @@ -0,0 +1,369 @@ +"""Unit tests for auth.github module.""" +import base64 +from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor, as_completed +from random import shuffle +from time import sleep +from typing import Any, cast + +import flask +import pytest +import responses +from marshmallow.exceptions import ValidationError + +import giftless.auth.github as gh +from giftless.auth import Unauthorized +from giftless.auth.identity import Identity, Permission + + +def test_ensure_default_lock() -> None: + lock_getter = gh._ensure_lock() + lock = lock_getter(None) + with lock: + # Is it a RLock or just Lock? + if lock.acquire(blocking=False): + lock.release() + + +def _concurrent_side_effects( + decorator: Callable[[Callable[..., Any]], Any], + thread_cnt: int = 4, + exception: type[Exception] | None = None, +) -> tuple[list, list, list]: + @decorator + def decorated_method(_ignored_self: Any, index: int) -> int: + sleep(0.1) + side_effects[index] = index + if exception is not None: + raise exception(index) + return index + + results: list[int | None] = [None] * thread_cnt + side_effects: list[int | None] = [None] * thread_cnt + exceptions: list[Exception | None] = [None] * thread_cnt + + with ThreadPoolExecutor( + max_workers=thread_cnt, thread_name_prefix="scm-" + ) as executor: + thread_indices = list(range(thread_cnt)) + shuffle(thread_indices) + futures = { + executor.submit(decorated_method, i, i): i for i in thread_indices + } + for future in as_completed(futures): + i = futures[future] + try: + result = future.result() + except Exception as exc: + exceptions[i] = exc + else: + results[i] = result + + return results, side_effects, exceptions + + +def test_single_call_method_decorator_default_no_args() -> None: + decorator = gh.single_call_method + results, side_effects, exceptions = _concurrent_side_effects(decorator) + # the differing index (taken into account by default) breaks call coupling, + # so this decoration has no effect and all calls get through + assert results == side_effects + + +def test_single_call_method_decorator_default_args() -> None: + decorator = gh.single_call_method() + results, side_effects, exceptions = _concurrent_side_effects(decorator) + # same as test_single_call_method_decorator_default_no_args, but checking + # if the decorator factory works properly with no explicit args + assert results == side_effects + + +def test_single_call_method_decorator_default_exception() -> None: + decorator = gh.single_call_method() + results, side_effects, exceptions = _concurrent_side_effects( + decorator, exception=Exception + ) + # same as test_single_call_method_decorator_default_no_args, but checking + # if the decorator factory works properly with no explicit args + assert all(r is None for r in results) + assert all(se is not None for se in side_effects) + assert all(e is not None for e in exceptions) + + +def test_single_call_method_decorator_call_once() -> None: + # using a constant hash key to put all threads in the same bucket + decorator = gh.single_call_method(key=lambda *args: 0) + threads = 4 + results, side_effects, exceptions = _concurrent_side_effects( + decorator, threads + ) + assert all(e is None for e in exceptions) + # as there's just a sleep in the decorated_method, technically multiple + # threads could enter the method call (and thus produce side_effects), + # but the expectation is the sleep is long enough for all to get stuck + chosen_ones = [se for se in side_effects if se is not None] + # at least one thread got stuck + assert len(chosen_ones) < threads + assert all(r in chosen_ones for r in results) + + +def test_single_call_method_decorator_call_once_exception() -> None: + # using a constant hash key to put all threads in the same bucket + decorator = gh.single_call_method(key=lambda *args: 0) + threads = 4 + results, side_effects, exceptions = _concurrent_side_effects( + decorator, threads, Exception + ) + assert all(r is None for r in results) + assert all(e is not None for e in exceptions) + chosen_ones = [se for se in side_effects if se is not None] + # at least one thread got stuck + assert len(chosen_ones) < threads + # make sure the exceptions come from the calling thread + assert all(e.args[0] in chosen_ones for e in exceptions) + + +def test_cachedmethod_threadsafe_default_key() -> None: + # cache all the uncoupled calls + cache: dict[Any, Any] = {} + threads = 4 + decorator = gh.cachedmethod_threadsafe(lambda _self: cache) + results, side_effects, exceptions = _concurrent_side_effects( + decorator, threads + ) + assert all(e is None for e in exceptions) + assert results == side_effects + assert len(cache) == threads + + +def test_cachedmethod_threadsafe_call_once() -> None: + # one result ends up cached, even if call produces different results + # (this is supposed to be used for idempotent methods, so multiple calls + # are supposed to produce identical results) + cache: dict[Any, Any] = {} + decorator = gh.cachedmethod_threadsafe( + lambda _self: cache, key=lambda *args: 0 + ) + results, side_effects, exceptions = _concurrent_side_effects(decorator) + assert all(e is None for e in exceptions) + chosen_ones = [se for se in side_effects if se is not None] + assert len(cache) == 1 + cached_result = next(iter(cache.values())) + assert cached_result in chosen_ones + + +def test_config_schema_defaults() -> None: + config = gh.Config.from_dict({}) + assert isinstance(config, gh.Config) + assert hasattr(config, "cache") + assert isinstance(config.cache, gh.CacheConfig) + + +def test_config_schema_default_cache() -> None: + config = gh.Config.from_dict({"cache": {}}) + assert isinstance(config, gh.Config) + assert hasattr(config, "cache") + assert isinstance(config.cache, gh.CacheConfig) + + +def test_config_schema_empty_cache() -> None: + options = {"cache": None} + with pytest.raises(ValidationError): + _config = gh.Config.from_dict(options) + + +DEFAULT_CONFIG = gh.Config.from_dict({}) +DEFAULT_USER_DICT = { + "login": "kingofthebritons", + "id": "12345678", + "name": "arthur", + "email": "arthur@camelot.gov.uk", +} +DEFAULT_USER_ARGS = tuple(DEFAULT_USER_DICT.values()) +ZERO_CACHE_CONFIG = gh.CacheConfig( + user_max_size=0, + token_max_size=0, + auth_max_size=0, + # deliberately non-zero to not get rejected on setting by the timeout logic + auth_write_ttl=60.0, + auth_other_ttl=30.0, +) +ORG = "my-org" +REPO = "my-repo" + + +def test_github_identity_core() -> None: + # use some value to get filtered out + user_dict = DEFAULT_USER_DICT | {"other_field": "other_value"} + cache_cfg = DEFAULT_CONFIG.cache + user = gh.GithubIdentity.from_dict(user_dict, cc=cache_cfg) + assert (user.login, user.id, user.name, user.email) == DEFAULT_USER_ARGS + assert all(arg in repr(user) for arg in DEFAULT_USER_ARGS[:3]) + assert hash(user) == hash((user.login, user.id)) + + args2 = (*DEFAULT_USER_ARGS[:2], "spammer", "spam@camelot.gov.uk") + user2 = gh.GithubIdentity(*args2, cc=cache_cfg) + assert user == user2 + user2.id = "654321" + assert user != user2 + + assert user.cache_ttl({Permission.WRITE}) == cache_cfg.auth_write_ttl + assert ( + user.cache_ttl({Permission.READ_META, Permission.READ}) + == cache_cfg.auth_other_ttl + ) + + +def test_github_identity_authorization_cache() -> None: + user = gh.GithubIdentity(*DEFAULT_USER_ARGS, cc=DEFAULT_CONFIG.cache) + assert not user.is_authorized(ORG, REPO, Permission.READ_META) + user.authorize(ORG, REPO, {Permission.READ_META, Permission.READ}) + assert user.permissions(ORG, REPO) == { + Permission.READ_META, + Permission.READ, + } + assert user.is_authorized(ORG, REPO, Permission.READ_META) + assert user.is_authorized(ORG, REPO, Permission.READ) + assert not user.is_authorized(ORG, REPO, Permission.WRITE) + + +def test_github_identity_authorization_proxy_cache_only() -> None: + user = gh.GithubIdentity(*DEFAULT_USER_ARGS, cc=ZERO_CACHE_CONFIG) + org, repo, repo2 = ORG, REPO, "repo2" + user.authorize(org, repo, Permission.all()) + user.authorize(org, repo2, Permission.all()) + assert user.is_authorized(org, repo, Permission.READ_META) + # without cache, the authorization expires after 1st is_authorized + assert not user.is_authorized(org, repo, Permission.READ_META) + assert user.is_authorized(org, repo2, Permission.READ_META) + assert not user.is_authorized(org, repo2, Permission.READ_META) + + +def auth_request( + app: flask.Flask, + auth: gh.GithubAuthenticator, + org: str = ORG, + repo: str = REPO, + req_auth_header: str | None = "", +) -> Identity | None: + if req_auth_header is None: + headers = None + elif req_auth_header == "": + # default - token + token = "dummy-github-token" + basic_auth = base64.b64encode( + b":".join([b"token", token.encode()]) + ).decode() + headers = {"Authorization": f"Basic {basic_auth}"} + else: + headers = {"Authorization": req_auth_header} + + with app.test_request_context( + f"/{org}/{repo}/objects/batch", + method="POST", + headers=headers, + ): + return auth(flask.request) + + +def mock_user( + auth: gh.GithubAuthenticator, *args: Any, **kwargs: Any +) -> responses.BaseResponse: + ret = responses.get(f"{auth.api_url}/user", *args, **kwargs) + return cast(responses.BaseResponse, ret) + + +def mock_perm( + auth: gh.GithubAuthenticator, + org: str = ORG, + repo: str = REPO, + login: str = DEFAULT_USER_DICT["login"], + *args: Any, + **kwargs: Any, +) -> responses.BaseResponse: + ret = responses.get( + f"{auth.api_url}/repos/{org}/{repo}/collaborators/{login}/permission", + *args, + **kwargs, + ) + return cast(responses.BaseResponse, ret) + + +def test_github_auth_request_missing_auth(app: flask.Flask) -> None: + auth = gh.factory() + with pytest.raises(Unauthorized): + auth_request(app, auth, req_auth_header=None) + + +def test_github_auth_request_funny_auth(app: flask.Flask) -> None: + auth = gh.factory() + with pytest.raises(Unauthorized): + auth_request(app, auth, req_auth_header="Funny key1=val1, key2=val2") + + +@responses.activate +def test_github_auth_request_bad_user(app: flask.Flask) -> None: + auth = gh.factory() + mock_user(auth, json={"error": "Forbidden"}, status=403) + with pytest.raises(Unauthorized): + auth_request(app, auth) + + +@responses.activate +def test_github_auth_request_bad_perm(app: flask.Flask) -> None: + auth = gh.factory(api_version=None) + mock_user(auth, json=DEFAULT_USER_DICT) + mock_perm(auth, json={"error": "Forbidden"}, status=403) + + with pytest.raises(Unauthorized): + auth_request(app, auth) + + +@responses.activate +def test_github_auth_request_admin(app: flask.Flask) -> None: + auth = gh.factory() + mock_user(auth, json=DEFAULT_USER_DICT) + mock_perm(auth, json={"permission": "admin"}) + + identity = auth_request(app, auth) + assert identity is not None + assert identity.is_authorized(ORG, REPO, Permission.WRITE) + + +@responses.activate +def test_github_auth_request_read(app: flask.Flask) -> None: + auth = gh.factory() + mock_user(auth, json=DEFAULT_USER_DICT) + mock_perm(auth, json={"permission": "read"}) + + identity = auth_request(app, auth) + assert identity is not None + assert not identity.is_authorized(ORG, REPO, Permission.WRITE) + assert identity.is_authorized(ORG, REPO, Permission.READ) + + +@responses.activate +def test_github_auth_request_none(app: flask.Flask) -> None: + auth = gh.factory() + mock_user(auth, json=DEFAULT_USER_DICT) + mock_perm(auth, json={"permission": "none"}) + + identity = auth_request(app, auth) + assert identity is not None + assert not identity.is_authorized(ORG, REPO, Permission.WRITE) + assert not identity.is_authorized(ORG, REPO, Permission.READ) + + +@responses.activate +def test_github_auth_request_cached(app: flask.Flask) -> None: + auth = gh.factory() + user_resp = mock_user(auth, json=DEFAULT_USER_DICT) + perm_resp = mock_perm(auth, json={"permission": "admin"}) + + auth_request(app, auth) + # second cached call + identity = auth_request(app, auth) + assert identity is not None + assert identity.is_authorized(ORG, REPO, Permission.WRITE) + assert user_resp.call_count == 1 + assert perm_resp.call_count == 1