diff --git a/diracx-db/src/diracx/db/sql/auth/db.py b/diracx-db/src/diracx/db/sql/auth/db.py index a35a0ad19..dd9726e62 100644 --- a/diracx-db/src/diracx/db/sql/auth/db.py +++ b/diracx-db/src/diracx/db/sql/auth/db.py @@ -5,6 +5,7 @@ import secrets from datetime import UTC, datetime from itertools import pairwise +from typing import Any from dateutil.relativedelta import relativedelta from dateutil.rrule import MONTHLY, rrule @@ -324,10 +325,7 @@ async def update_authorization_flow_status( ) async def insert_refresh_token( - self, - jti: UUID, - subject: str, - scope: str, + self, jti: UUID, subject: str, scope: str, policies: dict[str, Any] ) -> None: """Insert a refresh token in the DB. @@ -338,6 +336,7 @@ async def insert_refresh_token( jti=str(jti), sub=subject, scope=scope, + policies=policies, ) await self.conn.execute(stmt) diff --git a/diracx-db/src/diracx/db/sql/auth/schema.py b/diracx-db/src/diracx/db/sql/auth/schema.py index 0c7554318..ec3a60ba6 100644 --- a/diracx-db/src/diracx/db/sql/auth/schema.py +++ b/diracx-db/src/diracx/db/sql/auth/schema.py @@ -116,6 +116,7 @@ class RefreshTokens(Base): "Status", RefreshTokenStatus, server_default=RefreshTokenStatus.CREATED.name ) scope: Mapped[str1024] = mapped_column("Scope") + policies: Mapped[dict[str, Any]] = mapped_column("Policies") # User attributes bound to the refresh token sub: Mapped[str] = mapped_column("Sub", String(256), index=True) diff --git a/diracx-db/tests/auth/test_refresh_token.py b/diracx-db/tests/auth/test_refresh_token.py index 28d6dfe9d..498db60b4 100644 --- a/diracx-db/tests/auth/test_refresh_token.py +++ b/diracx-db/tests/auth/test_refresh_token.py @@ -25,6 +25,7 @@ async def test_insert(auth_db: AuthDB): jti1, "subject", "vo:lhcb property:NormalUser", + {"PolicySpecific": "OpenRefreshForTest"}, ) # Insert a second refresh token @@ -34,6 +35,7 @@ async def test_insert(auth_db: AuthDB): jti2, "subject", "vo:lhcb property:NormalUser", + {"PolicySpecific": "OpenRefreshForTest"}, ) # Make sure they don't have the same JWT ID @@ -46,6 +48,7 @@ async def test_get(auth_db: AuthDB): refresh_token_details = { "sub": "12345", "scope": "vo:lhcb property:NormalUser", + "policies": {"PolicySpecific": "OpenRefreshForTest"}, } # Insert refresh token details @@ -55,6 +58,7 @@ async def test_get(auth_db: AuthDB): jti, refresh_token_details["sub"], refresh_token_details["scope"], + refresh_token_details["policies"], ) # Enrich the dict with the generated refresh token attributes @@ -63,6 +67,7 @@ async def test_get(auth_db: AuthDB): "Scope": refresh_token_details["scope"], "JTI": jti, "Status": RefreshTokenStatus.CREATED, + "Policies": refresh_token_details["policies"], } # Get refresh token details @@ -90,9 +95,7 @@ async def test_get_user_refresh_tokens(auth_db: AuthDB): async with auth_db as auth_db: for sub in subjects: await auth_db.insert_refresh_token( - uuid7(), - sub, - "scope", + uuid7(), sub, "scope", {"PolicySpecific": "OpenRefreshForTest"} ) # Get the refresh tokens of each user @@ -117,9 +120,7 @@ async def test_revoke(auth_db: AuthDB): async with auth_db as auth_db: jti = uuid7() await auth_db.insert_refresh_token( - jti, - "subject", - "scope", + jti, "subject", "scope", {"PolicySpecific": "OpenRefreshForTest"} ) # Revoke the token @@ -146,9 +147,7 @@ async def test_revoke_user_refresh_tokens(auth_db: AuthDB): async with auth_db as auth_db: for sub in subjects: await auth_db.insert_refresh_token( - uuid7(), - sub, - "scope", + uuid7(), sub, "scope", {"PolicySpecific": "OpenRefreshForTest"} ) # Revoke the tokens of sub1 @@ -191,9 +190,7 @@ async def test_revoke_and_get_user_refresh_tokens(auth_db: AuthDB): for _ in range(nb_tokens): jti = uuid7() await auth_db.insert_refresh_token( - jti, - sub, - "scope", + jti, sub, "scope", {"PolicySpecific": "OpenRefreshForTest"} ) jtis.append(jti) @@ -239,9 +236,7 @@ async def test_get_refresh_tokens(auth_db: AuthDB): async with auth_db as auth_db: for sub in subjects: await auth_db.insert_refresh_token( - uuid7(), - sub, - "scope", + uuid7(), sub, "scope", {"PolicySpecific": "OpenRefreshForTest"} ) # Get all refresh tokens (Admin) diff --git a/diracx-logic/src/diracx/logic/auth/token.py b/diracx-logic/src/diracx/logic/auth/token.py index c9b0a7e9c..264bd773b 100644 --- a/diracx-logic/src/diracx/logic/auth/token.py +++ b/diracx-logic/src/diracx/logic/auth/token.py @@ -4,9 +4,10 @@ import base64 import hashlib +import logging import re from datetime import datetime, timedelta, timezone -from typing import cast +from typing import Any, cast from joserfc import jwt from joserfc.jwt import Claims @@ -37,6 +38,8 @@ verify_dirac_refresh_token, ) +logger = logging.getLogger(__name__) + async def get_oidc_token( grant_type: GrantType, @@ -45,6 +48,7 @@ async def get_oidc_token( config: Config, settings: AuthSettings, available_properties: set[SecurityProperty], + policies: dict[str, Any], device_code: str | None = None, code: str | None = None, redirect_uri: str | None = None, @@ -87,6 +91,7 @@ async def get_oidc_token( return await exchange_token( auth_db, scope, + policies, oidc_token_info, config, settings, @@ -235,6 +240,7 @@ async def perform_legacy_exchange( expected_api_key: str, preferred_username: str, scope: str, + policies: dict[str, Any], authorization: str, auth_db: AuthDB, available_properties: set[SecurityProperty], @@ -261,6 +267,7 @@ async def perform_legacy_exchange( return await exchange_token( auth_db, scope, + policies, {"sub": sub, "preferred_username": preferred_username}, config, settings, @@ -273,6 +280,7 @@ async def perform_legacy_exchange( async def exchange_token( auth_db: AuthDB, scope: str, + policies: dict[str, Any], oidc_token_info: dict, config: Config, settings: AuthSettings, @@ -316,14 +324,23 @@ async def exchange_token( # Merge the VO with the subject to get a unique DIRAC sub sub = f"{vo}:{sub}" + # Enrich the token with policy specific content + dirac_access_policies = {} + dirac_refresh_policies = {} + for policy_name, policy in policies.items(): + access_extra, refresh_extra = policy.enrich_tokens() + if access_extra: + dirac_access_policies[policy_name] = access_extra + if refresh_extra: + dirac_refresh_policies[policy_name] = refresh_extra + refresh_payload: RefreshTokenPayload | None = None + if include_refresh_token: # Insert the refresh token with user details into the RefreshTokens table # User details are needed to regenerate access tokens later refresh_jti = await insert_refresh_token( - auth_db=auth_db, - subject=sub, - scope=scope, + auth_db=auth_db, subject=sub, scope=scope, policies=dirac_refresh_policies ) # Generate refresh token payload @@ -338,7 +355,7 @@ async def exchange_token( # legacy_exchange is used to indicate that the original refresh token # was obtained from the legacy_exchange endpoint legacy_exchange=legacy_exchange, - dirac_policies={}, + dirac_policies=dirac_refresh_policies, ) # Generate access token payload @@ -357,7 +374,7 @@ async def exchange_token( preferred_username=preferred_username, dirac_group=dirac_group, exp=access_exp, - dirac_policies={}, + dirac_policies=dirac_access_policies, ) return access_payload, refresh_payload @@ -391,9 +408,7 @@ def _sign_token_payload(claims: dict, settings: AuthSettings) -> str: async def insert_refresh_token( - auth_db: AuthDB, - subject: str, - scope: str, + auth_db: AuthDB, subject: str, scope: str, policies: dict[str, Any] ) -> UUID: """Insert a refresh token into the database and return the JWT ID.""" # Generate a JWT ID @@ -404,6 +419,7 @@ async def insert_refresh_token( jti=jti, subject=subject, scope=scope, + policies=policies, ) return jti diff --git a/diracx-routers/src/diracx/routers/access_policies.py b/diracx-routers/src/diracx/routers/access_policies.py index c2e498d74..4bb0fe8a3 100644 --- a/diracx-routers/src/diracx/routers/access_policies.py +++ b/diracx-routers/src/diracx/routers/access_policies.py @@ -31,10 +31,6 @@ from fastapi import Depends from diracx.core.extensions import DiracEntryPoint, select_from_extension -from diracx.core.models import ( - AccessTokenPayload, - RefreshTokenPayload, -) from diracx.core.settings import DevelopmentSettings from diracx.routers.dependencies import auto_inject from diracx.routers.utils import AuthorizedUserInfo, verify_dirac_access_token @@ -90,15 +86,10 @@ async def policy(policy_name: str, user_info: AuthorizedUserInfo, /): return @staticmethod - def enrich_tokens( - access_payload: AccessTokenPayload, refresh_payload: RefreshTokenPayload | None - ) -> tuple[dict, dict]: + def enrich_tokens() -> tuple[dict, dict]: """Add content to access or refresh payload when issuing a token. Content can be whatever is desired inside the access or refresh payload. - - :param access_payload: access token payload - :param refresh_payload: refresh token payload :returns: extra content for both payload """ return {}, {} diff --git a/diracx-routers/src/diracx/routers/auth/token.py b/diracx-routers/src/diracx/routers/auth/token.py index b61222fa0..e42cc6f71 100644 --- a/diracx-routers/src/diracx/routers/auth/token.py +++ b/diracx-routers/src/diracx/routers/auth/token.py @@ -2,7 +2,6 @@ from __future__ import annotations -import logging import os from http import HTTPStatus from typing import Annotated, Literal @@ -35,42 +34,26 @@ from ..fastapi_classes import DiracxRouter router = DiracxRouter(require_auth=False) -logger = logging.getLogger(__name__) async def mint_token( access_payload: AccessTokenPayload, refresh_payload: RefreshTokenPayload | None, existing_refresh_token: str | None, - all_access_policies: dict[str, BaseAccessPolicy], settings: AuthSettings, ) -> TokenResponse: - """Enrich the token with policy specific content and mint it.""" + """Mint the token.""" if not refresh_payload and not existing_refresh_token: raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Refresh token is not set and no refresh token was provided", ) - # Enrich the token with policy specific content - dirac_access_policies = {} - dirac_refresh_policies = {} - for policy_name, policy in all_access_policies.items(): - access_extra, refresh_extra = policy.enrich_tokens( - access_payload, refresh_payload - ) - if access_extra: - dirac_access_policies[policy_name] = access_extra - if refresh_extra: - dirac_refresh_policies[policy_name] = refresh_extra - # Create the access token - access_payload.dirac_policies = dirac_access_policies access_token = create_token(access_payload, settings) # Create the refresh token if refresh_payload: - refresh_payload.dirac_policies = dirac_refresh_policies refresh_token = create_token(refresh_payload, settings) elif existing_refresh_token: refresh_token = existing_refresh_token @@ -133,6 +116,7 @@ async def get_oidc_token( config, settings, available_properties, + policies=all_access_policies, device_code=device_code, code=code, redirect_uri=redirect_uri, @@ -163,9 +147,7 @@ async def get_oidc_token( status_code=HTTPStatus.FORBIDDEN, detail=str(e), ) from e - return await mint_token( - access_payload, refresh_payload, refresh_token, all_access_policies, settings - ) + return await mint_token(access_payload, refresh_payload, refresh_token, settings) BASE_64_URL_SAFE_PATTERN = ( @@ -206,6 +188,7 @@ async def perform_legacy_exchange( expected_api_key=expected_api_key, preferred_username=preferred_username, scope=scope, + policies=all_access_policies, authorization=authorization, auth_db=auth_db, available_properties=available_properties, @@ -229,6 +212,4 @@ async def perform_legacy_exchange( status_code=HTTPStatus.FORBIDDEN, detail=str(e), ) from e - return await mint_token( - access_payload, refresh_payload, None, all_access_policies, settings - ) + return await mint_token(access_payload, refresh_payload, None, settings) diff --git a/diracx-testing/src/diracx/testing/utils.py b/diracx-testing/src/diracx/testing/utils.py index 4e02e3140..c14afb340 100644 --- a/diracx-testing/src/diracx/testing/utils.py +++ b/diracx-testing/src/diracx/testing/utils.py @@ -45,7 +45,6 @@ from uuid_utils import uuid7 from diracx.core.extensions import DiracEntryPoint -from diracx.core.models import AccessTokenPayload, RefreshTokenPayload if TYPE_CHECKING: from diracx.core.settings import ( @@ -188,11 +187,10 @@ async def policy( pass @staticmethod - def enrich_tokens( - access_payload: AccessTokenPayload, - refresh_payload: RefreshTokenPayload | None, - ): - return {"PolicySpecific": "OpenAccessForTest"}, {} + def enrich_tokens(): + return {"PolicySpecific": "OpenAccessForTest"}, { + "PolicySpecific": "OpenRefreshForTest" + } enabled_systems = { e.name for e in select_from_extension(group=DiracEntryPoint.SERVICES)