diff --git a/DashAI/alembic/versions/d4e8a2c6f0b1_add_credential_table.py b/DashAI/alembic/versions/d4e8a2c6f0b1_add_credential_table.py new file mode 100644 index 000000000..692d87700 --- /dev/null +++ b/DashAI/alembic/versions/d4e8a2c6f0b1_add_credential_table.py @@ -0,0 +1,41 @@ +"""Add credential table + +Revision ID: d4e8a2c6f0b1 +Revises: f1a2b3c4d5e6 +Create Date: 2026-06-15 00:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "d4e8a2c6f0b1" +down_revision: Union[str, None] = "f1a2b3c4d5e6" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "credential", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("encrypted_key", sa.Text(), nullable=False), + sa.Column( + "verified", + sa.Boolean(), + server_default="0", + nullable=False, + ), + sa.Column("created", sa.DateTime(), nullable=False), + sa.Column("last_modified", sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint("id", name="pk_credential"), + sa.UniqueConstraint("name", name="uq_credential_name"), + ) + + +def downgrade() -> None: + op.drop_table("credential") diff --git a/DashAI/back/api/api_v1/api.py b/DashAI/back/api/api_v1/api.py index 1beb875a5..50a11bed7 100644 --- a/DashAI/back/api/api_v1/api.py +++ b/DashAI/back/api/api_v1/api.py @@ -2,6 +2,7 @@ from DashAI.back.api.api_v1.endpoints.components import router as components from DashAI.back.api.api_v1.endpoints.converters import router as converters +from DashAI.back.api.api_v1.endpoints.credentials import router as credentials from DashAI.back.api.api_v1.endpoints.datafile import router as datafile_router from DashAI.back.api.api_v1.endpoints.dataset_source import router as dataset_source from DashAI.back.api.api_v1.endpoints.datasets import router as datasets @@ -46,3 +47,4 @@ api_router_v1.include_router(dataset_source, prefix="/dataset-source") api_router_v1.include_router(datafile_router, prefix="/datafile") api_router_v1.include_router(folders, prefix="/folder") +api_router_v1.include_router(credentials, prefix="/credential") diff --git a/DashAI/back/api/api_v1/endpoints/credentials.py b/DashAI/back/api/api_v1/endpoints/credentials.py new file mode 100644 index 000000000..b99ad1eab --- /dev/null +++ b/DashAI/back/api/api_v1/endpoints/credentials.py @@ -0,0 +1,265 @@ +"""Credential API endpoints.""" + +import logging +from typing import TYPE_CHECKING, Any, Dict, List, Union + +from fastapi import APIRouter, Depends, Header, status +from fastapi.exceptions import HTTPException +from kink import di +from pydantic import BaseModel + +from DashAI.back.credentials.sync import sync_credentials_status + +if TYPE_CHECKING: + from DashAI.back.dependencies.registry import ComponentRegistry + +log = logging.getLogger(__name__) +router = APIRouter() + + +class AuthRequest(BaseModel): + """Request body for authenticating a credential. + + Parameters + ---------- + key : str + The platform key/token to validate and store. + """ + + key: str + + +def _credential_components(registry: "ComponentRegistry") -> Dict[str, Dict[str, Any]]: + """Return the registry's Credential-type components. + + Parameters + ---------- + registry : ComponentRegistry + The component registry. + + Returns + ------- + Dict[str, Dict[str, Any]] + Mapping of credential name to component dict. + """ + return registry._registry.get("Credential", {}) + + +def _localize(value: Any, language: Union[str, None]) -> Union[str, None]: + """Resolve a possibly-multilingual value to a plain string. + + Parameters + ---------- + value : Any + A ``MultilingualString`` or plain value. + language : Union[str, None] + The ``Accept-Language`` header value, or None. + + Returns + ------- + Union[str, None] + The localized string, or the value unchanged when not multilingual. + """ + if hasattr(value, "get"): + lang_code = language.split("-")[0].lower() if language else "en" + return value.get(lang_code) + return value + + +def _status_payload( + name: str, + component_dict: Dict[str, Any], + is_authenticated: bool, + key: Union[str, None], + language: Union[str, None] = None, +) -> Dict[str, Any]: + """Build the status payload for a credential. + + The payload bundles the catalog fields (display name, description) with the + authentication state in a single object so the configuration modal can be + populated with one request. The stored key is included so the modal can + display it, which is acceptable for DashAI's local-first, single-user + desktop model where the database already lives on the user's machine. + + Parameters + ---------- + name : str + Credential component name. + component_dict : Dict[str, Any] + The registry component dict. + is_authenticated : bool + Whether the credential is currently verified. + key : Union[str, None] + The stored decrypted key, or None if nothing is stored. + language : Union[str, None] + The ``Accept-Language`` header used to localize text fields. + + Returns + ------- + Dict[str, Any] + Status payload including localized display name, description and key. + """ + display_name = _localize(component_dict.get("display_name"), language) + description = _localize(component_dict.get("description"), language) + return { + "name": name, + "display_name": display_name or name, + "description": description or "", + "is_authenticated": is_authenticated, + "key": key, + } + + +@router.get("/") +async def list_credentials( + accept_language: Union[str, None] = Header(default=None), + registry: "ComponentRegistry" = Depends(lambda: di["component_registry"]), +) -> List[Dict[str, Any]]: + """List all credential components with their authentication status. + + Returns catalog metadata and auth state together in a single response so + the configuration modal does not need one request per credential. + + Parameters + ---------- + accept_language : Union[str, None] + The 'Accept-Language' header used to localize text fields. + registry : ComponentRegistry + Injected component registry. + + Returns + ------- + list[dict] + Credential status payloads. + """ + creds = _credential_components(registry) + store = di["credential_store"] + statuses = store.all_statuses() + return [ + _status_payload( + name, cdict, statuses.get(name, False), store.load(name), accept_language + ) + for name, cdict in creds.items() + ] + + +@router.get("/{name}") +async def get_credential_status( + name: str, + accept_language: Union[str, None] = Header(default=None), + registry: "ComponentRegistry" = Depends(lambda: di["component_registry"]), +) -> Dict[str, Any]: + """Return the status of a single credential. + + Parameters + ---------- + name : str + Credential component name. + accept_language : Union[str, None] + The 'Accept-Language' header used to localize text fields. + registry : ComponentRegistry + Injected component registry. + + Returns + ------- + dict + Status payload. + + Raises + ------ + HTTPException + 404 if the credential is not registered. + """ + creds = _credential_components(registry) + if name not in creds: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Credential '{name}' not found.", + ) + store = di["credential_store"] + return _status_payload( + name, creds[name], store.is_verified(name), store.load(name), accept_language + ) + + +@router.post("/{name}/auth") +async def authenticate_credential( + name: str, + body: AuthRequest, + registry: "ComponentRegistry" = Depends(lambda: di["component_registry"]), +) -> Dict[str, Any]: + """Verify and store a credential key. + + Parameters + ---------- + name : str + Credential component name. + body : AuthRequest + Contains the key to authenticate with. + registry : ComponentRegistry + Injected component registry. + + Returns + ------- + dict + ``{"is_authenticated": True}`` on success. + + Raises + ------ + HTTPException + 404 if the credential is unknown, 400 if the key is invalid. + """ + creds = _credential_components(registry) + if name not in creds: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Credential '{name}' not found.", + ) + credential = creds[name]["class"]() + try: + credential.auth(body.key) + except ValueError as exc: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid credential key.", + ) from exc + + affected = registry.get_required_credentials(name) + sync_credentials_status(only=affected) + return {"is_authenticated": True} + + +@router.delete("/{name}") +async def delete_credential( + name: str, + registry: "ComponentRegistry" = Depends(lambda: di["component_registry"]), +) -> Dict[str, Any]: + """Remove a stored credential key. + + Parameters + ---------- + name : str + Credential component name. + registry : ComponentRegistry + Injected component registry. + + Returns + ------- + dict + ``{"is_authenticated": False}``. + + Raises + ------ + HTTPException + 404 if the credential is not registered. + """ + creds = _credential_components(registry) + if name not in creds: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Credential '{name}' not found.", + ) + di["credential_store"].delete(name) + affected = registry.get_required_credentials(name) + sync_credentials_status(only=affected) + return {"is_authenticated": False} diff --git a/DashAI/back/api/api_v1/endpoints/plugins.py b/DashAI/back/api/api_v1/endpoints/plugins.py index e8cb1cd67..0f35d0783 100644 --- a/DashAI/back/api/api_v1/endpoints/plugins.py +++ b/DashAI/back/api/api_v1/endpoints/plugins.py @@ -243,6 +243,7 @@ async def update_plugin( Plugin The updated plugin. """ + from DashAI.back.credentials.sync import sync_credentials_status from DashAI.back.plugins.utils import ( install_plugin, register_plugin_components, @@ -278,6 +279,7 @@ async def update_plugin( # else the new components should be registered else: register_plugin_components(installed_components, component_registry) + sync_credentials_status() job_queue.put(SyncComponentsJob()) elif ( plugin.status == PluginStatus.INSTALLED diff --git a/DashAI/back/config.py b/DashAI/back/config.py index 728c8a760..04931fc65 100644 --- a/DashAI/back/config.py +++ b/DashAI/back/config.py @@ -22,3 +22,4 @@ class DefaultSettings(BaseSettings): EXPLANATIONS_PATH: str = "explanations" NOTEBOOK_PATH: str = "notebook" DATAFILE_PATH: str = "datafiles" + CREDENTIALS_KEY_PATH: str = ".credentials_key" diff --git a/DashAI/back/config_object.py b/DashAI/back/config_object.py index b994e8e93..8ec29fcee 100644 --- a/DashAI/back/config_object.py +++ b/DashAI/back/config_object.py @@ -1,3 +1,5 @@ +from kink import di + from DashAI.back.core.schema_fields.base_schema import ( BaseSchema, replace_defs_in_schema, @@ -49,3 +51,24 @@ def validate_and_transform(self, raw_data: dict) -> dict: """ schema_instance = self.SCHEMA.model_validate(raw_data) return fill_objects(schema_instance) + + def get_credential(self, name: str): + """Resolve a registered credential component by name. + + The returned instance exposes ``get_key``, ``is_authenticated`` and + ``apply``. When nothing is stored, ``get_key`` returns None and + ``apply`` is a no-op, so optional credentials degrade gracefully. + + Parameters + ---------- + name : str + Credential component class name (e.g. "HuggingFaceCredential"). + + Returns + ------- + BaseCredential + An instance of the requested credential component. + """ + registry = di["component_registry"] + credential_class = registry[name]["class"] + return credential_class() diff --git a/DashAI/back/container.py b/DashAI/back/container.py index 23a4274ce..49b8d1a27 100644 --- a/DashAI/back/container.py +++ b/DashAI/back/container.py @@ -1,8 +1,12 @@ import logging +import os +import pathlib from typing import Dict from kink import Container, di +from DashAI.back.credentials.encryptor import CredentialEncryptor, load_or_create_key +from DashAI.back.credentials.store import CredentialStore from DashAI.back.dependencies.database import setup_sqlite_db from DashAI.back.dependencies.job_queues.huey_job_queue import HueyJobQueue from DashAI.back.dependencies.registry import ComponentRegistry @@ -38,6 +42,13 @@ def build_container(config: Dict[str, str]) -> Container: di["component_registry"] = ComponentRegistry( initial_components=config["INITIAL_COMPONENTS"] ) + credentials_key = load_or_create_key( + pathlib.Path(config["CREDENTIALS_KEY_PATH"]), + env_value=os.getenv("DASHAI_CREDENTIALS_SECRET"), + ) + encryptor = CredentialEncryptor(credentials_key) + di["credential_encryptor"] = encryptor + di["credential_store"] = CredentialStore(session_factory, encryptor) job_queue = HueyJobQueue("job_queue", path_db=config["LOCAL_PATH"]) di["job_queue"] = job_queue diff --git a/DashAI/back/credentials/__init__.py b/DashAI/back/credentials/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/DashAI/back/credentials/base_credential.py b/DashAI/back/credentials/base_credential.py new file mode 100644 index 000000000..0f1574030 --- /dev/null +++ b/DashAI/back/credentials/base_credential.py @@ -0,0 +1,96 @@ +"""Base class for DashAI platform credentials.""" + +from abc import ABC, abstractmethod +from typing import Final, Union + +from kink import di + +from DashAI.back.config_object import ConfigObject +from DashAI.back.core.utils import MultilingualString + + +class BaseCredential(ConfigObject, ABC): + """Abstract base class for all DashAI credentials. + + A credential authenticates against an external platform with a key and + persists it (encrypted) so components that declare it in + ``REQUIRED_CREDENTIALS`` or ``OPTIONAL_CREDENTIALS`` can use it. + + Subclasses only implement :meth:`verify` (the platform-specific network + check) and optionally :meth:`apply` (push the key into the platform SDK). + """ + + TYPE: Final[str] = "Credential" + DISPLAY_NAME: Union[str, MultilingualString] = "" + DESCRIPTION: Union[str, MultilingualString] = "" + ICON: str = "Key" + + @abstractmethod + def verify(self, key: str) -> bool: + """Check a key against the platform. + + Parameters + ---------- + key : str + The key to validate. + + Returns + ------- + bool + True if the key is valid. + """ + raise NotImplementedError + + def auth(self, key: str) -> bool: + """Validate and persist a key. + + Parameters + ---------- + key : str + The key to authenticate with. + + Returns + ------- + bool + True on success. + + Raises + ------ + ValueError + If the key fails verification. + """ + if not self.verify(key): + raise ValueError( + f"Invalid credential for {type(self).__name__}: verification failed." + ) + di["credential_store"].save(type(self).__name__, key) + return True + + def get_key(self) -> Union[str, None]: + """Return the stored decrypted key, or None. + + Returns + ------- + Union[str, None] + The decrypted key, or None if not authenticated. + """ + return di["credential_store"].load(type(self).__name__) + + def is_authenticated(self) -> bool: + """Return whether a verified key is stored. + + Returns + ------- + bool + True if authenticated. + """ + return di["credential_store"].is_verified(type(self).__name__) + + def apply(self) -> None: + """Push the stored key into the platform SDK if present. + + The default implementation is a no-op, which makes a credential safe to + ``apply()`` even when unauthenticated (used by ``OPTIONAL_CREDENTIALS``). + Override in subclasses that need to log in to an SDK. + """ + return None diff --git a/DashAI/back/credentials/encryptor.py b/DashAI/back/credentials/encryptor.py new file mode 100644 index 000000000..1efd4fed2 --- /dev/null +++ b/DashAI/back/credentials/encryptor.py @@ -0,0 +1,102 @@ +"""Symmetric encryption for stored credential keys.""" + +import logging +import os +import stat +from pathlib import Path +from typing import Union + +from cryptography.fernet import Fernet + +logger = logging.getLogger(__name__) + + +def load_or_create_key( + key_path: Path, + env_value: Union[str, None] = None, + persist: bool = True, +) -> bytes: + """Resolve the Fernet secret key. + + Resolution order: explicit ``env_value`` first, then an existing file at + ``key_path``, otherwise a freshly generated key (persisted to ``key_path`` + when ``persist`` is True). + + Parameters + ---------- + key_path : Path + Location of the on-disk key file. + env_value : Union[str, None] + Key provided via environment variable, if any. + persist : bool + Whether to write a newly generated key to disk, by default True. + + Returns + ------- + bytes + The Fernet key as bytes. + """ + if env_value: + return env_value.encode() + + if key_path.exists(): + return key_path.read_bytes() + + key = Fernet.generate_key() + if persist: + key_path.parent.mkdir(parents=True, exist_ok=True) + key_path.write_bytes(key) + try: + os.chmod(key_path, stat.S_IRUSR | stat.S_IWUSR) + except OSError: + logger.warning("Could not restrict permissions on %s", key_path) + return key + + +class CredentialEncryptor: + """Encrypts and decrypts credential keys with Fernet.""" + + def __init__(self, key: bytes) -> None: + """Initialize the encryptor. + + Parameters + ---------- + key : bytes + A valid Fernet key. + """ + self._fernet = Fernet(key) + + def encrypt(self, plaintext: str) -> str: + """Encrypt a plaintext secret. + + Parameters + ---------- + plaintext : str + The secret to encrypt. + + Returns + ------- + str + The encrypted token. + """ + return self._fernet.encrypt(plaintext.encode()).decode() + + def decrypt(self, token: str) -> str: + """Decrypt a token produced by :meth:`encrypt`. + + Parameters + ---------- + token : str + The encrypted token. + + Returns + ------- + str + The decrypted plaintext. + + Raises + ------ + cryptography.fernet.InvalidToken + If the token is invalid or was encrypted with a different key. + """ + return self._fernet.decrypt(token.encode()).decode() diff --git a/DashAI/back/credentials/github_credential.py b/DashAI/back/credentials/github_credential.py new file mode 100644 index 000000000..eb18eafc2 --- /dev/null +++ b/DashAI/back/credentials/github_credential.py @@ -0,0 +1,55 @@ +"""GitHub credential.""" + +import logging +from typing import Final + +from DashAI.back.core.utils import MultilingualString +from DashAI.back.credentials.base_credential import BaseCredential + +logger = logging.getLogger(__name__) + + +class GithubCredential(BaseCredential): + """Credential for the GitHub API.""" + + DISPLAY_NAME: Final = MultilingualString( + en="GitHub", + es="GitHub", + pt="GitHub", + de="GitHub", + zh="GitHub", + ) + DESCRIPTION: Final = MultilingualString( + en="Personal access token for the GitHub API.", + es="Token de acceso personal para la API de GitHub.", + pt="Token de acesso pessoal para a API do GitHub.", + de="Persönliches Zugriffstoken für die GitHub API.", + zh="用于 GitHub API 的个人访问令牌。", + ) + ICON: str = "Key" + + def verify(self, key: str) -> bool: + """Validate a GitHub token via the ``/user`` endpoint. + + Parameters + ---------- + key : str + GitHub personal access token. + + Returns + ------- + bool + True if the token is valid. + """ + import requests + + try: + response = requests.get( + "https://api.github.com/user", + headers={"Authorization": f"Bearer {key}"}, + timeout=10, + ) + return response.status_code == 200 + except Exception as exc: + logger.info("GitHub credential verification failed: %s", exc) + return False diff --git a/DashAI/back/credentials/huggingface_credential.py b/DashAI/back/credentials/huggingface_credential.py new file mode 100644 index 000000000..e3aea0015 --- /dev/null +++ b/DashAI/back/credentials/huggingface_credential.py @@ -0,0 +1,65 @@ +"""HuggingFace Hub credential.""" + +import logging +from typing import Final + +from DashAI.back.core.utils import MultilingualString +from DashAI.back.credentials.base_credential import BaseCredential + +logger = logging.getLogger(__name__) + + +class HuggingFaceCredential(BaseCredential): + """Credential for the HuggingFace Hub.""" + + DISPLAY_NAME: Final = MultilingualString( + en="HuggingFace", + es="HuggingFace", + pt="HuggingFace", + de="HuggingFace", + zh="HuggingFace", + ) + DESCRIPTION: Final = MultilingualString( + en="Access token for the HuggingFace Hub. Required for gated models and " + "datasets.", + es="Token de acceso para el HuggingFace Hub. Necesario para modelos y " + "conjuntos de datos restringidos.", + pt="Token de acesso para o HuggingFace Hub. Necessário para modelos e " + "conjuntos de dados restritos.", + de="Zugriffstoken für den HuggingFace Hub. Erforderlich für " + "eingeschränkte Modelle und Datensätze.", + zh="用于 HuggingFace Hub 的访问令牌。访问受限模型和数据集时需要。", + ) + ICON: str = "Key" + + def verify(self, key: str) -> bool: + """Validate a HuggingFace token via ``whoami``. + + Parameters + ---------- + key : str + HuggingFace access token. + + Returns + ------- + bool + True if the token is valid. + """ + from huggingface_hub import HfApi + + try: + HfApi().whoami(token=key) + return True + except Exception as exc: + logger.info("HuggingFace credential verification failed: %s", exc) + return False + + def apply(self) -> None: + """Log in to the HuggingFace Hub if a key is stored.""" + key = self.get_key() + if not key: + return None + from huggingface_hub import login + + login(token=key) + return None diff --git a/DashAI/back/credentials/kaggle_credential.py b/DashAI/back/credentials/kaggle_credential.py new file mode 100644 index 000000000..b1e6bca34 --- /dev/null +++ b/DashAI/back/credentials/kaggle_credential.py @@ -0,0 +1,108 @@ +"""Kaggle credential.""" + +import logging +import os +from typing import Final + +from DashAI.back.core.utils import MultilingualString +from DashAI.back.credentials.base_credential import BaseCredential + +logger = logging.getLogger(__name__) + + +class KaggleCredential(BaseCredential): + """Credential for the Kaggle API. + + The key is expected in the form ``"username:api_key"``. + """ + + DISPLAY_NAME: Final = MultilingualString( + en="Kaggle", + es="Kaggle", + pt="Kaggle", + de="Kaggle", + zh="Kaggle", + ) + DESCRIPTION: Final = MultilingualString( + en="Kaggle API credential in the form 'username:key'.", + es="Credencial de la API de Kaggle en el formato 'usuario:clave'.", + pt="Credencial da API do Kaggle no formato 'usuario:chave'.", + de="Zugangsdaten für die Kaggle API im Format 'benutzername:schluessel'.", + zh="Kaggle API 凭证,格式为 'username:key'。", + ) + ICON: str = "Key" + + @staticmethod + def _split_key(key: str): + """Split a ``"username:api_key"`` credential into its parts. + + Parameters + ---------- + key : str + Kaggle credential in the form ``"username:api_key"``. + + Returns + ------- + tuple[str, str] or None + ``(username, api_key)`` if well formed, otherwise None. + """ + username, separator, api_key = key.partition(":") + if not separator or not username or not api_key: + return None + return username, api_key + + def verify(self, key: str) -> bool: + """Validate a Kaggle credential with the official ``kaggle`` library. + + The credentials are exported to the environment before importing + ``kaggle``, because the package authenticates at import time and + terminates the process when no credentials are available. + + Parameters + ---------- + key : str + Kaggle credential in the form ``"username:api_key"``. + + Returns + ------- + bool + True if the credential authenticates successfully. + """ + parts = self._split_key(key) + if parts is None: + return False + username, api_key = parts + + os.environ["KAGGLE_USERNAME"] = username + os.environ["KAGGLE_KEY"] = api_key + try: + from kaggle.api.kaggle_api_extended import KaggleApi + + api = KaggleApi() + api.authenticate() + # Perform an authenticated call to confirm the key is valid. + api.competitions_list() + return True + except SystemExit: + return False + except Exception as exc: + logger.info("Kaggle credential verification failed: %s", exc) + return False + + def apply(self) -> None: + """Export the stored Kaggle credentials to the environment. + + The official ``kaggle`` library reads ``KAGGLE_USERNAME`` and + ``KAGGLE_KEY`` from the environment, so exporting them makes any later + use of the library authenticated. No-op when nothing is stored. + """ + key = self.get_key() + if not key: + return None + parts = self._split_key(key) + if parts is None: + return None + username, api_key = parts + os.environ["KAGGLE_USERNAME"] = username + os.environ["KAGGLE_KEY"] = api_key + return None diff --git a/DashAI/back/credentials/store.py b/DashAI/back/credentials/store.py new file mode 100644 index 000000000..1a3047f5a --- /dev/null +++ b/DashAI/back/credentials/store.py @@ -0,0 +1,115 @@ +"""Persistence boundary for encrypted credentials.""" + +import logging +from datetime import datetime +from typing import Dict, Union + +from DashAI.back.credentials.encryptor import CredentialEncryptor +from DashAI.back.dependencies.database.models import Credential + +logger = logging.getLogger(__name__) + + +class CredentialStore: + """Reads and writes encrypted credentials in the database. + + This is the only component that touches the credential table and the + encryptor. + """ + + def __init__(self, session_factory, encryptor: CredentialEncryptor) -> None: + """Initialize the store. + + Parameters + ---------- + session_factory + SQLAlchemy session factory (callable returning a session). + encryptor : CredentialEncryptor + Encryptor used to protect keys at rest. + """ + self._session_factory = session_factory + self._encryptor = encryptor + + def save(self, name: str, key: str) -> None: + """Encrypt and persist a credential key, marking it verified. + + Parameters + ---------- + name : str + Credential component name. + key : str + Plaintext key to store. + """ + encrypted = self._encryptor.encrypt(key) + with self._session_factory() as db: + row = db.query(Credential).filter_by(name=name).first() + if row is None: + row = Credential(name=name, encrypted_key=encrypted, verified=True) + db.add(row) + else: + row.encrypted_key = encrypted + row.verified = True + row.last_modified = datetime.now() + db.commit() + + def load(self, name: str) -> Union[str, None]: + """Return the decrypted key for a credential, or None. + + Parameters + ---------- + name : str + Credential component name. + + Returns + ------- + Union[str, None] + Decrypted key, or None if not stored. + """ + with self._session_factory() as db: + row = db.query(Credential).filter_by(name=name).first() + if row is None: + return None + return self._encryptor.decrypt(row.encrypted_key) + + def is_verified(self, name: str) -> bool: + """Return whether a credential is stored and verified. + + Parameters + ---------- + name : str + Credential component name. + + Returns + ------- + bool + True if a verified key exists. + """ + with self._session_factory() as db: + row = db.query(Credential).filter_by(name=name).first() + return bool(row and row.verified) + + def delete(self, name: str) -> None: + """Remove a stored credential. + + Parameters + ---------- + name : str + Credential component name. + """ + with self._session_factory() as db: + row = db.query(Credential).filter_by(name=name).first() + if row is not None: + db.delete(row) + db.commit() + + def all_statuses(self) -> Dict[str, bool]: + """Return the verified status of every stored credential. + + Returns + ------- + Dict[str, bool] + Mapping of credential name to verified status. + """ + with self._session_factory() as db: + rows = db.query(Credential).all() + return {row.name: bool(row.verified) for row in rows} diff --git a/DashAI/back/credentials/sync.py b/DashAI/back/credentials/sync.py new file mode 100644 index 000000000..27fbce8f2 --- /dev/null +++ b/DashAI/back/credentials/sync.py @@ -0,0 +1,22 @@ +"""Synchronize credential availability flags on the component registry.""" + +import logging +from typing import List, Union + +from kink import di + +logger = logging.getLogger(__name__) + + +def sync_credentials_status(only: Union[List[str], None] = None) -> None: + """Refresh ``credentials_satisfied`` flags from stored credential statuses. + + Parameters + ---------- + only : Union[List[str], None] + If provided, only these component names are recomputed. If None, all + components are recomputed. + """ + store = di["credential_store"] + registry = di["component_registry"] + registry.refresh_credentials_status(store.all_statuses(), only=only) diff --git a/DashAI/back/dataset_sources/huggingface_dataset_source.py b/DashAI/back/dataset_sources/huggingface_dataset_source.py index aceeebb29..e026213a2 100644 --- a/DashAI/back/dataset_sources/huggingface_dataset_source.py +++ b/DashAI/back/dataset_sources/huggingface_dataset_source.py @@ -23,6 +23,8 @@ class HuggingFaceDatasetSource(BaseDatasetSource): offset and slicing the iterator. """ + OPTIONAL_CREDENTIALS = ["HuggingFaceCredential"] + DISPLAY_NAME: Final = MultilingualString( en="HuggingFace Hub", es="HuggingFace Hub", @@ -186,6 +188,10 @@ def download_dataset(self, dataset_id: str, temp_path: str) -> str: """ from huggingface_hub import snapshot_download + # Apply the HuggingFace credential if one is stored; this is a no-op for + # public datasets and only matters for gated/private ones. + self.get_credential("HuggingFaceCredential").apply() + snapshot_download( repo_id=dataset_id, repo_type="dataset", diff --git a/DashAI/back/dependencies/config_builder.py b/DashAI/back/dependencies/config_builder.py index 3323340b9..a8afe53b5 100644 --- a/DashAI/back/dependencies/config_builder.py +++ b/DashAI/back/dependencies/config_builder.py @@ -59,6 +59,7 @@ def build_config_dict( config["RUNS_PATH"] = local_path / config["RUNS_PATH"] config["IMAGES_PATH"] = local_path / config["IMAGES_PATH"] config["DATAFILE_PATH"] = local_path / config["DATAFILE_PATH"] + config["CREDENTIALS_KEY_PATH"] = local_path / config["CREDENTIALS_KEY_PATH"] config["FRONT_BUILD_PATH"] = pathlib.Path(config["FRONT_BUILD_PATH"]).absolute() config["BACK_PATH"] = pathlib.Path(config["BACK_PATH"]).absolute() config["LOGGING_LEVEL"] = getattr(logging, logging_level) diff --git a/DashAI/back/dependencies/database/models.py b/DashAI/back/dependencies/database/models.py index cdf57d968..3ec231584 100644 --- a/DashAI/back/dependencies/database/models.py +++ b/DashAI/back/dependencies/database/models.py @@ -778,3 +778,22 @@ class Datafile(Base): name="uq_datafile_source_dataset", ), ) + + +class Credential(Base): + __tablename__ = "credential" + """ + Table to store encrypted credentials for external platforms. + """ + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String, unique=True, nullable=False) + encrypted_key: Mapped[str] = mapped_column(Text, nullable=False) + verified: Mapped[bool] = mapped_column( + Boolean, nullable=False, default=False, server_default="0" + ) + created: Mapped[DateTime] = mapped_column(DateTime, default=datetime.now) + last_modified: Mapped[DateTime] = mapped_column( + DateTime, + default=datetime.now, + onupdate=datetime.now, + ) diff --git a/DashAI/back/dependencies/registry/component_registry.py b/DashAI/back/dependencies/registry/component_registry.py index d99145aaf..39cc71704 100644 --- a/DashAI/back/dependencies/registry/component_registry.py +++ b/DashAI/back/dependencies/registry/component_registry.py @@ -208,6 +208,13 @@ def register_component(self, new_component: Type) -> None: if isinstance(display_name, str): new_component.DISPLAY_NAME = MultilingualString(en=display_name) + required_credentials = list( + getattr(new_component, "REQUIRED_CREDENTIALS", []) or [] + ) + optional_credentials = list( + getattr(new_component, "OPTIONAL_CREDENTIALS", []) or [] + ) + new_register_component = { "name": new_component.__name__, "type": base_type, @@ -218,6 +225,9 @@ def register_component(self, new_component: Type) -> None: "description": getattr(new_component, "DESCRIPTION", None), "display_name": getattr(new_component, "DISPLAY_NAME", None), "color": getattr(new_component, "COLOR", None), + "required_credentials": required_credentials, + "optional_credentials": optional_credentials, + "credentials_satisfied": len(required_credentials) == 0, } if base_type not in self._registry: @@ -230,8 +240,19 @@ def register_component(self, new_component: Type) -> None: self._relationship_manager.add_relationship( new_component.__name__, compatible_component, + "compatible_components", ) + for credential_name in required_credentials: + self._relationship_manager.add_relationship( + new_component.__name__, credential_name, "required_credentials" + ) + + for credential_name in optional_credentials: + self._relationship_manager.add_relationship( + new_component.__name__, credential_name, "optional_credentials" + ) + @beartype def unregister_component(self, component: Type) -> None: """Remove a component from the registry. @@ -259,8 +280,19 @@ def unregister_component(self, component: Type) -> None: self._relationship_manager.remove_relationship( component.__name__, compatible_component, + "compatible_components", ) + for credential_name in getattr(component, "REQUIRED_CREDENTIALS", []) or []: + self._relationship_manager.remove_relationship( + component.__name__, credential_name, "required_credentials" + ) + + for credential_name in getattr(component, "OPTIONAL_CREDENTIALS", []) or []: + self._relationship_manager.remove_relationship( + component.__name__, credential_name, "optional_credentials" + ) + @beartype def get_components_by_types( self, @@ -453,5 +485,68 @@ def get_related_components(self, component_id: str) -> List[Dict[str, Any]]: return [ self.__getitem__(related_component_id) - for related_component_id in self._relationship_manager[component_id] + for related_component_id in self._relationship_manager.get( + component_id, "compatible_components" + ) ] + + @beartype + def get_required_credentials(self, component_id: str) -> List[str]: + """Return the names of credentials a component requires. + + Parameters + ---------- + component_id : str + A registered component name. + + Returns + ------- + List[str] + Names of required credential components (empty if none). + """ + return self._relationship_manager.get(component_id, "required_credentials") + + @beartype + def get_optional_credentials(self, component_id: str) -> List[str]: + """Return the names of credentials a component can optionally use. + + Parameters + ---------- + component_id : str + A registered component name. + + Returns + ------- + List[str] + Names of optional credential components (empty if none). + """ + return self._relationship_manager.get(component_id, "optional_credentials") + + @beartype + def refresh_credentials_status( + self, + statuses: Dict[str, bool], + only: Union[List[str], None] = None, + ) -> None: + """Recompute the ``credentials_satisfied`` flag of components. + + A component is satisfied when every credential in its + ``required_credentials`` is verified. Components with no required + credentials are always satisfied. + + Parameters + ---------- + statuses : Dict[str, bool] + Mapping of credential component name to verified status. + only : Union[List[str], None] + If provided, only these component names are recomputed. If None, + all components are recomputed. + """ + for type_registry in self._registry.values(): + for name, component_dict in type_registry.items(): + if only is not None and name not in only: + continue + required = component_dict.get("required_credentials", []) + component_dict["credentials_satisfied"] = all( + statuses.get(credential_name, False) for credential_name in required + ) diff --git a/DashAI/back/dependencies/registry/relationship_manager.py b/DashAI/back/dependencies/registry/relationship_manager.py index 03f21d8e6..d024c7428 100644 --- a/DashAI/back/dependencies/registry/relationship_manager.py +++ b/DashAI/back/dependencies/registry/relationship_manager.py @@ -8,37 +8,31 @@ logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +DEFAULT_RELATION_TYPE = "compatible_components" + class RelationshipManager: - """Class that implements a relationship registry between DashAI components. - - The registry is a pair of dicts (defaultdicts) that stores the relationships as a - dictionary where its keys are some class and its values a list of classes that are - related with the class. - - For example, a `_relation`that stores relations between - "TabularClassificationTask" and "SVM", "KNN" models and "CSVDataloader" loader - could be: - - ``` - { - "TabularClassificationTask": ["SVC", "KNN", "CSVDataloader", ...], - "SVC": ["TabularClassificationTask"], - "KNN": ["TabularClassificationTask"], - "CSVDataloader": ["TabularClassificationTask"], - } - ``` - Note that the relations are duplicated and hopefully, consistent between them. + """Registry of typed relationships between DashAI components. + Relations are stored as a nested mapping + ``{component_id: {relation_type: [related_component_id, ...]}}``. + Each relation is stored bidirectionally and scoped by ``relation_type`` + (for example ``"compatible_components"``, ``"required_credentials"`` or + ``"optional_credentials"``). """ def __init__(self) -> None: """Initialize the relationship manager.""" - self._relations: DefaultDict[str, List[str]] = defaultdict(list) + self._relations: DefaultDict[str, DefaultDict[str, List[str]]] = defaultdict( + lambda: defaultdict(list) + ) @property - def relations(self) -> Dict[str, List[str]]: - return dict(self._relations) + def relations(self) -> Dict[str, Dict[str, List[str]]]: + return { + component_id: dict(relations) + for component_id, relations in self._relations.items() + } @relations.setter def relations(self, _: Any) -> None: @@ -54,11 +48,12 @@ def relations(self, _: Any) -> None: @beartype def add_relationship( - self, first_component_id: str, second_component_id: str + self, + first_component_id: str, + second_component_id: str, + relation_type: str = DEFAULT_RELATION_TYPE, ) -> None: - """Add a new relation to the relationship manager. - - Note that the relation is bidirectional. + """Add a new bidirectional relation of the given type. Parameters ---------- @@ -66,16 +61,20 @@ def add_relationship( First component id or name. second_component_id : str Second component id or name. - + relation_type : str + The relation category, by default ``"compatible_components"``. """ - self._relations[first_component_id].append(second_component_id) - self._relations[second_component_id].append(first_component_id) + self._relations[first_component_id][relation_type].append(second_component_id) + self._relations[second_component_id][relation_type].append(first_component_id) @beartype def remove_relationship( - self, first_component_id: str, second_component_id: str + self, + first_component_id: str, + second_component_id: str, + relation_type: str = DEFAULT_RELATION_TYPE, ) -> None: - """Remove an existing relation to the relationship manager. + """Remove an existing relation of the given type. Parameters ---------- @@ -83,24 +82,26 @@ def remove_relationship( First component id or name. second_component_id : str Second component id or name. + relation_type : str + The relation category, by default ``"compatible_components"``. + Raises + ------ + ValueError + If the relation does not exist. """ try: - self._relations[first_component_id].remove(second_component_id) - except KeyError as e: + self._relations[first_component_id][relation_type].remove( + second_component_id + ) + self._relations[second_component_id][relation_type].remove( + first_component_id + ) + except ValueError as e: raise ValueError( - f"Error: Relationship between {first_component_id} and does " - f"not exist {second_component_id} in the registry. Exception: " - f"{e}" - ) from e - - try: - self._relations[second_component_id].remove(first_component_id) - except KeyError as e: - raise ValueError( - f"Error: Relationship between {second_component_id} and does " - f"not exist {first_component_id} in the registry. Exception: " - f"{e}" + f"Error: Relationship of type '{relation_type}' between " + f"{first_component_id} and {second_component_id} does not exist " + f"in the registry. Exception: {e}" ) from e logger.info( @@ -108,28 +109,47 @@ def remove_relationship( f"{first_component_id}, {second_component_id}" ) + @beartype + def get( + self, component_id: str, relation_type: str = DEFAULT_RELATION_TYPE + ) -> List[str]: + """Return the related component ids of a given type. + + Parameters + ---------- + component_id : str + A component name or id. + relation_type : str + The relation category, by default ``"compatible_components"``. + + Returns + ------- + list[str] + Related component ids, or an empty list if none exist. + """ + if component_id in self._relations: + return list(self._relations[component_id].get(relation_type, [])) + return [] + @beartype def __contains__(self, component_id: str) -> bool: - """Indicate if the relation manager contains a relationship. + """Indicate if the relation manager contains a component. Parameters ---------- component_id : str - The id of the component to be checked if a relationship exists or not. + The id of the component to check. Returns ------- bool - True if the relation exists, False otherwise. + True if the component has any relation, False otherwise. """ return component_id in self._relations @beartype - def __getitem__(self, component_id: str) -> List[str]: - """Obtain all stored relationships from a specific component. - - Return an empty list if the component id does not exists in the relationship - manager. + def __getitem__(self, component_id: str) -> Dict[str, List[str]]: + """Obtain all stored relationships for a component, grouped by type. Parameters ---------- @@ -138,10 +158,10 @@ def __getitem__(self, component_id: str) -> List[str]: Returns ------- - list[str] - A list with the related components. + dict[str, list[str]] + Mapping of relation type to related component ids. Empty dict if + the component is unknown. """ if component_id in self._relations: - return self._relations[component_id] - - return [] + return dict(self._relations[component_id]) + return {} diff --git a/DashAI/back/initial_components.py b/DashAI/back/initial_components.py index 49ea258f1..29e62d866 100644 --- a/DashAI/back/initial_components.py +++ b/DashAI/back/initial_components.py @@ -61,6 +61,10 @@ from DashAI.back.converters.simple_converters.column_remover import ColumnRemover from DashAI.back.converters.simple_converters.nan_remover import NanRemover +# Credentials +from DashAI.back.credentials.huggingface_credential import HuggingFaceCredential +from DashAI.back.credentials.kaggle_credential import KaggleCredential + # DataLoaders from DashAI.back.dataloaders.classes.arff_dataloader import ARFFDataLoader from DashAI.back.dataloaders.classes.csv_dataloader import CSVDataLoader @@ -402,6 +406,9 @@ def get_initial_components(): HuggingFaceDatasetSource, OpenMLDatasetSource, ZenodoDatasetSource, + # Credentials + HuggingFaceCredential, + KaggleCredential, # Metrics F1, Accuracy, diff --git a/DashAI/back/models/hugging_face/stable_diffusion_v3_model.py b/DashAI/back/models/hugging_face/stable_diffusion_v3_model.py index 34685fd51..e4ab2b8bd 100644 --- a/DashAI/back/models/hugging_face/stable_diffusion_v3_model.py +++ b/DashAI/back/models/hugging_face/stable_diffusion_v3_model.py @@ -18,8 +18,8 @@ class StableDiffusionSchema(BaseSchema): """Configuration schema for Stable Diffusion V3 text-to-image generation. - Configures the checkpoint variant (``model_name``), HuggingFace access key - (``huggingface_key``), prompt conditioning (``negative_prompt``), + Configures the checkpoint variant (``model_name``), + prompt conditioning (``negative_prompt``), denoising schedule (``num_inference_steps``), prompt adherence (``guidance_scale``), output dimensions (``width``, ``height``), reproducibility (``seed``), hardware target (``device``), and batch size @@ -87,52 +87,6 @@ class StableDiffusionSchema(BaseSchema): ), ) # type: ignore - huggingface_key: schema_field( - string_field(), - placeholder="", - description=MultilingualString( - en=( - "Hugging Face read-access token required to download these gated " - "models. To obtain one: accept the model license on " - "huggingface.co/stabilityai, then go to Settings → Access Tokens " - "and generate a token with 'Read' scope." - ), - es=( - "Token de acceso de lectura de Hugging Face necesario para descargar " - "estos modelos protegidos. Para obtenerlo: acepte la licencia del " - "modelo en huggingface.co/stabilityai, luego vaya a " - "Configuración → Tokens de Acceso y genere un token con alcance " - "'Read'." - ), - pt=( - "Token de acesso de leitura do Hugging Face necessário para baixar " - "esses modelos protegidos. Para obtê-lo: aceite a licença do " - "modelo em huggingface.co/stabilityai, depois vá em " - "Configurações → Tokens de Acesso e gere um token com escopo " - "'Read'." - ), - de=( - "Hugging Face Lesezugriffs-Token, der zum Herunterladen dieser " - "geschützten Modelle erforderlich ist. So erhalten Sie ihn: Akzeptieren" - "Sie die Modell-Lizenz auf huggingface.co/stabilityai, dann gehen Sie " - "zu Einstellungen → Zugriffstoken und generieren Sie einen Token " - "mit 'Read'-Umfang." - ), - zh=( - "下载受限模型所需的 Hugging Face 只读访问令牌。获取方式:在 " - "huggingface.co/stabilityai 接受模型许可证,然后进入" - "设置 → 访问令牌,生成具有 'Read' 权限的令牌。" - ), - ), - alias=MultilingualString( - en="Hugging Face key", - es="Clave Hugging Face", - pt="Chave Hugging Face", - de="Hugging Face-Schlüssel", - zh="Hugging Face 密钥", - ), - ) # type: ignore - negative_prompt: Optional[ schema_field( string_field(), @@ -477,6 +431,7 @@ class StableDiffusionV3Model(TextToImageGenerationTaskModel): """ SCHEMA = StableDiffusionSchema + REQUIRED_CREDENTIALS = ["HuggingFaceCredential"] COLOR: str = "#6a1b9a" DISPLAY_NAME: str = MultilingualString( en="Stable Diffusion V3", @@ -537,7 +492,6 @@ def __init__(self, **kwargs): import torch from diffusers import DiffusionPipeline - from huggingface_hub import login kwargs = self.validate_and_transform(kwargs) use_gpu = DEVICE_TO_IDX.get(kwargs.get("device")) >= 0 @@ -547,15 +501,9 @@ def __init__(self, **kwargs): self.model_name = kwargs.get( "model_name", "stabilityai/stable-diffusion-3-medium-diffusers" ) - self.huggingface_key = kwargs.get("huggingface_key") - - if self.huggingface_key: - try: - login(token=self.huggingface_key) - except Exception as e: - raise ValueError( - "Failed to login to Hugging Face. Please check your API key." - ) from e + # Log in to HuggingFace using the stored credential so the gated + # checkpoints can be downloaded. + self.get_credential("HuggingFaceCredential").apply() try: self.model = DiffusionPipeline.from_pretrained( diff --git a/DashAI/front/src/api/credentials.ts b/DashAI/front/src/api/credentials.ts new file mode 100644 index 000000000..f3e3a43ed --- /dev/null +++ b/DashAI/front/src/api/credentials.ts @@ -0,0 +1,27 @@ +import api from "./api"; +import type { ICredential } from "../types/credential"; + +export const getCredentials = async (): Promise => { + const response = await api.get("/v1/credential/"); + return response.data; +}; + +export const authenticateCredential = async ( + name: string, + key: string, +): Promise<{ is_authenticated: boolean }> => { + const response = await api.post<{ is_authenticated: boolean }>( + `/v1/credential/${name}/auth`, + { key }, + ); + return response.data; +}; + +export const deleteCredential = async ( + name: string, +): Promise<{ is_authenticated: boolean }> => { + const response = await api.delete<{ is_authenticated: boolean }>( + `/v1/credential/${name}`, + ); + return response.data; +}; diff --git a/DashAI/front/src/components/ResponsiveAppBar.jsx b/DashAI/front/src/components/ResponsiveAppBar.jsx index 47a16d296..10c06ec2d 100644 --- a/DashAI/front/src/components/ResponsiveAppBar.jsx +++ b/DashAI/front/src/components/ResponsiveAppBar.jsx @@ -18,6 +18,7 @@ import LightModeOutlinedIcon from "@mui/icons-material/LightModeOutlined"; import Tooltip from "@mui/material/Tooltip"; import HardwareMonitorButton from "./hardware/HardwareMonitorButton"; import NavbarTourButton from "./tour/NavbarTourButton"; +import CredentialsButton from "./credentials/CredentialsButton"; function ResponsiveAppBar() { const theme = useTheme(); @@ -220,6 +221,7 @@ function ResponsiveAppBar() { + () => null); + describe("ResponsiveAppBar", () => { it("renders without crashing", () => { renderWithProviders(); diff --git a/DashAI/front/src/components/credentials/CredentialsButton.jsx b/DashAI/front/src/components/credentials/CredentialsButton.jsx new file mode 100644 index 000000000..048116863 --- /dev/null +++ b/DashAI/front/src/components/credentials/CredentialsButton.jsx @@ -0,0 +1,40 @@ +import React, { useState } from "react"; +import IconButton from "@mui/material/IconButton"; +import Tooltip from "@mui/material/Tooltip"; +import VpnKeyOutlinedIcon from "@mui/icons-material/VpnKeyOutlined"; +import { useTheme } from "@mui/material/styles"; +import { useTranslation } from "react-i18next"; +import CredentialsDialog from "./CredentialsDialog"; + +export default function CredentialsButton() { + const theme = useTheme(); + const { t } = useTranslation("credentials"); + const [open, setOpen] = useState(false); + + const iconBtnSx = { + width: 32, + height: 32, + borderRadius: "4px", + border: `1px solid ${theme.palette.divider}`, + color: theme.palette.text.secondary, + "&:hover": { + background: theme.palette.ui.hover, + color: theme.palette.text.primary, + }, + }; + + return ( + <> + + setOpen(true)} + aria-label="credentials" + sx={iconBtnSx} + > + + + + setOpen(false)} /> + + ); +} diff --git a/DashAI/front/src/components/credentials/CredentialsDialog.jsx b/DashAI/front/src/components/credentials/CredentialsDialog.jsx new file mode 100644 index 000000000..e1b864236 --- /dev/null +++ b/DashAI/front/src/components/credentials/CredentialsDialog.jsx @@ -0,0 +1,266 @@ +import React, { useEffect, useState } from "react"; +import PropTypes from "prop-types"; +import { + Dialog, + DialogTitle, + DialogContent, + Stack, + TextField, + Button, + Typography, + Box, + IconButton, + InputAdornment, + Tooltip, +} from "@mui/material"; +import VisibilityOutlinedIcon from "@mui/icons-material/VisibilityOutlined"; +import VisibilityOffOutlinedIcon from "@mui/icons-material/VisibilityOffOutlined"; +import DeleteOutlineIcon from "@mui/icons-material/DeleteOutline"; +import CloseIcon from "@mui/icons-material/Close"; +import VpnKeyOutlinedIcon from "@mui/icons-material/VpnKeyOutlined"; +import { useTranslation } from "react-i18next"; +import { useSnackbar } from "notistack"; +import { + getCredentials, + authenticateCredential, + deleteCredential, +} from "../../api/credentials"; + +function CredentialRow({ credential, onChanged }) { + const { t } = useTranslation("credentials"); + const { enqueueSnackbar } = useSnackbar(); + const [key, setKey] = useState(credential.key ?? ""); + const [showKey, setShowKey] = useState(false); + const [busy, setBusy] = useState(false); + + const authed = credential.is_authenticated; + + const handleVerify = async () => { + setBusy(true); + try { + await authenticateCredential(credential.name, key); + enqueueSnackbar(t("verifySuccess"), { variant: "success" }); + onChanged(); + } catch { + enqueueSnackbar(t("verifyError"), { variant: "error" }); + } finally { + setBusy(false); + } + }; + + const handleRemove = async () => { + setBusy(true); + try { + await deleteCredential(credential.name); + setKey(""); + onChanged(); + } finally { + setBusy(false); + } + }; + + return ( + ({ + border: `1px solid ${theme.palette.divider}`, + borderRadius: 2, + p: 2, + transition: "border-color 0.15s, background 0.15s", + "&:hover": { borderColor: theme.palette.text.disabled }, + })} + > + {/* Identity + status */} + + + + {credential.display_name} + + {credential.description && ( + + {credential.description} + + )} + + + ({ + width: 8, + height: 8, + borderRadius: "50%", + backgroundColor: authed + ? theme.palette.success.main + : theme.palette.text.disabled, + })} + /> + ({ + color: authed + ? theme.palette.success.main + : theme.palette.text.secondary, + fontWeight: 500, + })} + > + {authed ? t("authenticated") : t("notAuthenticated")} + + + + + {/* Key input + actions */} + + setKey(e.target.value)} + sx={{ "& input": { fontFamily: "monospace", fontSize: 13 } }} + InputProps={{ + endAdornment: ( + + setShowKey((prev) => !prev)} + edge="end" + > + {showKey ? ( + + ) : ( + + )} + + + ), + }} + /> + + {authed && ( + + + + + + + + )} + + + ); +} + +CredentialRow.propTypes = { + credential: PropTypes.object.isRequired, + onChanged: PropTypes.func.isRequired, +}; + +export default function CredentialsDialog({ open, onClose }) { + const { t } = useTranslation("credentials"); + const [credentials, setCredentials] = useState([]); + + const refresh = async () => { + try { + const data = await getCredentials(); + setCredentials(Array.isArray(data) ? data : []); + } catch { + // silently ignore fetch errors (e.g. in test environments) + } + }; + + useEffect(() => { + if (open) { + refresh(); + } + }, [open]); + + return ( + + + + ({ + width: 36, + height: 36, + borderRadius: 1.5, + flexShrink: 0, + display: "flex", + alignItems: "center", + justifyContent: "center", + color: theme.palette.primary.main, + backgroundColor: theme.palette.ui.hover, + })} + > + + + + + {t("title")} + + + {t("subtitle")} + + + + + + + + + + {credentials.map((credential) => ( + + ))} + + + + ); +} + +CredentialsDialog.propTypes = { + open: PropTypes.bool.isRequired, + onClose: PropTypes.func.isRequired, +}; diff --git a/DashAI/front/src/components/custom/ComponentSelector.jsx b/DashAI/front/src/components/custom/ComponentSelector.jsx index 17b088f2d..dd64b2370 100644 --- a/DashAI/front/src/components/custom/ComponentSelector.jsx +++ b/DashAI/front/src/components/custom/ComponentSelector.jsx @@ -10,12 +10,15 @@ import { Typography, Collapse, Paper, + Tooltip, } from "@mui/material"; import { Search as SearchIcon, Clear as ClearIcon, ExpandMore as ExpandMoreIcon, Check as CheckIcon, + LockOutlined as LockIcon, + VpnKeyOutlined as KeyIcon, } from "@mui/icons-material"; import { useTranslation } from "react-i18next"; @@ -30,6 +33,13 @@ function getDescription(component, fallback = "") { return component.description ?? fallback; } +// Derive a human label from a credential component name, e.g. +// "HuggingFaceCredential" -> "HuggingFace". Falls back to the raw name so it +// works for any backend-registered credential without frontend changes. +function credentialLabel(name) { + return name.replace(/Credential$/, "") || name; +} + function ComponentSelector({ components, selected = null, @@ -42,7 +52,7 @@ function ComponentSelector({ tourDataFor = null, tourDataMatchFn = null, }) { - const { t } = useTranslation("custom"); + const { t } = useTranslation(["custom", "credentials"]); const [search, setSearch] = useState(""); const [activeCategory, setActiveCategory] = useState(ALL_CATEGORY); @@ -107,6 +117,17 @@ function ComponentSelector({ const renderCard = (component) => { const isSelected = selected?.name === component.name; const icon = getIcon?.(component); + const requiredCredentials = component.required_credentials ?? []; + const optionalCredentials = component.optional_credentials ?? []; + const credentialsSatisfied = component.credentials_satisfied !== false; + // Unmet required credentials make the component unusable: lock and disable. + const locked = !credentialsSatisfied && requiredCredentials.length > 0; + const requiredPlatforms = requiredCredentials + .map(credentialLabel) + .join(", "); + const optionalPlatforms = optionalCredentials + .map(credentialLabel) + .join(", "); const isCsvComponent = tourDataFor && (tourDataMatchFn @@ -116,19 +137,23 @@ function ComponentSelector({ handleSelect(component)} + onClick={() => { + if (!locked) handleSelect(component); + }} + aria-disabled={locked} data-tour={isCsvComponent ? tourDataFor : undefined} sx={{ p: 3, display: "flex", gap: 3, alignItems: "flex-start", - cursor: "pointer", + cursor: locked ? "not-allowed" : "pointer", + opacity: locked ? 0.55 : 1, border: 1, borderColor: isSelected ? "primary.main" : "divider", bgcolor: isSelected ? "action.selected" : "background.paper", transition: "border-color 0.15s, background 0.15s", - "&:hover": { borderColor: "primary.light" }, + "&:hover": locked ? undefined : { borderColor: "primary.light" }, }} > {icon && ( @@ -165,13 +190,32 @@ function ComponentSelector({ {getDescription(component, t("noDescriptionAvailable"))} - {isSelected && ( - - )} + + {locked && ( + + + + )} + {!locked && optionalCredentials.length > 0 && ( + + + + )} + {isSelected && } + ); }; diff --git a/DashAI/front/src/hooks/useComponentAvailability.js b/DashAI/front/src/hooks/useComponentAvailability.js new file mode 100644 index 000000000..50a31c7b2 --- /dev/null +++ b/DashAI/front/src/hooks/useComponentAvailability.js @@ -0,0 +1,21 @@ +/** + * Compute availability of a component from its credential requirements. + * + * @param {object} component - Component dict with credential fields. + * @param {Object} credentialStatuses - name -> authenticated. + * @returns {{available: boolean, missingRequired: string[], missingOptional: string[]}} + */ +export function getComponentAvailability(component, credentialStatuses = {}) { + const required = component?.required_credentials ?? []; + const optional = component?.optional_credentials ?? []; + + const missingRequired = required.filter((name) => !credentialStatuses[name]); + const missingOptional = optional.filter((name) => !credentialStatuses[name]); + + const available = + typeof component?.credentials_satisfied === "boolean" + ? component.credentials_satisfied + : missingRequired.length === 0; + + return { available, missingRequired, missingOptional }; +} diff --git a/DashAI/front/src/types/component.ts b/DashAI/front/src/types/component.ts index 9bedf444b..231e2d6b2 100644 --- a/DashAI/front/src/types/component.ts +++ b/DashAI/front/src/types/component.ts @@ -10,4 +10,7 @@ export interface IComponent { description: string; display_name?: string; color?: string; + required_credentials?: string[]; + optional_credentials?: string[]; + credentials_satisfied?: boolean; } diff --git a/DashAI/front/src/types/credential.ts b/DashAI/front/src/types/credential.ts new file mode 100644 index 000000000..9426156f6 --- /dev/null +++ b/DashAI/front/src/types/credential.ts @@ -0,0 +1,7 @@ +export interface ICredential { + name: string; + display_name: string; + description: string; + is_authenticated: boolean; + key: string | null; +} diff --git a/DashAI/front/src/utils/i18n/index.js b/DashAI/front/src/utils/i18n/index.js index 2e89068eb..493f4fdbe 100644 --- a/DashAI/front/src/utils/i18n/index.js +++ b/DashAI/front/src/utils/i18n/index.js @@ -37,6 +37,11 @@ import generativeTourEN from "./locales/en/generativeTour.json"; import generativeTourES from "./locales/es/generativeTour.json"; import hubEN from "./locales/en/hub.json"; import hubES from "./locales/es/hub.json"; +import credentialsEN from "./locales/en/credentials.json"; +import credentialsES from "./locales/es/credentials.json"; +import credentialsPT from "./locales/pt/credentials.json"; +import credentialsDE from "./locales/de/credentials.json"; +import credentialsZH from "./locales/zh/credentials.json"; import configurableObjectPT from "./locales/pt/configurableObject.json"; import commonPT from "./locales/pt/common.json"; import customPT from "./locales/pt/custom.json"; @@ -113,6 +118,7 @@ const resources = { modelsSessionTour: modelsSessionTourEN, generativeTour: generativeTourEN, hub: hubEN, + credentials: credentialsEN, }, es: { configurableObject: configurableObjectES, @@ -133,6 +139,7 @@ const resources = { modelsSessionTour: modelsSessionTourES, generativeTour: generativeTourES, hub: hubES, + credentials: credentialsES, }, pt: { configurableObject: configurableObjectPT, @@ -152,6 +159,7 @@ const resources = { modelsTour: modelsTourPT, modelsSessionTour: modelsSessionTourPT, generativeTour: generativeTourPT, + credentials: credentialsPT, }, de: { configurableObject: configurableObjectDE, @@ -171,6 +179,7 @@ const resources = { modelsTour: modelsTourDE, modelsSessionTour: modelsSessionTourDE, generativeTour: generativeTourDE, + credentials: credentialsDE, }, zh: { configurableObject: configurableObjectZH, @@ -191,6 +200,7 @@ const resources = { modelsSessionTour: modelsSessionTourZH, generativeTour: generativeTourZH, hub: hubZH, + credentials: credentialsZH, }, }; @@ -222,6 +232,7 @@ i18n "plugins", "generativeTour", "hub", + "credentials", ], defaultNS: "common", diff --git a/DashAI/front/src/utils/i18n/locales/de/credentials.json b/DashAI/front/src/utils/i18n/locales/de/credentials.json new file mode 100644 index 000000000..f1c54f253 --- /dev/null +++ b/DashAI/front/src/utils/i18n/locales/de/credentials.json @@ -0,0 +1,13 @@ +{ + "title": "Anmeldedaten", + "subtitle": "Bei externen Plattformen authentifizieren", + "keyPlaceholder": "API-Schluessel eingeben", + "verify": "Pruefen", + "remove": "Entfernen", + "authenticated": "Authentifiziert", + "notAuthenticated": "Nicht authentifiziert", + "verifySuccess": "Anmeldedaten geprueft", + "verifyError": "Ungueltiger Schluessel", + "requiredTooltip": "Bei {{platform}} authentifizieren, um dies zu nutzen", + "optionalTooltip": "Bei {{platform}} authentifizieren fuer zusaetzlichen Zugriff" +} diff --git a/DashAI/front/src/utils/i18n/locales/en/credentials.json b/DashAI/front/src/utils/i18n/locales/en/credentials.json new file mode 100644 index 000000000..fb120e96e --- /dev/null +++ b/DashAI/front/src/utils/i18n/locales/en/credentials.json @@ -0,0 +1,13 @@ +{ + "title": "Credentials", + "subtitle": "Authenticate to external platforms", + "keyPlaceholder": "Enter your API key", + "verify": "Verify", + "remove": "Remove", + "authenticated": "Authenticated", + "notAuthenticated": "Not authenticated", + "verifySuccess": "Credential verified", + "verifyError": "Invalid key", + "requiredTooltip": "Authenticate {{platform}} to use this", + "optionalTooltip": "Authenticate {{platform}} for extra access" +} diff --git a/DashAI/front/src/utils/i18n/locales/es/credentials.json b/DashAI/front/src/utils/i18n/locales/es/credentials.json new file mode 100644 index 000000000..f7c5a8116 --- /dev/null +++ b/DashAI/front/src/utils/i18n/locales/es/credentials.json @@ -0,0 +1,13 @@ +{ + "title": "Credenciales", + "subtitle": "Autenticarse en plataformas externas", + "keyPlaceholder": "Ingrese su clave de API", + "verify": "Verificar", + "remove": "Eliminar", + "authenticated": "Autenticado", + "notAuthenticated": "No autenticado", + "verifySuccess": "Credencial verificada", + "verifyError": "Clave invalida", + "requiredTooltip": "Autentiquese en {{platform}} para usar esto", + "optionalTooltip": "Autentiquese en {{platform}} para acceso adicional" +} diff --git a/DashAI/front/src/utils/i18n/locales/pt/credentials.json b/DashAI/front/src/utils/i18n/locales/pt/credentials.json new file mode 100644 index 000000000..cbb0d442a --- /dev/null +++ b/DashAI/front/src/utils/i18n/locales/pt/credentials.json @@ -0,0 +1,13 @@ +{ + "title": "Credenciais", + "subtitle": "Autenticar em plataformas externas", + "keyPlaceholder": "Insira sua chave de API", + "verify": "Verificar", + "remove": "Remover", + "authenticated": "Autenticado", + "notAuthenticated": "Nao autenticado", + "verifySuccess": "Credencial verificada", + "verifyError": "Chave invalida", + "requiredTooltip": "Autentique-se no {{platform}} para usar isto", + "optionalTooltip": "Autentique-se no {{platform}} para acesso adicional" +} diff --git a/DashAI/front/src/utils/i18n/locales/zh/credentials.json b/DashAI/front/src/utils/i18n/locales/zh/credentials.json new file mode 100644 index 000000000..7a7f3d7f4 --- /dev/null +++ b/DashAI/front/src/utils/i18n/locales/zh/credentials.json @@ -0,0 +1,13 @@ +{ + "title": "凭证", + "subtitle": "对外部平台进行身份验证", + "keyPlaceholder": "输入您的 API 密钥", + "verify": "验证", + "remove": "移除", + "authenticated": "已验证", + "notAuthenticated": "未验证", + "verifySuccess": "凭证已验证", + "verifyError": "无效的密钥", + "requiredTooltip": "验证 {{platform}} 以使用此功能", + "optionalTooltip": "验证 {{platform}} 以获得额外访问权限" +} diff --git a/requirements.txt b/requirements.txt index 2fc5b6b99..9a3cfe627 100644 --- a/requirements.txt +++ b/requirements.txt @@ -49,3 +49,5 @@ torchmetrics pywebview openml oslo.concurrency +cryptography +kaggle diff --git a/tests/back/api/test_components_api.py b/tests/back/api/test_components_api.py index 26cd14699..ddaca3b18 100644 --- a/tests/back/api/test_components_api.py +++ b/tests/back/api/test_components_api.py @@ -164,6 +164,9 @@ def test_get_component_by_id(client: TestClient): "description": "Task 1.", "display_name": "Test Task 1", "color": "#795548", + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, } response = client.get("/api/v1/component/TestTask2/") @@ -182,6 +185,9 @@ def test_get_component_by_id(client: TestClient): "description": "Task 2.", "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, } response = client.get("/api/v1/component/TestDataloader1/") @@ -199,6 +205,9 @@ def test_get_component_by_id(client: TestClient): "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, } @@ -292,6 +301,9 @@ def test_get_components_select_only_tasks(client: TestClient): "description": "Task 1.", "display_name": "Test Task 1", "color": "#795548", + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, { "name": "TestTask2", @@ -307,6 +319,9 @@ def test_get_components_select_only_tasks(client: TestClient): "description": "Task 2.", "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, ] @@ -330,6 +345,9 @@ def test_get_components_select_only_dataloaders(client: TestClient): "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, { "name": "TestDataloader2", @@ -344,6 +362,9 @@ def test_get_components_select_only_dataloaders(client: TestClient): "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, { "name": "TestDataloader3", @@ -358,6 +379,9 @@ def test_get_components_select_only_dataloaders(client: TestClient): "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, ] @@ -438,6 +462,9 @@ def test_get_components_ignore_models(client: TestClient): "description": "Task 1.", "display_name": "Test Task 1", "color": "#795548", + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, { "name": "TestTask2", @@ -453,6 +480,9 @@ def test_get_components_ignore_models(client: TestClient): "description": "Task 2.", "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, { "name": "TestDataloader1", @@ -467,6 +497,9 @@ def test_get_components_ignore_models(client: TestClient): "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, { "name": "TestDataloader2", @@ -481,6 +514,9 @@ def test_get_components_ignore_models(client: TestClient): "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, { "name": "TestDataloader3", @@ -495,6 +531,9 @@ def test_get_components_ignore_models(client: TestClient): "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, ] @@ -517,6 +556,9 @@ def test_get_components_ignore_tasks_and_models(client: TestClient): "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, { "name": "TestDataloader2", @@ -531,6 +573,9 @@ def test_get_components_ignore_tasks_and_models(client: TestClient): "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, { "name": "TestDataloader3", @@ -545,6 +590,9 @@ def test_get_components_ignore_tasks_and_models(client: TestClient): "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, ] @@ -608,6 +656,9 @@ def test_get_components_related_inverse_relation(client: TestClient): "description": "Task 1.", "display_name": "Test Task 1", "color": "#795548", + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, } ] @@ -653,6 +704,9 @@ def test_get_components_dataloader_component_parent(client: TestClient): "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, { "name": "TestDataloader2", @@ -667,6 +721,9 @@ def test_get_components_dataloader_component_parent(client: TestClient): "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, ] @@ -706,6 +763,9 @@ def test_get_components_by_type_and_task(client: TestClient): "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, { "name": "TestDataloader2", @@ -720,6 +780,9 @@ def test_get_components_by_type_and_task(client: TestClient): "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, ] @@ -758,6 +821,9 @@ def test_get_components_select_and_ignore_by_type(client: TestClient): "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, { "name": "TestDataloader2", @@ -772,6 +838,9 @@ def test_get_components_select_and_ignore_by_type(client: TestClient): "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, { "name": "TestDataloader3", @@ -786,6 +855,9 @@ def test_get_components_select_and_ignore_by_type(client: TestClient): "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, ] @@ -811,6 +883,9 @@ def test_get_components_select_type_and_parent(client: TestClient): "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, { "name": "TestDataloader2", @@ -825,5 +900,8 @@ def test_get_components_select_type_and_parent(client: TestClient): "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, }, ] diff --git a/tests/back/api/test_credentials_api.py b/tests/back/api/test_credentials_api.py new file mode 100644 index 000000000..bc318eb5f --- /dev/null +++ b/tests/back/api/test_credentials_api.py @@ -0,0 +1,74 @@ +from unittest.mock import patch + + +def test_list_credentials(client): + response = client.get("/api/v1/credential/") + assert response.status_code == 200 + names = {c["name"] for c in response.json()} + assert "HuggingFaceCredential" in names + for cred in response.json(): + assert "is_authenticated" in cred + # catalog + status returned together in one request + assert "display_name" in cred + assert "description" in cred + assert "key" in cred + + +def test_auth_success_marks_authenticated(client): + with patch( + "DashAI.back.credentials.huggingface_credential.HuggingFaceCredential.verify", + return_value=True, + ): + response = client.post( + "/api/v1/credential/HuggingFaceCredential/auth", + json={"key": "hf_token"}, + ) + assert response.status_code == 200 + assert response.json()["is_authenticated"] is True + + status = client.get("/api/v1/credential/HuggingFaceCredential") + assert status.json()["is_authenticated"] is True + + +def test_auth_invalid_key_returns_400(client): + with patch( + "DashAI.back.credentials.huggingface_credential.HuggingFaceCredential.verify", + return_value=False, + ): + response = client.post( + "/api/v1/credential/HuggingFaceCredential/auth", + json={"key": "bad-secret-key"}, + ) + assert response.status_code == 400 + # the error must never echo the submitted key + assert "bad-secret-key" not in response.text + + +def test_auth_unknown_credential_returns_404(client): + response = client.post("/api/v1/credential/NotACredential/auth", json={"key": "x"}) + assert response.status_code == 404 + + +def test_get_unknown_credential_returns_404(client): + response = client.get("/api/v1/credential/NotACredential") + assert response.status_code == 404 + + +def test_delete_unknown_credential_returns_404(client): + response = client.delete("/api/v1/credential/NotACredential") + assert response.status_code == 404 + + +def test_delete_credential(client): + with patch( + "DashAI.back.credentials.huggingface_credential.HuggingFaceCredential.verify", + return_value=True, + ): + client.post( + "/api/v1/credential/HuggingFaceCredential/auth", + json={"key": "hf_token"}, + ) + response = client.delete("/api/v1/credential/HuggingFaceCredential") + assert response.status_code == 200 + status = client.get("/api/v1/credential/HuggingFaceCredential") + assert status.json()["is_authenticated"] is False diff --git a/tests/back/credentials/__init__.py b/tests/back/credentials/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/back/credentials/test_base_credential.py b/tests/back/credentials/test_base_credential.py new file mode 100644 index 000000000..6220a3a0e --- /dev/null +++ b/tests/back/credentials/test_base_credential.py @@ -0,0 +1,55 @@ +import pytest +from cryptography.fernet import Fernet +from kink import di +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from DashAI.back.credentials.base_credential import BaseCredential +from DashAI.back.credentials.encryptor import CredentialEncryptor +from DashAI.back.credentials.store import CredentialStore +from DashAI.back.dependencies.database.models import Base + + +class FakeCredential(BaseCredential): + DISPLAY_NAME = "Fake" + DESCRIPTION = "Fake credential for tests" + last_seen_key = None + + def verify(self, key: str) -> bool: + return key == "good" + + +@pytest.fixture(autouse=True) +def credential_store_in_di(): + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + session_factory = sessionmaker(bind=engine) + encryptor = CredentialEncryptor(Fernet.generate_key()) + di["credential_store"] = CredentialStore(session_factory, encryptor) + yield + del di["credential_store"] + + +def test_auth_stores_valid_key(): + cred = FakeCredential() + assert cred.auth("good") is True + assert cred.get_key() == "good" + assert cred.is_authenticated() is True + + +def test_auth_rejects_invalid_key(): + cred = FakeCredential() + with pytest.raises(ValueError, match="Invalid credential"): + cred.auth("bad") + assert cred.get_key() is None + assert cred.is_authenticated() is False + + +def test_apply_is_noop_without_key(): + cred = FakeCredential() + # should not raise even though nothing is stored + cred.apply() + + +def test_type_is_credential(): + assert FakeCredential.TYPE == "Credential" diff --git a/tests/back/credentials/test_concrete_credentials.py b/tests/back/credentials/test_concrete_credentials.py new file mode 100644 index 000000000..fd193fc7d --- /dev/null +++ b/tests/back/credentials/test_concrete_credentials.py @@ -0,0 +1,83 @@ +import sys +import types +from contextlib import contextmanager +from unittest.mock import MagicMock, patch + +from DashAI.back.credentials.github_credential import GithubCredential +from DashAI.back.credentials.huggingface_credential import HuggingFaceCredential +from DashAI.back.credentials.kaggle_credential import KaggleCredential + + +@contextmanager +def fake_kaggle(api_instance): + """Install a stub ``kaggle`` package so the real one is never imported. + + The official ``kaggle`` package authenticates at import time and exits the + process without credentials, so tests inject a fake module tree exposing a + ``KaggleApi`` that returns ``api_instance``. + """ + module_names = ("kaggle", "kaggle.api", "kaggle.api.kaggle_api_extended") + saved = {name: sys.modules.get(name) for name in module_names} + sys.modules["kaggle"] = types.ModuleType("kaggle") + sys.modules["kaggle.api"] = types.ModuleType("kaggle.api") + extended = types.ModuleType("kaggle.api.kaggle_api_extended") + extended.KaggleApi = MagicMock(return_value=api_instance) + sys.modules["kaggle.api.kaggle_api_extended"] = extended + try: + yield + finally: + for name, module in saved.items(): + if module is None: + sys.modules.pop(name, None) + else: + sys.modules[name] = module + + +def test_huggingface_verify_success(): + cred = HuggingFaceCredential() + with patch("huggingface_hub.HfApi") as hf_api: + hf_api.return_value.whoami.return_value = {"name": "user"} + assert cred.verify("hf_good") is True + + +def test_huggingface_verify_failure(): + cred = HuggingFaceCredential() + with patch("huggingface_hub.HfApi") as hf_api: + hf_api.return_value.whoami.side_effect = Exception("401") + assert cred.verify("hf_bad") is False + + +def test_github_verify_success(): + cred = GithubCredential() + with patch("requests.get") as get: + get.return_value = MagicMock(status_code=200) + assert cred.verify("ghp_good") is True + + +def test_github_verify_failure(): + cred = GithubCredential() + with patch("requests.get") as get: + get.return_value = MagicMock(status_code=401) + assert cred.verify("ghp_bad") is False + + +def test_kaggle_verify_success(): + cred = KaggleCredential() + api = MagicMock() + api.authenticate.return_value = None + api.competitions_list.return_value = [] + with fake_kaggle(api): + assert cred.verify("user:key") is True + + +def test_kaggle_verify_failure(): + cred = KaggleCredential() + api = MagicMock() + api.competitions_list.side_effect = Exception("401") + with fake_kaggle(api): + assert cred.verify("user:badkey") is False + + +def test_kaggle_verify_malformed_key(): + cred = KaggleCredential() + assert cred.verify("no-separator") is False diff --git a/tests/back/credentials/test_encryptor.py b/tests/back/credentials/test_encryptor.py new file mode 100644 index 000000000..686762465 --- /dev/null +++ b/tests/back/credentials/test_encryptor.py @@ -0,0 +1,30 @@ +from pathlib import Path + +from DashAI.back.credentials.encryptor import ( + CredentialEncryptor, + load_or_create_key, +) + + +def test_encrypt_decrypt_roundtrip(): + key = load_or_create_key(Path("/nonexistent/path"), env_value=None, persist=False) + enc = CredentialEncryptor(key) + token = enc.encrypt("hf_secret_token") + assert token != "hf_secret_token" + assert enc.decrypt(token) == "hf_secret_token" + + +def test_load_or_create_key_uses_env_value(tmp_path): + from cryptography.fernet import Fernet + + env_key = Fernet.generate_key().decode() + key = load_or_create_key(tmp_path / "key", env_value=env_key) + assert key == env_key.encode() + + +def test_load_or_create_key_persists_and_reuses(tmp_path): + key_path = tmp_path / ".credentials_key" + first = load_or_create_key(key_path, env_value=None) + assert key_path.exists() + second = load_or_create_key(key_path, env_value=None) + assert first == second diff --git a/tests/back/credentials/test_get_credential.py b/tests/back/credentials/test_get_credential.py new file mode 100644 index 000000000..f08b7105d --- /dev/null +++ b/tests/back/credentials/test_get_credential.py @@ -0,0 +1,32 @@ +from kink import di + +from DashAI.back.config_object import ConfigObject +from DashAI.back.credentials.base_credential import BaseCredential +from DashAI.back.dependencies.registry import ComponentRegistry + + +class DummyCredential(BaseCredential): + DISPLAY_NAME = "Dummy" + DESCRIPTION = "dummy" + + def verify(self, key: str) -> bool: + return True + + +class DummyComponentBase: + TYPE = "DummyType" + + +class DummyComponent(ConfigObject, DummyComponentBase): + REQUIRED_CREDENTIALS = ["DummyCredential"] + + +def test_get_credential_returns_instance(): + registry = ComponentRegistry(initial_components=[DummyCredential]) + di["component_registry"] = registry + try: + component = DummyComponent() + cred = component.get_credential("DummyCredential") + assert isinstance(cred, DummyCredential) + finally: + del di["component_registry"] diff --git a/tests/back/credentials/test_store.py b/tests/back/credentials/test_store.py new file mode 100644 index 000000000..837ef4b56 --- /dev/null +++ b/tests/back/credentials/test_store.py @@ -0,0 +1,58 @@ +import pytest +from cryptography.fernet import Fernet +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from DashAI.back.credentials.encryptor import CredentialEncryptor +from DashAI.back.credentials.store import CredentialStore +from DashAI.back.dependencies.database.models import Base + + +@pytest.fixture +def session_factory(): + engine = create_engine("sqlite:///:memory:") + Base.metadata.create_all(engine) + return sessionmaker(bind=engine) + + +@pytest.fixture +def store(session_factory): + encryptor = CredentialEncryptor(Fernet.generate_key()) + return CredentialStore(session_factory, encryptor) + + +def test_save_and_load_roundtrip(store): + store.save("HuggingFaceCredential", "hf_token_123") + assert store.load("HuggingFaceCredential") == "hf_token_123" + + +def test_save_marks_verified(store): + store.save("HuggingFaceCredential", "hf_token_123") + assert store.is_verified("HuggingFaceCredential") is True + + +def test_load_missing_returns_none(store): + assert store.load("Missing") is None + assert store.is_verified("Missing") is False + + +def test_save_is_upsert(store): + store.save("HuggingFaceCredential", "old") + store.save("HuggingFaceCredential", "new") + assert store.load("HuggingFaceCredential") == "new" + + +def test_delete_removes_key(store): + store.save("HuggingFaceCredential", "tok") + store.delete("HuggingFaceCredential") + assert store.load("HuggingFaceCredential") is None + assert store.is_verified("HuggingFaceCredential") is False + + +def test_all_statuses(store): + store.save("HuggingFaceCredential", "tok") + store.save("GithubCredential", "ghp") + assert store.all_statuses() == { + "HuggingFaceCredential": True, + "GithubCredential": True, + } diff --git a/tests/back/registries/test_registry.py b/tests/back/registries/test_registry.py index c15b17238..4a5f374fa 100644 --- a/tests/back/registries/test_registry.py +++ b/tests/back/registries/test_registry.py @@ -80,6 +80,9 @@ class NoComponent: ... "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, } COMPONENT2_DICT = { "name": "Component2", @@ -91,6 +94,9 @@ class NoComponent: ... "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, } SUBCOMPONENT1_DICT = { "name": "SubComponent1", @@ -102,6 +108,9 @@ class NoComponent: ... "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, } COMPONENT3_DICT = { "name": "Component3", @@ -113,6 +122,9 @@ class NoComponent: ... "description": "Some static component", "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, } COMPONENT3_DICT_MS = COMPONENT3_DICT.copy() COMPONENT3_DICT_MS["description"] = MultilingualString(en="Some static component") @@ -127,6 +139,9 @@ class NoComponent: ... "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, } RELATED_COMPONENT2_DICT = { "name": "RelatedComponent2", @@ -138,6 +153,9 @@ class NoComponent: ... "description": None, "display_name": None, "color": None, + "required_credentials": [], + "optional_credentials": [], + "credentials_satisfied": True, } @@ -457,10 +475,12 @@ def test_relationships_module(): test_registry.register_component(RelatedComponent2) assert test_registry._relationship_manager.relations == { - "RelatedComponent1": ["Component1"], - "Component1": ["RelatedComponent1", "RelatedComponent2"], - "RelatedComponent2": ["Component1", "Component2"], - "Component2": ["RelatedComponent2"], + "RelatedComponent1": {"compatible_components": ["Component1"]}, + "Component1": { + "compatible_components": ["RelatedComponent1", "RelatedComponent2"] + }, + "RelatedComponent2": {"compatible_components": ["Component1", "Component2"]}, + "Component2": {"compatible_components": ["RelatedComponent2"]}, } assert test_registry.get_related_components("Component1") == [ @@ -482,3 +502,59 @@ def test_relationships_module(): COMPONENT1_DICT, COMPONENT2_DICT, ] + + +class CredentialComponentA: + TYPE = "Credential" + + +class ComponentNeedsCred(BaseStaticComponent): + REQUIRED_CREDENTIALS = ["CredentialComponentA"] + + +class ComponentOptionalCred(BaseStaticComponent): + OPTIONAL_CREDENTIALS = ["CredentialComponentA"] + + +def test_register_records_credential_relations_and_flag(): + reg = ComponentRegistry( + initial_components=[ComponentNeedsCred, ComponentOptionalCred] + ) + + # required credential present -> not satisfied until verified + assert reg["ComponentNeedsCred"]["required_credentials"] == ["CredentialComponentA"] + assert reg["ComponentNeedsCred"]["optional_credentials"] == [] + assert reg["ComponentNeedsCred"]["credentials_satisfied"] is False + + # only optional credential -> always satisfied + assert reg["ComponentOptionalCred"]["optional_credentials"] == [ + "CredentialComponentA" + ] + assert reg["ComponentOptionalCred"]["credentials_satisfied"] is True + + assert reg.get_required_credentials("ComponentNeedsCred") == [ + "CredentialComponentA" + ] + assert reg.get_optional_credentials("ComponentOptionalCred") == [ + "CredentialComponentA" + ] + + +def test_refresh_credentials_status_updates_flag(): + reg = ComponentRegistry(initial_components=[ComponentNeedsCred]) + assert reg["ComponentNeedsCred"]["credentials_satisfied"] is False + + reg.refresh_credentials_status({"CredentialComponentA": True}) + assert reg["ComponentNeedsCred"]["credentials_satisfied"] is True + + reg.refresh_credentials_status({"CredentialComponentA": False}) + assert reg["ComponentNeedsCred"]["credentials_satisfied"] is False + + +def test_refresh_credentials_status_only_targets_subset(): + reg = ComponentRegistry(initial_components=[ComponentNeedsCred]) + reg.refresh_credentials_status( + {"CredentialComponentA": True}, only=["SomeOtherComponent"] + ) + # not in `only` -> unchanged + assert reg["ComponentNeedsCred"]["credentials_satisfied"] is False diff --git a/tests/back/registries/test_relationship_manager.py b/tests/back/registries/test_relationship_manager.py index 7109f2ad9..f88e5e2f3 100644 --- a/tests/back/registries/test_relationship_manager.py +++ b/tests/back/registries/test_relationship_manager.py @@ -1,43 +1,40 @@ -from collections import defaultdict - from DashAI.back.dependencies.registry.relationship_manager import RelationshipManager -def test_relationship_manager_add_relations(): - test_relationship_manager = RelationshipManager() - - assert isinstance(test_relationship_manager.relations, dict) - assert isinstance(test_relationship_manager._relations, defaultdict) - assert test_relationship_manager.relations == {} - assert test_relationship_manager._relations == defaultdict(list) +def test_add_default_relation_type_is_bidirectional(): + rm = RelationshipManager() + rm.add_relationship("A", "B") + assert rm.get("A", "compatible_components") == ["B"] + assert rm.get("B", "compatible_components") == ["A"] - test_relationship_manager.add_relationship("Component1", "Task1") - test_relationship_manager.add_relationship("Component2", "Task1") - test_relationship_manager.add_relationship("Component3", "Task2") - assert test_relationship_manager.relations == { - "Component1": ["Task1"], - "Task1": ["Component1", "Component2"], - "Component2": ["Task1"], - "Component3": ["Task2"], - "Task2": ["Component3"], - } +def test_relation_types_are_isolated(): + rm = RelationshipManager() + rm.add_relationship("Model", "Task", "compatible_components") + rm.add_relationship("Model", "HFCred", "required_credentials") + assert rm.get("Model", "compatible_components") == ["Task"] + assert rm.get("Model", "required_credentials") == ["HFCred"] + # reverse lookup of who requires the credential + assert rm.get("HFCred", "required_credentials") == ["Model"] -def test_relationship_manager__getitem__(): - test_relationship_manager = RelationshipManager() +def test_get_missing_returns_empty_list(): + rm = RelationshipManager() + assert rm.get("X", "compatible_components") == [] - test_relationship_manager.add_relationship("Component1", "Task1") - test_relationship_manager.add_relationship("Component2", "Task1") - test_relationship_manager.add_relationship("Component3", "Task2") - assert test_relationship_manager["Component1"] == ["Task1"] - assert test_relationship_manager["Component2"] == ["Task1"] - assert test_relationship_manager["Component3"] == ["Task2"] - assert test_relationship_manager["Task1"] == ["Component1", "Component2"] - assert test_relationship_manager["Task2"] == ["Component3"] +def test_relations_property_is_nested(): + rm = RelationshipManager() + rm.add_relationship("A", "B") + assert rm.relations == { + "A": {"compatible_components": ["B"]}, + "B": {"compatible_components": ["A"]}, + } -def test_relationship_manager__getitem__unexistant_component(): - test_relationship_manager = RelationshipManager() - assert test_relationship_manager["UnexistantComponent"] == [] +def test_remove_relationship_with_type(): + rm = RelationshipManager() + rm.add_relationship("A", "B", "required_credentials") + rm.remove_relationship("A", "B", "required_credentials") + assert rm.get("A", "required_credentials") == [] + assert rm.get("B", "required_credentials") == []