diff --git a/DashAI/back/api/api_v1/endpoints/dataset_source.py b/DashAI/back/api/api_v1/endpoints/dataset_source.py index 18c89403b..d8d7a45d6 100644 --- a/DashAI/back/api/api_v1/endpoints/dataset_source.py +++ b/DashAI/back/api/api_v1/endpoints/dataset_source.py @@ -52,6 +52,7 @@ async def search_datasets( q: str = Query(default="", description="Search query"), limit: int = Query(default=20, ge=1, le=100), cursor: str = Query(default="", description="Pagination cursor from previous page"), + tags: list[str] = Query(default=[]), registry: "ComponentRegistry" = Depends(lambda: di["component_registry"]), ) -> Dict[str, Any]: """Search for datasets in a registered source. @@ -67,6 +68,9 @@ async def search_datasets( cursor : str Opaque pagination token returned by the previous call. Empty string means first page. + tags : list[str] + Repeated tag filter strings (e.g. ``?tags=nlp&tags=tabular``). Passed + through to the datasource via ``**filters``. registry : ComponentRegistry Injected component registry. @@ -76,7 +80,7 @@ async def search_datasets( ``{"results": [...], "next_cursor": str | null}`` """ source = _get_source(source_name, registry) - page = source.search(q, limit=limit, cursor=cursor or None) + page = source.search(q, limit=limit, cursor=cursor or None, tags=tags) return { "results": [ { diff --git a/DashAI/back/dataset_sources/huggingface_dataset_source.py b/DashAI/back/dataset_sources/huggingface_dataset_source.py index aceeebb29..a0ee6bec5 100644 --- a/DashAI/back/dataset_sources/huggingface_dataset_source.py +++ b/DashAI/back/dataset_sources/huggingface_dataset_source.py @@ -99,7 +99,9 @@ def search( Pagination cursor returned by the previous call (encoded numeric offset). ``None`` fetches the first page. **filters : Any - Unused; reserved for future tag/task filters. + Supported keys: + tags (list[str]): Filter by HuggingFace tag strings. + Passed directly as the ``filter`` argument to ``list_datasets``. Returns ------- @@ -110,11 +112,13 @@ def search( try: offset = int(cursor) if cursor else 0 + tags: list[str] = filters.get("tags") or [] iterator = HfApi().list_datasets( search=query or None, full=True, limit=offset + limit + 1, + filter=tags if tags else None, ) window = list(islice(iterator, offset, offset + limit + 1)) has_next = len(window) > limit diff --git a/DashAI/back/dataset_sources/openml_dataset_source.py b/DashAI/back/dataset_sources/openml_dataset_source.py index a0326447b..202e398c6 100644 --- a/DashAI/back/dataset_sources/openml_dataset_source.py +++ b/DashAI/back/dataset_sources/openml_dataset_source.py @@ -146,7 +146,9 @@ def search( Opaque pagination token (encodes the numeric offset). ``None`` fetches the first page. **filters : Any - Unused; reserved for future filters. + Supported keys: + tags (list[str]): Filter by OpenML tag. Only the first tag is + used; OpenML's API accepts a single ``tag`` parameter. Returns ------- @@ -157,12 +159,21 @@ def search( offset = int(cursor) if cursor else 0 list_kwargs: dict[str, Any] = { "offset": offset, - "size": limit, + "size": limit + 1, "status": "active", "output_format": "dataframe", } if query: list_kwargs["data_name"] = query + tags: list[str] = filters.get("tags") or [] + if tags: + list_kwargs["tag"] = tags[0] + if len(tags) > 1: + log.warning( + "OpenML supports only one tag filter; using %r, ignoring %r", + tags[0], + tags[1:], + ) result = openml.datasets.list_datasets(**list_kwargs) @@ -188,22 +199,25 @@ def _meta(did: str) -> tuple[str, tuple[str, ...]] | None: entries = [] for row, did, meta in zip(rows, ids, metas): description = "" - tags: list[str] = [] + dataset_tags: list[str] = [] if meta is not None: description, real_tags = meta - tags = list(real_tags) + dataset_tags = list(real_tags) entries.append( DatasetEntry( id=did, name=row.get("name", "") or "", description=description, - tags=tags, + tags=dataset_tags, size_bytes=None, url=f"https://www.openml.org/d/{did}", source=self.__class__.__name__, ) ) - next_cursor = str(offset + limit) if len(entries) == limit else None + has_next = len(entries) > limit + if has_next: + entries = entries[:limit] + next_cursor = str(offset + limit) if has_next else None return SearchPage(entries=entries, next_cursor=next_cursor) except Exception: log.exception("Error searching OpenML datasets") diff --git a/DashAI/back/dataset_sources/zenodo_dataset_source.py b/DashAI/back/dataset_sources/zenodo_dataset_source.py index c20fff7fc..2796fba19 100644 --- a/DashAI/back/dataset_sources/zenodo_dataset_source.py +++ b/DashAI/back/dataset_sources/zenodo_dataset_source.py @@ -19,6 +19,11 @@ _ZENODO_API = "https://zenodo.org/api" +def _escape_lucene_phrase(s: str) -> str: + """Escape backslash and double-quote inside a Lucene quoted-string phrase.""" + return s.replace("\\", "\\\\").replace('"', '\\"') + + def _strip_html(text: str) -> str: """Remove HTML tags from a string. @@ -122,7 +127,10 @@ def search( Opaque pagination token (encodes the page number as a string). Pass ``None`` to fetch the first page. **filters : Any - Unused; reserved for future filters. + Supported keys: + tags (list[str]): Filter by Zenodo keywords. Each tag is + appended to the query using Lucene ``keywords:""`` + syntax (AND logic). Returns ------- @@ -131,8 +139,18 @@ def search( """ try: page = int(cursor) if cursor else 1 + tags: list[str] = filters.get("tags") or [] + keyword_clauses = " AND ".join( + f'keywords:"{_escape_lucene_phrase(t)}"' for t in tags + ) + if query and tags: + zenodo_q = f"({query}) AND {keyword_clauses}" + elif tags: + zenodo_q = keyword_clauses + else: + zenodo_q = query params: dict[str, Any] = { - "q": query, + "q": zenodo_q, "type": "dataset", "page": page, "size": limit, diff --git a/DashAI/front/src/api/api.ts b/DashAI/front/src/api/api.ts index 71955268c..5035172a9 100644 --- a/DashAI/front/src/api/api.ts +++ b/DashAI/front/src/api/api.ts @@ -4,6 +4,17 @@ import i18n from "i18next"; const api: AxiosInstance = axios.create({ baseURL: process.env.REACT_APP_API_URL, + paramsSerializer: (params) => { + const sp = new URLSearchParams(); + for (const [key, val] of Object.entries(params)) { + if (Array.isArray(val)) { + val.forEach((v) => sp.append(key, String(v))); + } else if (val !== null && val !== undefined) { + sp.append(key, String(val)); + } + } + return sp.toString(); + }, }); api.interceptors.request.use((config) => { diff --git a/DashAI/front/src/api/hub.ts b/DashAI/front/src/api/hub.ts index 72d96c2dd..eb9b13c8c 100644 --- a/DashAI/front/src/api/hub.ts +++ b/DashAI/front/src/api/hub.ts @@ -35,10 +35,11 @@ export const searchDatasets = async ( query: string, limit = 20, cursor: string | null = null, + tags: string[] = [], ): Promise => { const response = await api.get( `${hubEndpoint}/${sourceName}/search`, - { params: { q: query, limit, cursor: cursor ?? "" } }, + { params: { q: query, limit, cursor: cursor ?? "", tags } }, ); return response.data; }; diff --git a/DashAI/front/src/components/hub/DatasetGrid.jsx b/DashAI/front/src/components/hub/DatasetGrid.jsx index edbd4fc5a..949e4b378 100644 --- a/DashAI/front/src/components/hub/DatasetGrid.jsx +++ b/DashAI/front/src/components/hub/DatasetGrid.jsx @@ -2,11 +2,13 @@ import { useCallback, useEffect, useRef, useState } from "react"; import { Box, Button, + Chip, CircularProgress, InputAdornment, TextField, Typography, } from "@mui/material"; +import LocalOfferIcon from "@mui/icons-material/LocalOffer"; import SearchIcon from "@mui/icons-material/Search"; import { useTranslation } from "react-i18next"; import { searchDatasets } from "../../api/hub"; @@ -31,21 +33,26 @@ export default function DatasetGrid({ }) { const { t } = useTranslation(["hub", "common"]); const [query, setQuery] = useState(""); + const [tags, setTags] = useState([]); + const [tagInput, setTagInput] = useState(""); const [datasets, setDatasets] = useState([]); const [nextCursor, setNextCursor] = useState(null); const [hasMore, setHasMore] = useState(false); const [loading, setLoading] = useState(false); const [loadingMore, setLoadingMore] = useState(false); const debounceRef = useRef(null); + const reqIdRef = useRef(0); const loadPage = useCallback( - (q, cursor, append) => { + (q, activeTags, cursor, append) => { if (!sourceName) return; if (append) setLoadingMore(true); else setLoading(true); + const reqId = ++reqIdRef.current; - searchDatasets(sourceName, q, PAGE_SIZE, cursor) + searchDatasets(sourceName, q, PAGE_SIZE, cursor, activeTags) .then(({ results, next_cursor }) => { + if (reqId !== reqIdRef.current) return; setDatasets((prev) => append ? [ @@ -59,11 +66,13 @@ export default function DatasetGrid({ setHasMore(next_cursor !== null); }) .catch(() => { + if (reqId !== reqIdRef.current) return; if (!append) setDatasets([]); setNextCursor(null); setHasMore(false); }) .finally(() => { + if (reqId !== reqIdRef.current) return; if (append) setLoadingMore(false); else setLoading(false); }); @@ -75,8 +84,10 @@ export default function DatasetGrid({ setDatasets([]); setNextCursor(null); setQuery(""); + setTags([]); + setTagInput(""); setHasMore(false); - if (sourceName) loadPage("", null, false); + if (sourceName) loadPage("", [], null, false); }, [sourceName]); const handleQueryChange = (e) => { @@ -84,11 +95,39 @@ export default function DatasetGrid({ setQuery(val); setNextCursor(null); clearTimeout(debounceRef.current); - debounceRef.current = setTimeout(() => loadPage(val, null, false), 400); + debounceRef.current = setTimeout( + () => loadPage(val, tags, null, false), + 400, + ); + }; + + const handleTagInputKeyDown = (e) => { + if (e.key === "Enter" || e.key === ",") { + e.preventDefault(); + const trimmed = tagInput.trim(); + if (trimmed && !tags.includes(trimmed)) { + const newTags = [...tags, trimmed]; + setTags(newTags); + setTagInput(""); + setNextCursor(null); + clearTimeout(debounceRef.current); + loadPage(query, newTags, null, false); + } else { + setTagInput(""); + } + } + }; + + const handleRemoveTag = (tagToRemove) => { + const newTags = tags.filter((t) => t !== tagToRemove); + setTags(newTags); + setNextCursor(null); + clearTimeout(debounceRef.current); + loadPage(query, newTags, null, false); }; const handleLoadMore = () => { - loadPage(query, nextCursor, true); + loadPage(query, tags, nextCursor, true); }; if (!sourceName) { @@ -137,6 +176,39 @@ export default function DatasetGrid({ }} /> + + setTagInput(e.target.value)} + onKeyDown={handleTagInputKeyDown} + slotProps={{ + input: { + startAdornment: ( + + + + ), + }, + }} + /> + {tags.length > 0 && ( + + {tags.map((tag) => ( + handleRemoveTag(tag)} + /> + ))} + + )} + + {loading ? ( @@ -160,7 +232,7 @@ export default function DatasetGrid({