Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions diracx-db/src/diracx/db/sql/auth/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import secrets
from datetime import UTC, datetime
from itertools import pairwise
from typing import Any

from dateutil.relativedelta import relativedelta
from dateutil.rrule import MONTHLY, rrule
Expand Down Expand Up @@ -324,10 +325,7 @@ async def update_authorization_flow_status(
)

async def insert_refresh_token(
self,
jti: UUID,
subject: str,
scope: str,
self, jti: UUID, subject: str, scope: str, policies: dict[str, Any]
) -> None:
"""Insert a refresh token in the DB.

Expand All @@ -338,6 +336,7 @@ async def insert_refresh_token(
jti=str(jti),
sub=subject,
scope=scope,
policies=policies,
)
await self.conn.execute(stmt)

Expand Down
1 change: 1 addition & 0 deletions diracx-db/src/diracx/db/sql/auth/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class RefreshTokens(Base):
"Status", RefreshTokenStatus, server_default=RefreshTokenStatus.CREATED.name
)
scope: Mapped[str1024] = mapped_column("Scope")
policies: Mapped[dict[str, Any]] = mapped_column("Policies")

# User attributes bound to the refresh token
sub: Mapped[str] = mapped_column("Sub", String(256), index=True)
Expand Down
25 changes: 10 additions & 15 deletions diracx-db/tests/auth/test_refresh_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ async def test_insert(auth_db: AuthDB):
jti1,
"subject",
"vo:lhcb property:NormalUser",
{"PolicySpecific": "OpenRefreshForTest"},
)

# Insert a second refresh token
Expand All @@ -34,6 +35,7 @@ async def test_insert(auth_db: AuthDB):
jti2,
"subject",
"vo:lhcb property:NormalUser",
{"PolicySpecific": "OpenRefreshForTest"},
)

# Make sure they don't have the same JWT ID
Expand All @@ -46,6 +48,7 @@ async def test_get(auth_db: AuthDB):
refresh_token_details = {
"sub": "12345",
"scope": "vo:lhcb property:NormalUser",
"policies": {"PolicySpecific": "OpenRefreshForTest"},
}

# Insert refresh token details
Expand All @@ -55,6 +58,7 @@ async def test_get(auth_db: AuthDB):
jti,
refresh_token_details["sub"],
refresh_token_details["scope"],
refresh_token_details["policies"],
)

# Enrich the dict with the generated refresh token attributes
Expand All @@ -63,6 +67,7 @@ async def test_get(auth_db: AuthDB):
"Scope": refresh_token_details["scope"],
"JTI": jti,
"Status": RefreshTokenStatus.CREATED,
"Policies": refresh_token_details["policies"],
}

# Get refresh token details
Expand Down Expand Up @@ -90,9 +95,7 @@ async def test_get_user_refresh_tokens(auth_db: AuthDB):
async with auth_db as auth_db:
for sub in subjects:
await auth_db.insert_refresh_token(
uuid7(),
sub,
"scope",
uuid7(), sub, "scope", {"PolicySpecific": "OpenRefreshForTest"}
)

# Get the refresh tokens of each user
Expand All @@ -117,9 +120,7 @@ async def test_revoke(auth_db: AuthDB):
async with auth_db as auth_db:
jti = uuid7()
await auth_db.insert_refresh_token(
jti,
"subject",
"scope",
jti, "subject", "scope", {"PolicySpecific": "OpenRefreshForTest"}
)

# Revoke the token
Expand All @@ -146,9 +147,7 @@ async def test_revoke_user_refresh_tokens(auth_db: AuthDB):
async with auth_db as auth_db:
for sub in subjects:
await auth_db.insert_refresh_token(
uuid7(),
sub,
"scope",
uuid7(), sub, "scope", {"PolicySpecific": "OpenRefreshForTest"}
)

# Revoke the tokens of sub1
Expand Down Expand Up @@ -191,9 +190,7 @@ async def test_revoke_and_get_user_refresh_tokens(auth_db: AuthDB):
for _ in range(nb_tokens):
jti = uuid7()
await auth_db.insert_refresh_token(
jti,
sub,
"scope",
jti, sub, "scope", {"PolicySpecific": "OpenRefreshForTest"}
)
jtis.append(jti)

Expand Down Expand Up @@ -239,9 +236,7 @@ async def test_get_refresh_tokens(auth_db: AuthDB):
async with auth_db as auth_db:
for sub in subjects:
await auth_db.insert_refresh_token(
uuid7(),
sub,
"scope",
uuid7(), sub, "scope", {"PolicySpecific": "OpenRefreshForTest"}
)

# Get all refresh tokens (Admin)
Expand Down
34 changes: 25 additions & 9 deletions diracx-logic/src/diracx/logic/auth/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import base64
import hashlib
import logging
import re
from datetime import datetime, timedelta, timezone
from typing import cast
from typing import Any, cast

from joserfc import jwt
from joserfc.jwt import Claims
Expand Down Expand Up @@ -37,6 +38,8 @@
verify_dirac_refresh_token,
)

logger = logging.getLogger(__name__)


async def get_oidc_token(
grant_type: GrantType,
Expand All @@ -45,6 +48,7 @@ async def get_oidc_token(
config: Config,
settings: AuthSettings,
available_properties: set[SecurityProperty],
policies: dict[str, Any],
device_code: str | None = None,
code: str | None = None,
redirect_uri: str | None = None,
Expand Down Expand Up @@ -87,6 +91,7 @@ async def get_oidc_token(
return await exchange_token(
auth_db,
scope,
policies,
oidc_token_info,
config,
settings,
Expand Down Expand Up @@ -235,6 +240,7 @@ async def perform_legacy_exchange(
expected_api_key: str,
preferred_username: str,
scope: str,
policies: dict[str, Any],
authorization: str,
auth_db: AuthDB,
available_properties: set[SecurityProperty],
Expand All @@ -261,6 +267,7 @@ async def perform_legacy_exchange(
return await exchange_token(
auth_db,
scope,
policies,
{"sub": sub, "preferred_username": preferred_username},
config,
settings,
Expand All @@ -273,6 +280,7 @@ async def perform_legacy_exchange(
async def exchange_token(
auth_db: AuthDB,
scope: str,
policies: dict[str, Any],
oidc_token_info: dict,
config: Config,
settings: AuthSettings,
Expand Down Expand Up @@ -316,14 +324,23 @@ async def exchange_token(
# Merge the VO with the subject to get a unique DIRAC sub
sub = f"{vo}:{sub}"

# Enrich the token with policy specific content
dirac_access_policies = {}
dirac_refresh_policies = {}
for policy_name, policy in policies.items():
access_extra, refresh_extra = policy.enrich_tokens()
if access_extra:
dirac_access_policies[policy_name] = access_extra
if refresh_extra:
dirac_refresh_policies[policy_name] = refresh_extra

refresh_payload: RefreshTokenPayload | None = None

if include_refresh_token:
# Insert the refresh token with user details into the RefreshTokens table
# User details are needed to regenerate access tokens later
refresh_jti = await insert_refresh_token(
auth_db=auth_db,
subject=sub,
scope=scope,
auth_db=auth_db, subject=sub, scope=scope, policies=dirac_refresh_policies
)

# Generate refresh token payload
Expand All @@ -338,7 +355,7 @@ async def exchange_token(
# legacy_exchange is used to indicate that the original refresh token
# was obtained from the legacy_exchange endpoint
legacy_exchange=legacy_exchange,
dirac_policies={},
dirac_policies=dirac_refresh_policies,
)

# Generate access token payload
Expand All @@ -357,7 +374,7 @@ async def exchange_token(
preferred_username=preferred_username,
dirac_group=dirac_group,
exp=access_exp,
dirac_policies={},
dirac_policies=dirac_access_policies,
)

return access_payload, refresh_payload
Expand Down Expand Up @@ -391,9 +408,7 @@ def _sign_token_payload(claims: dict, settings: AuthSettings) -> str:


async def insert_refresh_token(
auth_db: AuthDB,
subject: str,
scope: str,
auth_db: AuthDB, subject: str, scope: str, policies: dict[str, Any]
) -> UUID:
"""Insert a refresh token into the database and return the JWT ID."""
# Generate a JWT ID
Expand All @@ -404,6 +419,7 @@ async def insert_refresh_token(
jti=jti,
subject=subject,
scope=scope,
policies=policies,
)
return jti

Expand Down
11 changes: 1 addition & 10 deletions diracx-routers/src/diracx/routers/access_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@
from fastapi import Depends

from diracx.core.extensions import DiracEntryPoint, select_from_extension
from diracx.core.models import (
AccessTokenPayload,
RefreshTokenPayload,
)
from diracx.core.settings import DevelopmentSettings
from diracx.routers.dependencies import auto_inject
from diracx.routers.utils import AuthorizedUserInfo, verify_dirac_access_token
Expand Down Expand Up @@ -90,15 +86,10 @@ async def policy(policy_name: str, user_info: AuthorizedUserInfo, /):
return

@staticmethod
def enrich_tokens(
access_payload: AccessTokenPayload, refresh_payload: RefreshTokenPayload | None
) -> tuple[dict, dict]:
def enrich_tokens() -> tuple[dict, dict]:
"""Add content to access or refresh payload when issuing a token.

Content can be whatever is desired inside the access or refresh payload.

:param access_payload: access token payload
:param refresh_payload: refresh token payload
:returns: extra content for both payload
"""
return {}, {}
Expand Down
29 changes: 5 additions & 24 deletions diracx-routers/src/diracx/routers/auth/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import logging
import os
from http import HTTPStatus
from typing import Annotated, Literal
Expand Down Expand Up @@ -35,42 +34,26 @@
from ..fastapi_classes import DiracxRouter

router = DiracxRouter(require_auth=False)
logger = logging.getLogger(__name__)


async def mint_token(
access_payload: AccessTokenPayload,
refresh_payload: RefreshTokenPayload | None,
existing_refresh_token: str | None,
all_access_policies: dict[str, BaseAccessPolicy],
settings: AuthSettings,
) -> TokenResponse:
"""Enrich the token with policy specific content and mint it."""
"""Mint the token."""
if not refresh_payload and not existing_refresh_token:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="Refresh token is not set and no refresh token was provided",
)

# Enrich the token with policy specific content
dirac_access_policies = {}
dirac_refresh_policies = {}
for policy_name, policy in all_access_policies.items():
access_extra, refresh_extra = policy.enrich_tokens(
access_payload, refresh_payload
)
if access_extra:
dirac_access_policies[policy_name] = access_extra
if refresh_extra:
dirac_refresh_policies[policy_name] = refresh_extra

# Create the access token
access_payload.dirac_policies = dirac_access_policies
access_token = create_token(access_payload, settings)

# Create the refresh token
if refresh_payload:
refresh_payload.dirac_policies = dirac_refresh_policies
refresh_token = create_token(refresh_payload, settings)
elif existing_refresh_token:
refresh_token = existing_refresh_token
Expand Down Expand Up @@ -133,6 +116,7 @@ async def get_oidc_token(
config,
settings,
available_properties,
policies=all_access_policies,
device_code=device_code,
code=code,
redirect_uri=redirect_uri,
Expand Down Expand Up @@ -163,9 +147,7 @@ async def get_oidc_token(
status_code=HTTPStatus.FORBIDDEN,
detail=str(e),
) from e
return await mint_token(
access_payload, refresh_payload, refresh_token, all_access_policies, settings
)
return await mint_token(access_payload, refresh_payload, refresh_token, settings)


BASE_64_URL_SAFE_PATTERN = (
Expand Down Expand Up @@ -206,6 +188,7 @@ async def perform_legacy_exchange(
expected_api_key=expected_api_key,
preferred_username=preferred_username,
scope=scope,
policies=all_access_policies,
authorization=authorization,
auth_db=auth_db,
available_properties=available_properties,
Expand All @@ -229,6 +212,4 @@ async def perform_legacy_exchange(
status_code=HTTPStatus.FORBIDDEN,
detail=str(e),
) from e
return await mint_token(
access_payload, refresh_payload, None, all_access_policies, settings
)
return await mint_token(access_payload, refresh_payload, None, settings)
10 changes: 4 additions & 6 deletions diracx-testing/src/diracx/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from uuid_utils import uuid7

from diracx.core.extensions import DiracEntryPoint
from diracx.core.models import AccessTokenPayload, RefreshTokenPayload

if TYPE_CHECKING:
from diracx.core.settings import (
Expand Down Expand Up @@ -188,11 +187,10 @@ async def policy(
pass

@staticmethod
def enrich_tokens(
access_payload: AccessTokenPayload,
refresh_payload: RefreshTokenPayload | None,
):
return {"PolicySpecific": "OpenAccessForTest"}, {}
def enrich_tokens():
return {"PolicySpecific": "OpenAccessForTest"}, {
"PolicySpecific": "OpenRefreshForTest"
}

enabled_systems = {
e.name for e in select_from_extension(group=DiracEntryPoint.SERVICES)
Expand Down
Loading