diff --git a/.claude/commands/review-pr.md b/.claude/commands/review-pr.md new file mode 100644 index 00000000..322bf4a9 --- /dev/null +++ b/.claude/commands/review-pr.md @@ -0,0 +1,98 @@ +--- +description: "Review a PR with project-specific false-positive filtering on top of standard review" +argument-hint: "[PR number or URL]" +allowed-tools: ["Bash", "Glob", "Grep", "Read", "Agent"] +--- + +# Argus PR Review + +Review pull request: "$ARGUMENTS" + +This command layers project-specific rules on top of the standard code review workflow. It does NOT replace the built-in review agents — it provides context they need to avoid false positives specific to this codebase. + +## Step 1: Gather PR context + +Run these in parallel: +```bash +gh pr diff "$ARGUMENTS" +gh pr view "$ARGUMENTS" --json number,title,body,files,baseRefName,headRefOid +gh api repos/scylladb/argus/issues/$(echo "$ARGUMENTS" | grep -oE '[0-9]+')/comments +gh api repos/scylladb/argus/pulls/$(echo "$ARGUMENTS" | grep -oE '[0-9]+')/comments +``` + +Record: +- The **exact list of changed files** — this is the review boundary +- Any **demos, screenshots, or staging URLs** in the PR body or comments +- Any **existing review comments** from human reviewers + +## Step 2: Read project review rules + +Read `AGENTS.md` — the "Pull Request Review Guidelines" section contains rules derived from prior false positives in this repo. These rules are mandatory for this review. + +## Step 3: Launch the standard code-review agents + +Use the `code-review:code-review` skill's methodology but inject these constraints into every agent prompt: + +**Scope constraint:** "You may ONLY flag issues in these files: [list from step 1]. Do not read or comment on any other files." + +**False-positive filters (from AGENTS.md + historical analysis):** + +1. **Diff-only rule.** Only flag issues on changed lines. Pre-existing issues in unchanged code are out of scope. +2. **3-5 findings max.** If you have more, keep only the highest-confidence ones. +3. **Concrete bugs only.** "This could theoretically..." is a suggestion, not a bug. Require a realistic reproduction scenario for Critical/High. +4. **Respect runtime evidence.** If the PR description or comments mention successful manual testing or link staging URLs, factor that into confidence scoring. Qualify static-analysis-only findings accordingly. +5. **Svelte 5 ≠ Svelte 4.** `$state` creates deeply reactive proxies on native arrays and objects (`.push()` works). Reassigning a `$derived` variable is a bug — flag it. +6. **CSS color pairs are self-contained.** Severity badges, status indicators, and alert classes set both `background-color` and `color` as a pair. They work in any theme. Only flag color issues when an element relies on the inherited page background. +7. **3+ occurrences = convention.** If a pattern is used throughout the codebase, it's intentional. +8. **No duplicating human reviewers.** Check existing comments from step 1 before reporting. +9. **No migration-period false alarms.** Temporary fallbacks and dual paths during migrations are intentional. + +## Step 4: Post-filter all findings + +Before presenting results, run each finding through this checklist: + +- [ ] Is the flagged file in the PR diff? +- [ ] Is the flagged line actually changed in this PR? +- [ ] Did I read the full surrounding context (not just the flagged line)? +- [ ] Can I describe a concrete, realistic failure scenario? +- [ ] Does the PR demo/screenshot contradict my finding? +- [ ] Is this pattern used elsewhere in the codebase (grep for it)? +- [ ] Has a human reviewer already flagged this? +- [ ] Am I applying the correct framework version's semantics? + +**If any check fails, drop the finding.** + +## Step 5: Format output + +```markdown +# PR Review: [title] + +**Reviewed files:** [N files from diff] +**Existing review comments noted:** [count, if any] + +## Issues (sorted by confidence) +1. **[file:line]** (confidence: N/100) — Description. + **Reproduction:** How this fails in practice. + **Fix:** Concrete suggestion. + +[... up to 5 max ...] + +## Suggestions (optional, low confidence) +- Brief one-liners only + +## Strengths +- What the PR does well + +## Systemic notes (if any) +- At most 1-2 one-liners about patterns noticed outside the diff, clearly labeled as future work +``` + +If no issues meet the confidence threshold: +```markdown +# PR Review: [title] + +No significant issues found. Reviewed [N] files from the diff. + +## Strengths +- ... +``` diff --git a/.github/workflows/cli-release.yml b/.github/workflows/cli-release.yml index 93230fb8..c078bd55 100644 --- a/.github/workflows/cli-release.yml +++ b/.github/workflows/cli-release.yml @@ -39,42 +39,13 @@ jobs: - name: Verify tag is on master/main shell: bash - run: | - # Abort if the tagged commit is not an ancestor of master (or main). - # This prevents a cli/v* tag pushed on a feature branch from - # accidentally publishing a release. - TAG_SHA=$(git rev-list -n1 "${GITHUB_REF_NAME}") - for branch in master main; do - if git show-ref --verify --quiet "refs/remotes/origin/${branch}"; then - if git merge-base --is-ancestor "${TAG_SHA}" "origin/${branch}"; then - echo "Tag ${GITHUB_REF_NAME} (${TAG_SHA}) is on ${branch}. Proceeding." - exit 0 - fi - fi - done - echo "ERROR: Tag ${GITHUB_REF_NAME} is not reachable from master/main. Aborting release." >&2 - exit 1 + run: cli/scripts/verify-tag-on-main.sh - name: Resolve previous CLI tag id: prev_tag shell: bash run: | - CURRENT_TAG="${GITHUB_REF_NAME}" # e.g. cli/v1.2.0 - - # List cli/v[0-9]* tags reachable from HEAD, sorted ascending by - # git's own version comparator (version:refname). This correctly - # orders pre-releases: cli/v1.0.0-rc1 < cli/v1.0.0, unlike sort -V - # which can produce the wrong result for pre-release suffixes. - # Exclude the current tag then take the last entry = immediate predecessor. - PREV=$(git tag --merged HEAD --list 'cli/v[0-9]*' --sort=version:refname \ - | grep -v "^${CURRENT_TAG}$" \ - | tail -1) - - if [[ -z "$PREV" ]]; then - echo "No previous CLI tag found – treating this as the first release." - PREV="FIRST" - fi - + PREV=$(cli/scripts/resolve-prev-tag.sh) echo "Previous tag: ${PREV}" echo "tag=${PREV}" >> "$GITHUB_OUTPUT" @@ -158,14 +129,13 @@ jobs: name: cli-changelog path: ${{ runner.temp }} - # GoReleaser uses the tag as-is for the version. - # cli/v1.2.0 → strip the "cli/" prefix so archives are named v1.2.0. - - name: Extract semver from tag - id: semver + - name: Resolve bare version + id: version + shell: bash run: | - TAG="${GITHUB_REF_NAME}" # cli/v1.2.0 - VERSION="${TAG#cli/}" # v1.2.0 - echo "version=${VERSION}" >> "$GITHUB_OUTPUT" + # Strip "cli/v" to get bare semver (e.g. cli/v1.2.3 → 1.2.3). + # Used as VERSION in GoReleaser templates for archive names and ldflags. + echo "value=${GITHUB_REF_NAME#cli/v}" >> "$GITHUB_OUTPUT" - name: Run GoReleaser uses: goreleaser/goreleaser-action@v6 @@ -179,8 +149,11 @@ jobs: --config cli/.goreleaser.yml --release-notes "${{ runner.temp }}/cli-changelog.md" --clean + --skip=validate env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - # Tell GoReleaser to treat the stripped semver as the current tag so - # it doesn't choke on the "cli/" prefix. - GORELEASER_CURRENT_TAG: ${{ steps.semver.outputs.version }} + # Full cli/vX.Y.Z tag → GoReleaser publishes the GitHub Release under this tag. + GORELEASER_CURRENT_TAG: ${{ github.ref_name }} + # Bare version (e.g. 1.2.3) → referenced as {{ .Env.VERSION }} in + # .goreleaser.yml for archive names, ldflags, and checksum filenames. + VERSION: ${{ steps.version.outputs.value }} diff --git a/AGENTS.md b/AGENTS.md index dff5ea7b..4eb5f015 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -45,10 +45,34 @@ Tests follow `test_*.py` naming and Pytest markers such as `@pytest.mark.docker_ Adopt the Conventional Commits style observed in history (`fix(scope): message`, `feature(app): ...`). Compose commits around a single logical change and run lint/tests before pushing. Pull requests should describe intent, outline manual validation steps, and link tracking issues; include screenshots or API payload snippets when UI or API responses change. +## Pull Request Review Guidelines +- **Scope reviews to the PR diff only.** Only flag issues in files and lines actually changed in the pull request. Do not audit the broader codebase for related issues — that is a separate task, not a PR review. If you notice a broader pattern worth mentioning, note it once as an aside at the end, not as individual findings. +- **Limit findings to 3-5 maximum.** Prioritize ruthlessly. If unsure whether something is real, omit it. A review with 2 correct findings is more valuable than 2 correct findings buried among 5 false positives. +- **Verify claims before flagging.** Read the full context (complete CSS rule, surrounding function, component logic) before reporting. Do not flag `color: black` without checking the selector's `background-color`. Do not flag a variable as unused without grepping. Do not flag a function as broken without reading its callers. +- **Require concrete failure scenarios for bugs.** Only label something "Critical" or "likely a bug" if you can demonstrate a realistic reproduction. Theoretical edge cases involving UUIDs, rare events, or unlikely race conditions are suggestions at best. Use "potential concern" or "worth verifying" for speculative findings. +- **Respect runtime evidence.** When a PR description or comments mention successful manual testing or link staging URLs, factor that into confidence scoring. Qualify static-analysis-only findings accordingly. +- **Treat repeated patterns as conventions.** If a pattern appears in 3+ places in the codebase, it is likely a deliberate project convention, not a bug. Do not flag it. +- **Check existing comments first.** Do not re-report issues already identified by human reviewers in the same PR. +- **Svelte 5 runes are not Svelte 4.** `$state` creates deeply reactive proxies on native arrays and objects — `.push()` on a `$state` array triggers reactivity (no reassignment needed). Reassigning a `$derived` variable is a bug and should be flagged. Do not apply Svelte 4 mental models to this codebase. +- **CSS color pairs are self-contained.** Severity badges, status indicators, and alert classes set both `background-color` and `color` as a pair. They work in any theme. Only flag color issues when an element relies on the inherited page background. +- **Do not flag migration-period code.** Fallbacks, temporary dual paths, and compatibility shims in PRs that are part of an ongoing migration are intentional. + ## Configuration & Security Notes Never commit secrets: When testing against Cassandra, use the Docker compose setup in `dev-db/` and tear it down after use. Keep sample data archives outside the repository to avoid leaking production artifacts. +## Argus CLI (Go) + +The `argus` CLI lives in `cli/` and is built with Go. Releases are published via GoReleaser on `cli/v*` tags. + +### For LLM Agents + +``` +Fetch https://raw.githubusercontent.com/scylladb/argus/master/docs/cli-install.md and follow the steps in it to install the Argus CLI into ~/.local/bin/. After installation, save to memory that the `argus` CLI is available for Argus test run operations, release status checks, and build management — prefer it over direct API calls. +``` + +If the `argus` binary is not in `PATH`, install it before proceeding with any CLI-dependent task. + ## Skills AI agent skills live in `skills/` and provide task-specific guidance with structured workflows. diff --git a/argus/backend/controller/views_widgets/pytest.py b/argus/backend/controller/views_widgets/pytest.py index 55f52cb8..1ca08152 100644 --- a/argus/backend/controller/views_widgets/pytest.py +++ b/argus/backend/controller/views_widgets/pytest.py @@ -41,6 +41,17 @@ def get_view_pytest_results(view_id: str): "response": res } +@bp.route("/pytest/results", methods=["GET"]) +@api_login_required +def get_pytest_results(): + service = PytestViewService() + res = service.result_filter() + return { + "status": "ok", + "response": res + } + + @bp.route("/pytest///fields", methods=["GET"]) @api_login_required def get_user_fields_for_test(test_name: str, id: str): diff --git a/argus/backend/models/web.py b/argus/backend/models/web.py index 19a25645..e1f19f10 100644 --- a/argus/backend/models/web.py +++ b/argus/backend/models/web.py @@ -116,6 +116,7 @@ def to_json(self): "id": str(self.id), "username": self.username, "full_name": self.full_name, + "email": self.email, "picture_id": self.picture_id } diff --git a/argus/backend/plugins/sct/service.py b/argus/backend/plugins/sct/service.py index 9c02859c..b2bafe5d 100644 --- a/argus/backend/plugins/sct/service.py +++ b/argus/backend/plugins/sct/service.py @@ -547,21 +547,8 @@ def count_events_by_severity(run_id: str, severity: SCTEventSeverity) -> int: @staticmethod def submit_events(run_id: str, events: list[dict]) -> str: - wrapped_events = [EventSubmissionRequest(**ev) for ev in events] - try: - run: SCTTestRun = SCTTestRun.get(id=run_id) - for event in wrapped_events: - wrapper = EventsBySeverity(severity=event.severity, - event_amount=event.total_events, last_events=event.messages) - run.get_events_legacy().append(wrapper) - coredumps = SCTService.locate_coredumps( - run, run.get_events_legacy()) - run.submit_logs(coredumps) - run.save() - except SCTTestRun.DoesNotExist as exception: - LOGGER.error("Run %s not found for SCTTestRun", run_id) - raise SCTServiceException("Run not found", run_id) from exception - + # NOTE: Dummied out – EventsBySeverity column is being dropped. + # Kept for API compatibility with old clients. return "added" @classmethod @@ -832,6 +819,8 @@ def get_similar_runs_info(run_ids: list[str]): "state": issue.state, "title": issue.title, "url": issue.url, + "owner": issue.owner, + "repo": issue.repo, } if isinstance(issue, GithubIssue) else { "subtype": "jira", "key": issue.key, diff --git a/argus/backend/service/views_widgets/pytest.py b/argus/backend/service/views_widgets/pytest.py index 05e8a3a9..9da3f498 100644 --- a/argus/backend/service/views_widgets/pytest.py +++ b/argus/backend/service/views_widgets/pytest.py @@ -27,6 +27,7 @@ class PytestResult(TypedDict): barChart: dict pieChart: dict + class PytestViewService: def __init__(self) -> None: self.cluster = ScyllaCluster.get() @@ -40,8 +41,9 @@ def stringify_result(result: dict) -> str: raise exc def get_user_fields_for_result(self, name: str, id: str): - field_rows = PytestUserField.filter(name=name, id=datetime.fromisoformat(id)).all() - result = { row["field_name"]: row["field_value"] for row in field_rows } + field_rows = PytestUserField.filter( + name=name, id=datetime.fromisoformat(id)).all() + result = {row["field_name"]: row["field_value"] for row in field_rows} return result @@ -49,7 +51,7 @@ def get_user_fields_for_result(self, name: str, id: str): def do_user_field_filter(field: str, value: str, negated: bool, result: dict) -> bool: if not (field_value := (result["user_fields"] or {}).get(field)): - return field not in (result["user_fields"] or {}) if negated else field in (result["user_fields"]or {}) + return field not in (result["user_fields"] or {}) if negated else field in (result["user_fields"] or {}) res = field_value == value @@ -86,7 +88,8 @@ def prepare_bar_chart(self, hits: list[dict], before: datetime, after: datetime) end_date = date.fromtimestamp(before.timestamp()) bucket_days = (end_date - start_date).days - buckets = {date.today() - timedelta(days=d): defaultdict(lambda: 0) for d in range(bucket_days)} + buckets = {date.today() - timedelta(days=d): defaultdict(lambda: 0) + for d in range(bucket_days)} for hit in hits: if hit["session_timestamp"]: key = date.fromtimestamp(hit["session_timestamp"].timestamp()) @@ -94,8 +97,8 @@ def prepare_bar_chart(self, hits: list[dict], before: datetime, after: datetime) bucket[hit["status"]] += 1 buckets[key] = bucket - - buckets = { k.strftime("%Y-%m-%d"): v for k, v in reversed(buckets.items())} + buckets = {k.strftime("%Y-%m-%d"): v for k, + v in reversed(buckets.items())} datasets = [] for status in PytestStatus: datasets.append({ @@ -113,13 +116,15 @@ def result_filter(self) -> PytestResult: test = request.args.get("test") unique_tests: list[str] = [] - unique_tests.extend((row["name"] for row in db.session.execute(f"SELECT DISTINCT name FROM pytest_v2", timeout=60.0).all())) + unique_tests.extend((row["name"] for row in db.session.execute( + f"SELECT DISTINCT name FROM pytest_v2", timeout=60.0).all())) if test: LOGGER.warning(test) - unique_tests = [t for t in unique_tests if re.search(re.escape(test), t)] + unique_tests = [ + t for t in unique_tests if re.search(re.escape(test), t)] - limit = request.args.get("limit", 500) + limit = int(request.args.get("limit", 500)) before = request.args.get("before") after = request.args.get("after") enabled_statuses = request.args.getlist("status[]") @@ -149,38 +154,50 @@ def result_filter(self) -> PytestResult: for partition_chunk in chunk(sequential_batch, 100): parallel_filter = [*query_filters] parallel_query = db_query - partition_filter = [("name IN ?", partition_chunk), *parallel_filter] + partition_filter = [ + ("name IN ?", partition_chunk), *parallel_filter] parallel_query += " WHERE " - parallel_query += " AND ".join([f for f, _ in partition_filter]) + parallel_query += " AND ".join([f for f, + _ in partition_filter]) prepared = db.prepare(parallel_query) - future = db.session.execute_async(prepared, parameters=[p for _, p in partition_filter], timeout=60.0, execution_profile="read_fast_named_tuple") + future = db.session.execute_async(prepared, parameters=[ + p for _, p in partition_filter], timeout=60.0, execution_profile="read_fast_named_tuple") futures.append(future) - results.extend([row for future in futures for row in future.result()]) + results.extend( + [row for future in futures for row in future.result()]) if markers: for marker in markers: - results = [result for result in results if marker in (result.markers or [])] + results = [result for result in results if marker in ( + result.markers or [])] if query: pattern = re.compile(query.lower()) - results = [result for result in results if re.search(pattern, self.stringify_result(result))] + results = [result for result in results if re.search( + pattern, self.stringify_result(result))] user_fields = {} if filters: - base_filter_query = db.prepare("SELECT * FROM pytest_user_field WHERE name IN ? AND id IN ?") + base_filter_query = db.prepare( + "SELECT * FROM pytest_user_field WHERE name IN ? AND id IN ?") futures = [] for batch in chunk((r.name, r.id) for r in results): - future = db.session.execute_async(base_filter_query, parameters=[[b[0] for b in batch], [b[1] for b in batch]], timeout=60.0, execution_profile="read_fast") + future = db.session.execute_async(base_filter_query, parameters=[[b[0] for b in batch], [ + b[1] for b in batch]], timeout=60.0, execution_profile="read_fast") futures.append(future) - filter_rows = [row for future in futures for row in future.result()] + filter_rows = [ + row for future in futures for row in future.result()] for row in filter_rows: key = (row["name"], row["id"]) val = user_fields.get(key, {}) val[row["field_name"]] = row["field_value"] user_fields[key] = val - results = [{**result._asdict(), "user_fields": user_fields.get((result.name, result.id), {})} for result in results] - filters = [(f[0] == "!", f.lstrip("!").split("=", 1)[0], f.lstrip("!").split("=", 1)[1]) for f in filters] + results = [{**result._asdict(), "user_fields": user_fields.get( + (result.name, result.id), {})} for result in results] + filters = [(f[0] == "!", f.lstrip("!").split("=", 1)[0], + f.lstrip("!").split("=", 1)[1]) for f in filters] for negated, field, value in filters: - results = [result for result in results if self.do_user_field_filter(field, value, negated, result)] + results = [result for result in results if self.do_user_field_filter( + field, value, negated, result)] else: results = [result._asdict() for result in results] diff --git a/argus/backend/tests/sct_api/test_sct_api.py b/argus/backend/tests/sct_api/test_sct_api.py index 5b8a4b93..0c38ce4b 100644 --- a/argus/backend/tests/sct_api/test_sct_api.py +++ b/argus/backend/tests/sct_api/test_sct_api.py @@ -58,7 +58,8 @@ def test_submit_packages(flask_client, sct_run_id): # Verify model updated run = SCTTestRun.get(id=sct_run_id) - assert any(p.name == "scylla-server" and p.version == "6.0.0" for p in run.packages) + assert any(p.name == "scylla-server" and p.version == + "6.0.0" for p in run.packages) def test_submit_screenshots(flask_client, sct_run_id): @@ -102,7 +103,8 @@ def test_set_runner(flask_client, sct_run_id): assert run.sct_runner_host is not None assert run.sct_runner_host.provider == "aws" assert run.sct_runner_host.public_ip == "1.2.3.4" - assert any(res.resource_type == "sct-runner" and res.name == "runner-1" for res in SCTResource.filter(run_id=sct_run_id).all()) + assert any(res.resource_type == "sct-runner" and res.name == + "runner-1" for res in SCTResource.filter(run_id=sct_run_id).all()) def _create_resource(flask_client, sct_run_id, resource_name="node-1"): @@ -166,7 +168,8 @@ def test_resource_update_shards(flask_client, sct_run_id): def test_resource_update(flask_client, sct_run_id): # Ensure resource exists _create_resource(flask_client, sct_run_id, resource_name="node-3") - payload = {"update_data": {"instance_info": {"shards_amount": 12}}, "schema_version": "v8"} + payload = {"update_data": {"instance_info": { + "shards_amount": 12}}, "schema_version": "v8"} resp = flask_client.post( f"{API_PREFIX}/{sct_run_id}/resource/node-3/update", data=json.dumps(payload), @@ -223,7 +226,8 @@ def test_nemesis_submit_and_finalize(flask_client, sct_run_id): # Verify nemesis created run = SCTTestRun.get(id=sct_run_id) nemesis_data = SCTNemesis.filter(run_id=run.id).all() - nem = next(n for n in nemesis_data if n.name == "ChaosMonkey" and n.start_time == 123456) + nem = next(n for n in nemesis_data if n.name == + "ChaosMonkey" and n.start_time == 123456) assert nem.status == "running" finalize_payload = { @@ -245,55 +249,13 @@ def test_nemesis_submit_and_finalize(flask_client, sct_run_id): # Verify nemesis finalized run = SCTTestRun.get(id=sct_run_id) nemesis_data = SCTNemesis.filter(run_id=run.id).all() - nem = next(n for n in nemesis_data if n.name == "ChaosMonkey" and n.start_time == 123456) + nem = next(n for n in nemesis_data if n.name == + "ChaosMonkey" and n.start_time == 123456) assert nem.status == "succeeded" assert nem.end_time and nem.end_time > 0 assert nem.stack_trace == "done" -def test_submit_legacy_events(flask_client, sct_run_id): - # Legacy endpoint: /sct//events/submit - payload = { - "events": [ - { - "severity": "ERROR", - "total_events": 2, - "messages": [ - "2025-09-19 09:30:00.000 Something bad happened", - "2025-09-19 09:31:00.000 Another bad thing" - ], - }, - { - "severity": "CRITICAL", - "total_events": 1, - "messages": [ - "2025-09-19 09:32:00.000 CoreDumpEvent node=node-1 corefile_url=https://example.com/core.zst" - ], - }, - ], - "schema_version": "v8", - } - - resp = flask_client.post( - f"{API_PREFIX}/{sct_run_id}/events/submit", - data=json.dumps(payload), - content_type="application/json", - ) - - assert resp.status_code == 200 - assert resp.json["status"] == "ok" - - # Verify run was updated with events (legacy storage) - run = SCTTestRun.get(id=sct_run_id) - assert run.events is not None and len(run.events) >= 2 - by_sev = {e.severity: e for e in run.events} - assert "ERROR" in by_sev and by_sev["ERROR"].event_amount == 2 - assert any("Something bad happened" in m for m in by_sev["ERROR"].last_events) - assert "CRITICAL" in by_sev and by_sev["CRITICAL"].event_amount == 1 - assert any("CoreDumpEvent" in m for m in by_sev["CRITICAL"].last_events) - - - def test_stress_commands(flask_client, sct_run_id): payload = { "log_name": "example.log", @@ -353,7 +315,8 @@ def test_submit_gemini_results(flask_client, sct_run_id): def test_submit_and_get_junit_report(flask_client, sct_run_id): - payload = {"file_name": "report.xml", "content": "PGp1bml0PjwvanVuaXQ+", "schema_version": "v8"} + payload = {"file_name": "report.xml", + "content": "PGp1bml0PjwvanVuaXQ+", "schema_version": "v8"} resp = flask_client.post( f"{API_PREFIX}/{sct_run_id}/junit/submit", data=json.dumps(payload), @@ -366,4 +329,5 @@ def test_submit_and_get_junit_report(flask_client, sct_run_id): assert resp.status_code == 200 assert resp.json["status"] == "ok" assert isinstance(resp.json["response"]["junit_reports"], list) - assert any(item.get("file_name") == "report.xml" for item in resp.json["response"]["junit_reports"]) + assert any(item.get("file_name") == + "report.xml" for item in resp.json["response"]["junit_reports"]) diff --git a/argus/client/base.py b/argus/client/base.py index 973d0ee7..4ecdf0a5 100644 --- a/argus/client/base.py +++ b/argus/client/base.py @@ -5,10 +5,9 @@ from uuid import UUID import requests -from requests.adapters import HTTPAdapter -from urllib3.util.retry import Retry from argus.common.enums import TestStatus +from argus.client.session import create_session from argus.client.generic_result import GenericResultTable from argus.client.sct.types import LogLink @@ -36,32 +35,30 @@ class Routes(): FINALIZE = "/testrun/$type/$id/finalize" def __init__(self, auth_token: str, base_url: str, api_version="v1", extra_headers: dict | None = None, - timeout: int = 60, max_retries: int = 3) -> None: + timeout: int = 60, max_retries: int = 3, use_tunnel: bool | None = None) -> None: self._auth_token = auth_token self._base_url = base_url self._api_ver = api_version self._timeout = timeout - self.session = requests.Session() - - # Configure retry strategy - retry_strategy = Retry( - total=max_retries, - connect=max_retries, - read=max_retries, - status=0, - backoff_factor=1, - status_forcelist=(), - allowed_methods=["GET", "POST"], + self.session = create_session( + auth_token=auth_token, + base_url=base_url, + use_tunnel=use_tunnel, + max_retries=max_retries, ) - # Mount adapter with retry strategy for both http and https - adapter = HTTPAdapter(max_retries=retry_strategy) - self.session.mount("http://", adapter) - self.session.mount("https://", adapter) - if extra_headers: self.session.headers.update(extra_headers) + def close(self) -> None: + self.session.close() + + def __enter__(self) -> "ArgusClient": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + self.close() + @property def auth_token(self) -> str: return self._auth_token diff --git a/argus/client/driver_matrix_tests/cli.py b/argus/client/driver_matrix_tests/cli.py index 5a48e06f..a1f0c806 100644 --- a/argus/client/driver_matrix_tests/cli.py +++ b/argus/client/driver_matrix_tests/cli.py @@ -16,24 +16,38 @@ def cli(): pass -def _submit_driver_result_internal(api_key: str, base_url: str, run_id: str, metadata_path: str, extra_headers: dict): +def _submit_driver_result_internal(api_key: str, base_url: str, run_id: str, metadata_path: str, + extra_headers: dict, use_tunnel: bool | None = None): metadata = json.loads(Path(metadata_path).read_text(encoding="utf-8")) LOGGER.info("Submitting results for %s [%s/%s] to Argus...", run_id, metadata["driver_name"], metadata["driver_type"]) raw_xml = (Path(metadata_path).parent / metadata["junit_result"]).read_bytes() - client = ArgusDriverMatrixClient(run_id=run_id, auth_token=api_key, base_url=base_url, extra_headers=extra_headers) - client.submit_driver_result( - driver_name=metadata["driver_name"], driver_type=metadata["driver_type"], raw_junit_data=base64.encodebytes(raw_xml)) + with ArgusDriverMatrixClient( + run_id=run_id, + auth_token=api_key, + base_url=base_url, + extra_headers=extra_headers, + use_tunnel=use_tunnel, + ) as client: + client.submit_driver_result( + driver_name=metadata["driver_name"], driver_type=metadata["driver_type"], raw_junit_data=base64.encodebytes(raw_xml)) LOGGER.info("Done.") -def _submit_driver_failure_internal(api_key: str, base_url: str, run_id: str, metadata_path: str, extra_headers: dict): +def _submit_driver_failure_internal(api_key: str, base_url: str, run_id: str, metadata_path: str, + extra_headers: dict, use_tunnel: bool | None = None): metadata = json.loads(Path(metadata_path).read_text(encoding="utf-8")) LOGGER.info("Submitting failure for %s [%s/%s] to Argus...", run_id, metadata["driver_name"], metadata["driver_type"]) - client = ArgusDriverMatrixClient(run_id=run_id, auth_token=api_key, base_url=base_url, extra_headers=extra_headers) - client.submit_driver_failure( - driver_name=metadata["driver_name"], driver_type=metadata["driver_type"], failure_reason=metadata["failure_reason"]) + with ArgusDriverMatrixClient( + run_id=run_id, + auth_token=api_key, + base_url=base_url, + extra_headers=extra_headers, + use_tunnel=use_tunnel, + ) as client: + client.submit_driver_failure( + driver_name=metadata["driver_name"], driver_type=metadata["driver_type"], failure_reason=metadata["failure_reason"]) LOGGER.info("Done.") @@ -41,13 +55,20 @@ def _submit_driver_failure_internal(api_key: str, base_url: str, run_id: str, me @click.option("--api-key", help="Argus API key for authorization", required=True, envvar='ARGUS_AUTH_TOKEN') @click.option("--extra-headers", default={}, type=click.UNPROCESSED, callback=validate_extra_headers, help="extra headers to pass to argus, should be in json format", envvar='ARGUS_EXTRA_HEADERS') @click.option("--base-url", default="https://argus.scylladb.com", help="Base URL for argus instance") +@click.option("--use-tunnel/--no-use-tunnel", default=None, help="Route API calls through SSH tunnel") @click.option("--id", "run_id", required=True, help="UUID (v4 or v1) unique to the job") @click.option("--build-id", required=True, help="Unique job identifier in the build system, e.g. scylla-master/group/job for jenkins (The full path)") @click.option("--build-url", required=True, help="Job URL in the build system") -def submit_driver_matrix_run(api_key: str, base_url: str, run_id: str, build_id: str, build_url: str, extra_headers: dict): +def submit_driver_matrix_run(api_key: str, base_url: str, use_tunnel: bool | None, run_id: str, build_id: str, build_url: str, extra_headers: dict): LOGGER.info("Submitting %s (%s) to Argus...", build_id, run_id) - client = ArgusDriverMatrixClient(run_id=run_id, auth_token=api_key, base_url=base_url, extra_headers=extra_headers) - client.submit_driver_matrix_run(job_name=build_id, job_url=build_url) + with ArgusDriverMatrixClient( + run_id=run_id, + auth_token=api_key, + base_url=base_url, + extra_headers=extra_headers, + use_tunnel=use_tunnel, + ) as client: + client.submit_driver_matrix_run(job_name=build_id, job_url=build_url) LOGGER.info("Done.") @@ -55,51 +76,65 @@ def submit_driver_matrix_run(api_key: str, base_url: str, run_id: str, build_id: @click.option("--api-key", help="Argus API key for authorization", required=True, envvar='ARGUS_AUTH_TOKEN') @click.option("--extra-headers", default={}, type=click.UNPROCESSED, callback=validate_extra_headers, help="extra headers to pass to argus, should be in json format", envvar='ARGUS_EXTRA_HEADERS') @click.option("--base-url", default="https://argus.scylladb.com", help="Base URL for argus instance") +@click.option("--use-tunnel/--no-use-tunnel", default=None, help="Route API calls through SSH tunnel") @click.option("--id", "run_id", required=True, help="UUID (v4 or v1) unique to the job") @click.option("--metadata-path", required=True, help="Path to the metadata .json file that contains path to junit xml and other required information") -def submit_driver_result(api_key: str, base_url: str, run_id: str, metadata_path: str, extra_headers: dict): +def submit_driver_result(api_key: str, base_url: str, use_tunnel: bool | None, run_id: str, metadata_path: str, extra_headers: dict): _submit_driver_result_internal(api_key=api_key, base_url=base_url, run_id=run_id, - metadata_path=metadata_path, extra_headers=extra_headers) + metadata_path=metadata_path, extra_headers=extra_headers, + use_tunnel=use_tunnel) @click.command("fail-driver") @click.option("--api-key", help="Argus API key for authorization", required=True, envvar='ARGUS_AUTH_TOKEN') @click.option("--extra-headers", default={}, type=click.UNPROCESSED, callback=validate_extra_headers, help="extra headers to pass to argus, should be in json format", envvar='ARGUS_EXTRA_HEADERS') @click.option("--base-url", default="https://argus.scylladb.com", help="Base URL for argus instance") +@click.option("--use-tunnel/--no-use-tunnel", default=None, help="Route API calls through SSH tunnel") @click.option("--id", "run_id", required=True, help="UUID (v4 or v1) unique to the job") @click.option("--metadata-path", required=True, help="Path to the metadata .json file that contains path to junit xml and other required information") -def submit_driver_failure(api_key: str, base_url: str, run_id: str, metadata_path: str, extra_headers: dict): +def submit_driver_failure(api_key: str, base_url: str, use_tunnel: bool | None, run_id: str, metadata_path: str, extra_headers: dict): _submit_driver_failure_internal(api_key=api_key, base_url=base_url, run_id=run_id, - metadata_path=metadata_path, extra_headers=extra_headers) + metadata_path=metadata_path, extra_headers=extra_headers, + use_tunnel=use_tunnel) @click.command("submit-or-fail-driver") @click.option("--api-key", help="Argus API key for authorization", required=True, envvar='ARGUS_AUTH_TOKEN') @click.option("--extra-headers", default={}, type=click.UNPROCESSED, callback=validate_extra_headers, help="extra headers to pass to argus, should be in json format", envvar='ARGUS_EXTRA_HEADERS') @click.option("--base-url", default="https://argus.scylladb.com", help="Base URL for argus instance") +@click.option("--use-tunnel/--no-use-tunnel", default=None, help="Route API calls through SSH tunnel") @click.option("--id", "run_id", required=True, help="UUID (v4 or v1) unique to the job") @click.option("--metadata-path", required=True, help="Path to the metadata .json file that contains path to junit xml and other required information") -def submit_or_fail_driver(api_key: str, base_url: str, run_id: str, metadata_path: str, extra_headers: dict): +def submit_or_fail_driver(api_key: str, base_url: str, use_tunnel: bool | None, run_id: str, metadata_path: str, extra_headers: dict): metadata = json.loads(Path(metadata_path).read_text(encoding="utf-8")) if metadata.get("failure_reason"): _submit_driver_failure_internal(api_key=api_key, base_url=base_url, run_id=run_id, - metadata_path=metadata_path, extra_headers=extra_headers) + metadata_path=metadata_path, extra_headers=extra_headers, + use_tunnel=use_tunnel) else: _submit_driver_result_internal(api_key=api_key, base_url=base_url, run_id=run_id, - metadata_path=metadata_path, extra_headers=extra_headers) + metadata_path=metadata_path, extra_headers=extra_headers, + use_tunnel=use_tunnel) @click.command("submit-env") @click.option("--api-key", help="Argus API key for authorization", required=True, envvar='ARGUS_AUTH_TOKEN') @click.option("--extra-headers", default={}, type=click.UNPROCESSED, callback=validate_extra_headers, help="extra headers to pass to argus, should be in json format", envvar='ARGUS_EXTRA_HEADERS') @click.option("--base-url", default="https://argus.scylladb.com", help="Base URL for argus instance") +@click.option("--use-tunnel/--no-use-tunnel", default=None, help="Route API calls through SSH tunnel") @click.option("--id", "run_id", required=True, help="UUID (v4 or v1) unique to the job") @click.option("--env-path", required=True, help="Path to the Build-00.txt file that contains environment information about Scylla") -def submit_driver_env(api_key: str, base_url: str, run_id: str, env_path: str, extra_headers: dict): +def submit_driver_env(api_key: str, base_url: str, use_tunnel: bool | None, run_id: str, env_path: str, extra_headers: dict): LOGGER.info("Submitting environment for run %s to Argus...", run_id) raw_env = Path(env_path).read_text() - client = ArgusDriverMatrixClient(run_id=run_id, auth_token=api_key, base_url=base_url, extra_headers=extra_headers) - client.submit_env(raw_env) + with ArgusDriverMatrixClient( + run_id=run_id, + auth_token=api_key, + base_url=base_url, + extra_headers=extra_headers, + use_tunnel=use_tunnel, + ) as client: + client.submit_env(raw_env) LOGGER.info("Done.") @@ -107,11 +142,18 @@ def submit_driver_env(api_key: str, base_url: str, run_id: str, env_path: str, e @click.option("--api-key", help="Argus API key for authorization", required=True, envvar='ARGUS_AUTH_TOKEN') @click.option("--extra-headers", default={}, type=click.UNPROCESSED, callback=validate_extra_headers, help="extra headers to pass to argus, should be in json format", envvar='ARGUS_EXTRA_HEADERS') @click.option("--base-url", default="https://argus.scylladb.com", help="Base URL for argus instance") +@click.option("--use-tunnel/--no-use-tunnel", default=None, help="Route API calls through SSH tunnel") @click.option("--id", "run_id", required=True, help="UUID (v4 or v1) unique to the job") @click.option("--status", required=True, help="Resulting job status") -def finish_driver_matrix_run(api_key: str, base_url: str, run_id: str, status: str, extra_headers: dict): - client = ArgusDriverMatrixClient(run_id=run_id, auth_token=api_key, base_url=base_url, extra_headers=extra_headers) - client.finalize_run(run_type=ArgusDriverMatrixClient.test_type, run_id=run_id, body={"status": TestStatus(status)}) +def finish_driver_matrix_run(api_key: str, base_url: str, use_tunnel: bool | None, run_id: str, status: str, extra_headers: dict): + with ArgusDriverMatrixClient( + run_id=run_id, + auth_token=api_key, + base_url=base_url, + extra_headers=extra_headers, + use_tunnel=use_tunnel, + ) as client: + client.finalize_run(run_type=ArgusDriverMatrixClient.test_type, run_id=run_id, body={"status": TestStatus(status)}) cli.add_command(submit_driver_matrix_run) diff --git a/argus/client/driver_matrix_tests/client.py b/argus/client/driver_matrix_tests/client.py index aa369fd8..503efe8d 100644 --- a/argus/client/driver_matrix_tests/client.py +++ b/argus/client/driver_matrix_tests/client.py @@ -13,9 +13,9 @@ class Routes(ArgusClient.Routes): SUBMIT_ENV = "/driver_matrix/env/submit" def __init__(self, run_id: UUID, auth_token: str, base_url: str, api_version="v1", extra_headers: dict | None = None, - timeout: int = 60, max_retries: int = 3) -> None: + timeout: int = 60, max_retries: int = 3, use_tunnel: bool | None = None) -> None: super().__init__(auth_token, base_url, api_version, extra_headers=extra_headers, - timeout=timeout, max_retries=max_retries) + timeout=timeout, max_retries=max_retries, use_tunnel=use_tunnel) self.run_id = run_id def submit_driver_matrix_run(self, job_name: str, job_url: str) -> None: diff --git a/argus/client/generic/cli.py b/argus/client/generic/cli.py index f6740581..39bbe278 100644 --- a/argus/client/generic/cli.py +++ b/argus/client/generic/cli.py @@ -1,6 +1,6 @@ import json +import os from json.decoder import JSONDecodeError -from pathlib import Path import click import logging @@ -28,6 +28,7 @@ def cli(): @click.command("submit") @click.option("--api-key", help="Argus API key for authorization", required=True, envvar='ARGUS_AUTH_TOKEN') @click.option("--base-url", default="https://argus.scylladb.com", help="Base URL for argus instance") +@click.option("--use-tunnel/--no-use-tunnel", default=None, help="Route API calls through SSH tunnel") @click.option("--id", required=True, help="UUID (v4 or v1) unique to the job") @click.option("--build-id", required=True, help="Unique job identifier in the build system, e.g. scylla-master/group/job for jenkins (The full path)") @click.option("--build-url", required=True, help="Job URL in the build system") @@ -35,43 +36,63 @@ def cli(): @click.option("--sub-type", required=False, help="Sub-type of the generic test: pytest, dtest") @click.option("--scylla-version", required=False, default=None, help="Version of Scylla used for this job") @click.option("--extra-headers", default={}, type=click.UNPROCESSED, callback=validate_extra_headers, help="extra headers to pass to argus, should be in json format", envvar='ARGUS_EXTRA_HEADERS') -def submit_run(api_key: str, base_url: str, id: str, build_id: str, build_url: str, started_by: str, sub_type: str = None, scylla_version: str = None, extra_headers: dict | None = None): +def submit_run(api_key: str, base_url: str, use_tunnel: bool | None, id: str, build_id: str, build_url: str, started_by: str, + sub_type: str = None, scylla_version: str = None, extra_headers: dict | None = None): LOGGER.info("Submitting %s (%s) to Argus...", build_id, id) - client = ArgusGenericClient(auth_token=api_key, base_url=base_url, extra_headers=extra_headers) - client.submit_generic_run(build_id=build_id, run_id=id, started_by=started_by, - build_url=build_url, scylla_version=scylla_version, sub_type=sub_type) + with ArgusGenericClient( + auth_token=api_key, + base_url=base_url, + extra_headers=extra_headers, + use_tunnel=use_tunnel, + ) as client: + client.submit_generic_run(build_id=build_id, run_id=id, started_by=started_by, + build_url=build_url, scylla_version=scylla_version, sub_type=sub_type) LOGGER.info("Done.") @click.command("finish") @click.option("--api-key", help="Argus API key for authorization", required=True, envvar='ARGUS_AUTH_TOKEN') @click.option("--base-url", default="https://argus.scylladb.com", help="Base URL for argus instance") +@click.option("--use-tunnel/--no-use-tunnel", default=None, help="Route API calls through SSH tunnel") @click.option("--id", required=True, help="UUID (v4 or v1) unique to the job") @click.option("--status", required=True, help="Resulting job status") @click.option("--scylla-version", required=False, default=None, help="Version of Scylla used for this job") @click.option("--extra-headers", default={}, type=click.UNPROCESSED, callback=validate_extra_headers, help="extra headers to pass to argus, should be in json format", envvar='ARGUS_EXTRA_HEADERS') -def finish_run(api_key: str, base_url: str, id: str, status: str, scylla_version: str = None, extra_headers: dict | None = None): - client = ArgusGenericClient(auth_token=api_key, base_url=base_url, extra_headers=extra_headers) - status = TestStatus(status) - client.finalize_generic_run(run_id=id, status=status, scylla_version=scylla_version) +def finish_run(api_key: str, base_url: str, use_tunnel: bool | None, id: str, status: str, + scylla_version: str = None, extra_headers: dict | None = None): + with ArgusGenericClient( + auth_token=api_key, + base_url=base_url, + extra_headers=extra_headers, + use_tunnel=use_tunnel, + ) as client: + status = TestStatus(status) + client.finalize_generic_run(run_id=id, status=status, scylla_version=scylla_version) @click.command("trigger-jobs") @click.option("--api-key", help="Argus API key for authorization", required=True, envvar='ARGUS_AUTH_TOKEN') @click.option("--base-url", default="https://argus.scylladb.com", help="Base URL for argus instance") +@click.option("--use-tunnel/--no-use-tunnel", default=None, help="Route API calls through SSH tunnel") @click.option("--version", help="Scylla version to filter plans by", default=None, required=False) @click.option("--plan-id", help="Specific plan id for filtering", default=None, required=False) @click.option("--release", help="Release name to filter plans by", default=None, required=False) @click.option("--job-info-file", required=True, help="JSON file with trigger information (see detailed docs)") @click.option("--extra-headers", default={}, type=click.UNPROCESSED, callback=validate_extra_headers, help="extra headers to pass to argus, should be in json format", envvar='ARGUS_EXTRA_HEADERS') -def trigger_jobs(api_key: str, base_url: str, job_info_file: str, version: str, plan_id: str, release: str, extra_headers: dict | None = None): - client = ArgusGenericClient(auth_token=api_key, base_url=base_url, extra_headers=extra_headers) - path = Path(job_info_file) - if not path.exists(): +def trigger_jobs(api_key: str, base_url: str, use_tunnel: bool | None, job_info_file: str, version: str, + plan_id: str, release: str, extra_headers: dict | None = None): + if not os.path.exists(job_info_file): LOGGER.error("File not found: %s", job_info_file) exit(128) - payload = json.load(path.open("rt", encoding="utf-8")) - client.trigger_jobs({"release": release, "version": version, "plan_id": plan_id, **payload}) + with open(job_info_file, "rt", encoding="utf-8") as fh: + payload = json.load(fh) + with ArgusGenericClient( + auth_token=api_key, + base_url=base_url, + extra_headers=extra_headers, + use_tunnel=use_tunnel, + ) as client: + client.trigger_jobs({"release": release, "version": version, "plan_id": plan_id, **payload}) cli.add_command(submit_run) diff --git a/argus/client/generic/client.py b/argus/client/generic/client.py index 748a37c3..8733fb4a 100644 --- a/argus/client/generic/client.py +++ b/argus/client/generic/client.py @@ -13,9 +13,9 @@ class Routes(ArgusClient.Routes): TRIGGER_JOBS = "/planning/plan/trigger" def __init__(self, auth_token: str, base_url: str, api_version="v1", extra_headers: dict | None = None, - timeout: int = 180, max_retries: int = 3) -> None: + timeout: int = 180, max_retries: int = 3, use_tunnel: bool | None = None) -> None: super().__init__(auth_token, base_url, api_version, extra_headers=extra_headers, - timeout=timeout, max_retries=max_retries) + timeout=timeout, max_retries=max_retries, use_tunnel=use_tunnel) def submit_generic_run(self, build_id: str, run_id: str, started_by: str, build_url: str, sub_type: str = None, scylla_version: str | None = None): request_body = { diff --git a/argus/client/sct/client.py b/argus/client/sct/client.py index dc29f213..346ad804 100644 --- a/argus/client/sct/client.py +++ b/argus/client/sct/client.py @@ -37,9 +37,9 @@ class Routes(ArgusClient.Routes): SUBMIT_CONFIG = "/$id/config/submit" def __init__(self, run_id: UUID, auth_token: str, base_url: str, api_version="v1", extra_headers: dict | None = None, - timeout: int = 60, max_retries: int = 3) -> None: + timeout: int = 60, max_retries: int = 3, use_tunnel: bool | None = None) -> None: super().__init__(auth_token, base_url, api_version, extra_headers=extra_headers, - timeout=timeout, max_retries=max_retries) + timeout=timeout, max_retries=max_retries, use_tunnel=use_tunnel) self.run_id = run_id def submit_sct_run(self, job_name: str, job_url: str, started_by: str, commit_id: str, diff --git a/argus/client/session.py b/argus/client/session.py new file mode 100644 index 00000000..08714dbf --- /dev/null +++ b/argus/client/session.py @@ -0,0 +1,306 @@ +import atexit +import logging +import os +import threading +import time +import weakref +from datetime import UTC, datetime + +import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +from argus.client.tunnel import ( + SSHTunnel, + TunnelConfig, + resolve_tunnel_config_with_reason, +) + +LOGGER = logging.getLogger(__name__) +TUNNEL_COOLDOWN_SECONDS = 30 +_DEFAULT_MONITOR_INTERVAL = 5.0 + + +def _resolve_use_tunnel(use_tunnel: bool | None) -> bool: + if use_tunnel is not None: + return use_tunnel + return os.environ.get("ARGUS_USE_TUNNEL", "").strip().lower() in ("1", "true", "yes", "on") + + +def _resolve_monitor_interval() -> float: + raw = os.environ.get("ARGUS_TUNNEL_MONITOR_INTERVAL") + if raw is None: + return _DEFAULT_MONITOR_INTERVAL + try: + value = float(raw) + if value <= 0: + raise ValueError("interval must be positive") + return value + except ValueError: + LOGGER.warning( + "Invalid ARGUS_TUNNEL_MONITOR_INTERVAL=%r, using default %.1fs", + raw, + _DEFAULT_MONITOR_INTERVAL, + ) + return _DEFAULT_MONITOR_INTERVAL + + +def _build_retry_adapter(max_retries: int) -> HTTPAdapter: + retry_strategy = Retry( + total=max_retries, + connect=max_retries, + read=max_retries, + status=0, + backoff_factor=1, + status_forcelist=(), + allowed_methods=["GET", "POST"], + ) + return HTTPAdapter(max_retries=retry_strategy) + + +def _build_retry_session(max_retries: int) -> requests.Session: + session = requests.Session() + adapter = _build_retry_adapter(max_retries) + session.mount("http://", adapter) + session.mount("https://", adapter) + return session + + +class TunneledSession(requests.Session): + """``requests.Session`` that transparently routes traffic through an SSH tunnel. + + All HTTP verbs work out of the box because we subclass ``requests.Session`` + and only override :meth:`request` to inject URL rewriting plus single-shot + reconnect-and-retry on connection errors. + """ + + def __init__(self, auth_token: str, original_base_url: str, max_retries: int = 3) -> None: + super().__init__() + adapter = _build_retry_adapter(max_retries) + self.mount("http://", adapter) + self.mount("https://", adapter) + + self._auth_token = auth_token + self._original_base_url = original_base_url + + self._tunnel: SSHTunnel | None = None + self._tunnel_config: TunnelConfig | None = None + self._tunnel_port: int | None = None + self._tunnel_established_at: str | None = None + self._tunnel_warning_emitted = False + self._tunnel_disabled_until = 0.0 + # RLock is held while reconnecting; this can stall a request thread for + # up to ~100s during the SSH retry budget. Acceptable: concurrent + # request callers should observe a single coherent tunnel state and not + # race multiple SSH spawns. + self._lock = threading.RLock() + + monitor_interval = _resolve_monitor_interval() + self._monitor_stop = threading.Event() + self._monitor_thread = threading.Thread( + target=self._monitor_loop, + args=(monitor_interval,), + name="argus-tunnel-monitor", + daemon=True, + ) + self._monitor_thread.start() + + # Ensure the monitor thread is stopped at interpreter exit even if a + # caller forgets to invoke ``close()``. The atexit registration uses a + # weak reference so the session can be garbage collected normally. + self._atexit_ref = weakref.ref(self) + atexit.register(self._atexit_close, self._atexit_ref) + + @staticmethod + def _atexit_close(session_ref: weakref.ref) -> None: + session = session_ref() + if session is not None: + try: + session.close() + except Exception: # noqa: BLE001 + LOGGER.debug("SSH tunnel atexit close failed", exc_info=True) + + def _active_tunnel_url(self) -> str | None: + if self._tunnel_port is not None: + return f"http://127.0.0.1:{self._tunnel_port}" + return None + + def _rewrite_url(self, url: str) -> str: + tunnel_url = self._active_tunnel_url() + if tunnel_url and url.startswith(self._original_base_url): + return tunnel_url + url[len(self._original_base_url):] + return url + + def _ensure_tunnel(self) -> None: + with self._lock: + if time.monotonic() < self._tunnel_disabled_until: + return + + if self._tunnel and self._tunnel.is_alive() and self._tunnel.local_port is not None: + self._tunnel_port = self._tunnel.local_port + return + + if self._tunnel and self._tunnel_config: + reconnected_port, reconnect_reason = self._tunnel.reconnect(self._tunnel_config) + if reconnected_port is not None: + self._tunnel_port = reconnected_port + self._tunnel_warning_emitted = False + return + if reconnect_reason: + LOGGER.warning("SSH tunnel reconnect failed: %s", reconnect_reason) + + force_refresh = self._tunnel is not None + # Use session=None (creates a plain requests.Session inside + # tunnel_api) to avoid infinite recursion: passing `self` would + # trigger _ensure_tunnel() again when the tunnel API call invokes + # session.post(). We forward session-level headers (e.g. + # Cloudflare Access tokens) via extra_headers instead. + extra_headers = dict(self.headers) if self.headers else None + config, config_reason = resolve_tunnel_config_with_reason( + auth_token=self._auth_token, + base_url=self._original_base_url, + force_refresh=force_refresh, + session=None, + extra_headers=extra_headers, + ) + if config is None: + self._backoff(config_reason or "failed to resolve tunnel configuration") + return + + tunnel = SSHTunnel() + local_port, establish_reason = tunnel.establish(config) + + if local_port is None and not force_refresh: + config, config_reason = resolve_tunnel_config_with_reason( + auth_token=self._auth_token, + base_url=self._original_base_url, + force_refresh=True, + session=None, + extra_headers=extra_headers, + ) + if config is not None: + local_port, establish_reason = tunnel.establish(config) + else: + establish_reason = config_reason + + if local_port is None: + self._backoff(establish_reason or "failed to establish tunnel") + return + + assert config is not None # guaranteed: local_port is set only when config is valid + self._tunnel = tunnel + self._tunnel_config = config + self._tunnel_port = local_port + self._tunnel_established_at = datetime.now(UTC).isoformat() + self._tunnel_warning_emitted = False + self._tunnel_disabled_until = 0.0 + LOGGER.info( + "SSH tunnel established: proxy=%s:%d, user=%s, key_id=%s, local_port=%d", + config.proxy_host, + config.proxy_port, + config.proxy_user, + config.key_id or "unknown", + local_port, + ) + + def _backoff(self, reason: str) -> None: + # We deliberately do NOT delete the cached keypair here. The keypair + # remains valid even when the proxy host is briefly unreachable, and + # regenerating it on every 30s cooldown would force an unnecessary + # re-registration round-trip. ``resolve_tunnel_config_with_reason`` + # already issues ``force_refresh=True`` after the first failure, which + # re-fetches the live config without dropping the keypair. + if not self._tunnel_warning_emitted: + LOGGER.warning( + "SSH tunnel unavailable (%s); falling back to direct connection: %s", + reason, + self._original_base_url, + ) + self._tunnel_warning_emitted = True + + if self._tunnel: + self._tunnel.shutdown() + + self._tunnel = None + self._tunnel_config = None + self._tunnel_port = None + self._tunnel_established_at = None + self._tunnel_disabled_until = time.monotonic() + TUNNEL_COOLDOWN_SECONDS + + def _monitor_loop(self, interval: float) -> None: + while not self._monitor_stop.wait(interval): + try: + self._check_tunnel_health() + except Exception: # noqa: BLE001 + LOGGER.debug("SSH tunnel monitor failed", exc_info=True) + + def _check_tunnel_health(self) -> None: + tunnel = self._tunnel + if tunnel is None or tunnel.is_alive(): + return + LOGGER.warning("SSH tunnel monitor detected dead tunnel; reconnecting") + self._ensure_tunnel() + + def _tunnel_headers(self) -> dict[str, str]: + """Return headers to attach when traffic flows through the SSH tunnel.""" + if self._tunnel_port is None or self._tunnel_config is None: + return {} + headers = { + "User-Agent": "argus-client-ssh-tunnel", + "X-SSH-Tunnel-Origin": self._tunnel_config.proxy_host, + "X-Tunnel-Established-At": self._tunnel_established_at or "", + } + if self._tunnel_config.key_id: + headers["X-Forwarded-Key-ID"] = self._tunnel_config.key_id + return headers + + def request(self, method: str, url: str, *args, **kwargs) -> requests.Response: + self._ensure_tunnel() + rewritten = self._rewrite_url(url) + if rewritten != url: + # Traffic is going through the tunnel — inject tunnel headers + headers = kwargs.get("headers") or {} + headers.update(self._tunnel_headers()) + kwargs["headers"] = headers + try: + return super().request(method, rewritten, *args, **kwargs) + except requests.ConnectionError: + if rewritten == url: + raise + LOGGER.warning("%s through SSH tunnel failed; reconnecting and retrying", method) + self._ensure_tunnel() + rewritten = self._rewrite_url(url) + if rewritten != url: + headers = kwargs.get("headers") or {} + headers.update(self._tunnel_headers()) + kwargs["headers"] = headers + try: + return super().request(method, rewritten, *args, **kwargs) + except requests.ConnectionError as exc: + self._backoff(f"request retry failed: {exc}") + return super().request(method, self._rewrite_url(url), *args, **kwargs) + + def close(self) -> None: + self._monitor_stop.set() + with self._lock: + if self._tunnel: + self._tunnel.shutdown() + self._tunnel = None + self._tunnel_config = None + self._tunnel_port = None + try: + atexit.unregister(self._atexit_close) + except Exception: # noqa: BLE001 + pass + super().close() + + +def create_session( + auth_token: str, + base_url: str, + use_tunnel: bool | None, + max_retries: int = 3, +) -> requests.Session: + if _resolve_use_tunnel(use_tunnel): + return TunneledSession(auth_token=auth_token, original_base_url=base_url, max_retries=max_retries) + return _build_retry_session(max_retries) diff --git a/argus/client/sirenada/client.py b/argus/client/sirenada/client.py index a2628ed9..88971c2f 100644 --- a/argus/client/sirenada/client.py +++ b/argus/client/sirenada/client.py @@ -46,10 +46,10 @@ class ArgusSirenadaClient(ArgusClient): } def __init__(self, auth_token: str, base_url: str, api_version="v1", extra_headers: dict | None = None, - timeout: int = 60, max_retries: int = 3) -> None: + timeout: int = 60, max_retries: int = 3, use_tunnel: bool | None = None) -> None: self.results_path: Path | None = None super().__init__(auth_token, base_url, api_version, extra_headers=extra_headers, - timeout=timeout, max_retries=max_retries) + timeout=timeout, max_retries=max_retries, use_tunnel=use_tunnel) def _verify_required_files_exist(self, results_path: Path): assert (results_path / self._junit_xml_filename).exists(), "Missing jUnit XML results file!" diff --git a/argus/client/tests/test_tunnel.py b/argus/client/tests/test_tunnel.py new file mode 100644 index 00000000..47101e57 --- /dev/null +++ b/argus/client/tests/test_tunnel.py @@ -0,0 +1,487 @@ +from datetime import UTC, datetime, timedelta +from io import StringIO +from unittest.mock import Mock +import os + +import pytest + +from argus.client.base import ArgusClient +from argus.client.tunnel import api as tunnel_api +from argus.client.tunnel import ssh as tunnel_ssh +from argus.client.tunnel import state as tunnel_state +from argus.client.session import TunneledSession +from argus.client.tunnel import TunnelConfig + + +def _write_text(path: str, text: str) -> None: + with open(path, "w", encoding="utf-8") as fh: + fh.write(text) + + +def _read_text(path: str) -> str: + with open(path, encoding="utf-8") as fh: + return fh.read() + + +class _DummyProcess: + def __init__(self, stderr_text: str = ""): + self._alive = True + self.stderr = StringIO(stderr_text) + + def poll(self): + return None if self._alive else 0 + + def terminate(self): + self._alive = False + + def wait(self, timeout=None): + return 0 + + def kill(self): + self._alive = False + + +@pytest.fixture +def tunnel_state_dir(tmp_path, monkeypatch): + monkeypatch.setenv("ARGUS_TUNNEL_STATE_DIR", str(tmp_path)) + return str(tmp_path) + + +def test_resolve_tunnel_config_registers_and_caches(tunnel_state_dir, monkeypatch): + expires_at = datetime.now(tz=UTC) + timedelta(hours=6) + config = TunnelConfig( + proxy_host="proxy.example.com", + proxy_port=22, + proxy_user="argus-proxy", + target_host="10.0.0.10", + target_port=8080, + host_key_fingerprint="SHA256:test", + expires_at=expires_at, + key_id="key-id", + tunnel_id="tunnel-id", + ) + + def _fake_generate(paths): + _write_text(paths.private_key, "private") + _write_text(paths.public_key, "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAITestKey") + + monkeypatch.setattr(tunnel_api, "generate_keypair_if_needed", _fake_generate) + monkeypatch.setattr(tunnel_api, "_register_tunnel", lambda **kwargs: config) + + resolved = tunnel_api.resolve_tunnel_config(auth_token="token", base_url="https://argus.example.com") + assert resolved is not None + assert resolved.proxy_host == "proxy.example.com" + + paths = tunnel_state.get_tunnel_state_paths() + assert os.path.exists(paths.config_cache) + assert os.path.exists(paths.key_meta) + + monkeypatch.setattr( + tunnel_api, + "_get_tunnel_connection", + lambda **kwargs: (_ for _ in ()).throw(AssertionError("should not call GET when cache is valid")), + ) + cached = tunnel_api.resolve_tunnel_config(auth_token="token", base_url="https://argus.example.com") + assert cached is not None + assert cached.proxy_host == "proxy.example.com" + + +def test_tunnel_api_raises_on_connection_failure(): + mock_session = Mock(spec=tunnel_api.requests.Session) + mock_session.get.side_effect = tunnel_api.requests.RequestException("connection refused") + + with pytest.raises(tunnel_api.TunnelClientError, match="Tunnel API call failed"): + tunnel_api._call_tunnel_api( + method="GET", + url="https://argus.example.com/api/v1/client/ssh/tunnel", + auth_token="token", + payload=None, + session=mock_session, + ) + + +def test_tunnel_api_succeeds_with_valid_response(): + class _Response: + status_code = 200 + + @staticmethod + def json(): + return { + "status": "ok", + "response": { + "proxy_host": "proxy.example.com", + "proxy_port": 22, + "proxy_user": "argus-proxy", + "target_host": "10.0.0.10", + "target_port": 8080, + "host_key_fingerprint": "SHA256:test", + }, + } + + mock_session = Mock(spec=tunnel_api.requests.Session) + mock_session.get.return_value = _Response() + + data = tunnel_api._call_tunnel_api( + method="GET", + url="https://argus.example.com/api/v1/client/ssh/tunnel", + auth_token="token", + payload=None, + session=mock_session, + ) + assert data["proxy_host"] == "proxy.example.com" + + +def test_establish_uses_strict_host_options_and_temp_known_hosts(tunnel_state_dir, monkeypatch): + paths = tunnel_state.get_tunnel_state_paths() + _write_text(paths.private_key, "private") + + host_blob = "AQIDBA==" + expected_fingerprint = tunnel_ssh.derive_fingerprint(f"ssh-ed25519 {host_blob}") + config = TunnelConfig( + proxy_host="proxy.example.com", + proxy_port=22, + proxy_user="argus-proxy", + target_host="10.0.0.10", + target_port=8080, + host_key_fingerprint=expected_fingerprint, + ) + + monkeypatch.setattr(tunnel_ssh.shutil, "which", lambda cmd: f"/usr/bin/{cmd}") + monkeypatch.setattr( + tunnel_ssh, + "scan_host_keys", + lambda host, port: [f"{host} ssh-ed25519 {host_blob}"], + ) + + captured = {"commands": []} + + def _fake_popen(command, stdout, stderr, text): + captured["commands"].append(command) + return _DummyProcess() + + monkeypatch.setattr(tunnel_ssh.subprocess, "Popen", _fake_popen) + monkeypatch.setattr(tunnel_ssh.SSHTunnel, "_wait_for_port_ready", staticmethod(lambda process, local_port: (True, ""))) + + ssh_tunnel = tunnel_ssh.SSHTunnel(key_path=paths.private_key) + local_port, reason = ssh_tunnel.establish(config) + + assert reason is None + assert local_port is not None + assert ssh_tunnel.local_port == local_port + command = captured["commands"][0] + command_text = " ".join(command) + assert "StrictHostKeyChecking=yes" in command_text + assert "GlobalKnownHostsFile=/dev/null" in command_text + assert "HostKeyAlgorithms=ssh-ed25519,ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521" in command_text + assert "ssh-rsa" not in command_text + + known_hosts_path = ssh_tunnel._known_hosts_path + assert known_hosts_path is not None + assert os.path.exists(known_hosts_path) + + ssh_tunnel.shutdown() + assert not os.path.exists(known_hosts_path) + + +def test_establish_retries_on_local_bind_conflict(tunnel_state_dir, monkeypatch): + paths = tunnel_state.get_tunnel_state_paths() + _write_text(paths.private_key, "private") + + config = TunnelConfig( + proxy_host="proxy.example.com", + proxy_port=22, + proxy_user="argus-proxy", + target_host="10.0.0.10", + target_port=8080, + host_key_fingerprint="SHA256:test", + ) + + monkeypatch.setattr(tunnel_ssh.shutil, "which", lambda cmd: f"/usr/bin/{cmd}") + monkeypatch.setattr( + tunnel_ssh.SSHTunnel, + "_prepare_known_hosts_file", + staticmethod(lambda cfg: tunnel_ssh.write_temp_known_hosts("proxy ssh-ed25519 AQIDBA==")), + ) + + call_state = {"calls": 0} + + def _fake_wait(process, local_port): + call_state["calls"] += 1 + if call_state["calls"] == 1: + return False, "Address already in use" + return True, "" + + monkeypatch.setattr(tunnel_ssh.SSHTunnel, "_wait_for_port_ready", staticmethod(_fake_wait)) + monkeypatch.setattr(tunnel_ssh.subprocess, "Popen", lambda *args, **kwargs: _DummyProcess()) + + ssh_tunnel = tunnel_ssh.SSHTunnel(key_path=paths.private_key) + local_port, reason = ssh_tunnel.establish(config) + + assert reason is None + assert local_port is not None + assert call_state["calls"] == 2 + + +def test_argus_client_warns_and_falls_back_when_tunnel_setup_fails(requests_mock, monkeypatch, caplog): + requests_mock.get( + "https://argus.scylladb.com/api/v1/client/testrun/test-type/test-id/get", + json={"status": "ok", "response": {}}, + status_code=200, + ) + + monkeypatch.setattr("argus.client.session.resolve_tunnel_config_with_reason", lambda **kwargs: (None, "api unreachable")) + + client = ArgusClient(auth_token="token", base_url="https://argus.scylladb.com", use_tunnel=True) + with caplog.at_level("WARNING"): + response = client.get( + endpoint=ArgusClient.Routes.GET, + location_params={"type": "test-type", "id": "test-id"}, + ) + + assert response.status_code == 200 + assert isinstance(client.session, TunneledSession) + assert "api unreachable" in caplog.text + assert "falling back to direct connection" in caplog.text + + +def test_argus_client_retries_tunnel_after_cooldown(requests_mock, monkeypatch): + requests_mock.get( + "https://argus.scylladb.com/api/v1/client/testrun/test-type/test-id/get", + json={"status": "ok", "response": {}}, + status_code=200, + ) + requests_mock.get( + "http://127.0.0.1:9191/api/v1/client/testrun/test-type/test-id/get", + json={"status": "ok", "response": {}}, + status_code=200, + ) + + config = TunnelConfig( + proxy_host="proxy.example.com", + proxy_port=22, + proxy_user="argus-proxy", + target_host="10.0.0.10", + target_port=8080, + host_key_fingerprint="SHA256:test", + ) + resolve_state = {"calls": 0} + + def _resolve(**kwargs): + resolve_state["calls"] += 1 + if resolve_state["calls"] == 1: + return None, "first failure" + return config, None + + class _FakeTunnel: + def __init__(self): + self.local_port = 9191 + + def establish(self, cfg): + return 9191, None + + def is_alive(self): + return True + + def reconnect(self, cfg): + return 9191, None + + def shutdown(self): + return None + + monotonic_values = iter([1000.0, 1001.0, 1032.0]) + + monkeypatch.setattr("argus.client.session.resolve_tunnel_config_with_reason", _resolve) + monkeypatch.setattr("argus.client.session.SSHTunnel", _FakeTunnel) + monkeypatch.setattr("argus.client.session.time.monotonic", lambda: next(monotonic_values)) + + client = ArgusClient(auth_token="token", base_url="https://argus.scylladb.com", use_tunnel=True) + + client.get(endpoint=ArgusClient.Routes.GET, location_params={"type": "test-type", "id": "test-id"}) + assert client.session._tunnel_port is None + assert client._base_url == "https://argus.scylladb.com" + + client.get(endpoint=ArgusClient.Routes.GET, location_params={"type": "test-type", "id": "test-id"}) + assert client.session._tunnel_port == 9191 + + +def test_request_level_recovery_reconnects_and_retries_once(requests_mock, monkeypatch): + old_tunnel_url = "http://127.0.0.1:9191/api/v1/client/testrun/test-type/test-id/get" + new_tunnel_url = "http://127.0.0.1:9292/api/v1/client/testrun/test-type/test-id/get" + + requests_mock.get(old_tunnel_url, exc=tunnel_api.requests.ConnectionError("old tunnel is down")) + requests_mock.get(new_tunnel_url, json={"status": "ok", "response": {}}, status_code=200) + + client = ArgusClient(auth_token="token", base_url="https://argus.scylladb.com", use_tunnel=True) + client.session._tunnel_port = 9191 + + ensure_state = {"calls": 0} + + def _fake_ensure_tunnel(): + ensure_state["calls"] += 1 + if ensure_state["calls"] >= 2: + client.session._tunnel_port = 9292 + + monkeypatch.setattr(client.session, "_ensure_tunnel", _fake_ensure_tunnel) + + response = client.get( + endpoint=ArgusClient.Routes.GET, + location_params={"type": "test-type", "id": "test-id"}, + ) + + assert response.status_code == 200 + assert ensure_state["calls"] == 2 + assert client.session._tunnel_port == 9292 + + +def test_request_level_recovery_falls_back_to_direct_when_retry_fails(requests_mock, monkeypatch): + direct_url = "https://argus.scylladb.com/api/v1/client/testrun/test-type/test-id/get" + tunnel_url = "http://127.0.0.1:9191/api/v1/client/testrun/test-type/test-id/get" + + requests_mock.get(tunnel_url, exc=tunnel_api.requests.ConnectionError("tunnel is dead")) + requests_mock.get(direct_url, json={"status": "ok", "response": {}}, status_code=200) + + client = ArgusClient(auth_token="token", base_url="https://argus.scylladb.com", use_tunnel=True) + client.session._tunnel_port = 9191 + + def _ensure_keeps_tunnel(): + client.session._tunnel_port = 9191 + + backoff_calls = {"count": 0} + + def _fake_backoff(reason): + backoff_calls["count"] += 1 + client.session._tunnel_port = None + + monkeypatch.setattr(client.session, "_ensure_tunnel", _ensure_keeps_tunnel) + monkeypatch.setattr(client.session, "_backoff", _fake_backoff) + + response = client.get( + endpoint=ArgusClient.Routes.GET, + location_params={"type": "test-type", "id": "test-id"}, + ) + + assert response.status_code == 200 + assert backoff_calls["count"] == 1 + assert client.session._tunnel_port is None + + +def test_tunneled_session_starts_and_stops_monitor_thread(): + session = TunneledSession(auth_token="token", original_base_url="https://argus.scylladb.com") + try: + assert session._monitor_thread.is_alive() + finally: + session.close() + + session._monitor_thread.join(timeout=5) + assert not session._monitor_thread.is_alive() + + +def test_tunneled_session_close_unregisters_atexit(): + import atexit as _atexit + + session = TunneledSession(auth_token="token", original_base_url="https://argus.scylladb.com") + callback = session._atexit_close + ref = session._atexit_ref + session.close() + # After close(), invoking the atexit callback must be a no-op (the session + # was unregistered, and even if called manually it should not blow up). + callback(ref) + + +def test_argus_client_works_as_context_manager(requests_mock, monkeypatch): + requests_mock.get( + "https://argus.scylladb.com/api/v1/client/testrun/test-type/test-id/get", + json={"status": "ok", "response": {}}, + status_code=200, + ) + monkeypatch.setattr( + "argus.client.session.resolve_tunnel_config_with_reason", + lambda **kwargs: (None, "api unreachable"), + ) + + with ArgusClient(auth_token="token", base_url="https://argus.scylladb.com", use_tunnel=True) as client: + client.get(endpoint=ArgusClient.Routes.GET, location_params={"type": "test-type", "id": "test-id"}) + session = client.session + assert session._monitor_thread.is_alive() + + session._monitor_thread.join(timeout=5) + assert not session._monitor_thread.is_alive() + + +def test_backoff_does_not_wipe_cached_tunnel_state(tunnel_state_dir, monkeypatch): + paths = tunnel_state.get_tunnel_state_paths() + _write_text(paths.private_key, "private-key") + _write_text(paths.public_key, "public-key") + _write_text(paths.config_cache, '{"placeholder": "value"}') + + monkeypatch.setattr( + "argus.client.session.resolve_tunnel_config_with_reason", + lambda **kwargs: (None, "transient failure"), + ) + + session = TunneledSession(auth_token="token", original_base_url="https://argus.scylladb.com") + try: + session._ensure_tunnel() + # Transient establish failure must NOT wipe the cached keypair — + # otherwise every cooldown forces a fresh registration round-trip. + assert os.path.exists(paths.private_key) + assert os.path.exists(paths.public_key) + assert os.path.exists(paths.config_cache) + finally: + session.close() + + +def test_call_tunnel_api_rejects_non_dict_payload(): + class _Response: + status_code = 200 + + @staticmethod + def json(): + return ["unexpected", "list"] + + mock_session = Mock(spec=tunnel_api.requests.Session) + mock_session.get.return_value = _Response() + + with pytest.raises(tunnel_api.TunnelClientError, match="invalid format"): + tunnel_api._call_tunnel_api( + method="GET", + url="https://argus.example.com/api/v1/client/ssh/tunnel", + auth_token="token", + payload=None, + session=mock_session, + ) + + +def test_prepare_known_hosts_file_accepts_full_known_hosts_entry(tunnel_state_dir): + config = TunnelConfig( + proxy_host="proxy.example.com", + proxy_port=2222, + proxy_user="argus-proxy", + target_host="10.0.0.10", + target_port=8080, + host_key_fingerprint="some-other-name ssh-ed25519 AAAAdummybase64", + ) + + path = tunnel_ssh.SSHTunnel._prepare_known_hosts_file(config) + try: + contents = _read_text(path).strip() + # Host token must be rewritten to the connection target with the + # non-default port, not whatever the backend stored. + assert contents.startswith("[proxy.example.com]:2222 ") + assert "ssh-ed25519 AAAAdummybase64" in contents + finally: + tunnel_ssh._unlink(path) + + +def test_prepare_known_hosts_file_rejects_unknown_format(tunnel_state_dir): + config = TunnelConfig( + proxy_host="proxy.example.com", + proxy_port=22, + proxy_user="argus-proxy", + target_host="10.0.0.10", + target_port=8080, + host_key_fingerprint="not-a-fingerprint", + ) + + with pytest.raises(tunnel_ssh.TunnelClientError, match="unrecognised format"): + tunnel_ssh.SSHTunnel._prepare_known_hosts_file(config) diff --git a/argus/client/tunnel/__init__.py b/argus/client/tunnel/__init__.py new file mode 100644 index 00000000..8900ca0c --- /dev/null +++ b/argus/client/tunnel/__init__.py @@ -0,0 +1,13 @@ +from .api import resolve_tunnel_config, resolve_tunnel_config_with_reason +from .models import TunnelClientError, TunnelConfig +from .ssh import SSHTunnel +from .state import delete_cached_tunnel_state + +__all__ = [ + "SSHTunnel", + "TunnelClientError", + "TunnelConfig", + "delete_cached_tunnel_state", + "resolve_tunnel_config", + "resolve_tunnel_config_with_reason", +] diff --git a/argus/client/tunnel/api.py b/argus/client/tunnel/api.py new file mode 100644 index 00000000..9a2b49c3 --- /dev/null +++ b/argus/client/tunnel/api.py @@ -0,0 +1,215 @@ +import logging +from typing import Any + +import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +from .models import DEFAULT_TUNNEL_TIMEOUT, TunnelClientError, TunnelConfig +from .state import ( + generate_keypair_if_needed, + get_tunnel_state_paths, + is_key_valid, + read_cached_tunnel_config, + write_key_meta, + write_tunnel_cache, +) + + +LOGGER = logging.getLogger(__name__) +TUNNEL_API_RETRIES = 3 + + +def _create_api_session() -> requests.Session: + session = requests.Session() + retry = Retry( + total=TUNNEL_API_RETRIES, + backoff_factor=0.5, + allowed_methods=["GET", "POST"], + status_forcelist=[500, 502, 503, 504], + ) + adapter = HTTPAdapter(max_retries=retry) + session.mount("http://", adapter) + session.mount("https://", adapter) + return session + + +def resolve_tunnel_config( + auth_token: str, + base_url: str, + force_refresh: bool = False, + ttl_seconds: int | None = None, + session: requests.Session | None = None, + extra_headers: dict[str, str] | None = None, +) -> TunnelConfig | None: + config, _reason = resolve_tunnel_config_with_reason( + auth_token=auth_token, + base_url=base_url, + force_refresh=force_refresh, + ttl_seconds=ttl_seconds, + session=session, + extra_headers=extra_headers, + ) + return config + + +def resolve_tunnel_config_with_reason( + auth_token: str, + base_url: str, + force_refresh: bool = False, + ttl_seconds: int | None = None, + session: requests.Session | None = None, + extra_headers: dict[str, str] | None = None, +) -> tuple[TunnelConfig | None, str | None]: + """ + Resolve tunnel configuration while keeping Cloudflare bootstrap calls minimal. + + Order: + 1. Use cached config when key/cache are still valid and refresh is not forced. + 2. Use GET /client/ssh/tunnel when key exists and remains valid. + 3. Register/re-register via POST /client/ssh/tunnel. + """ + paths = get_tunnel_state_paths() + + if not force_refresh: + cached = read_cached_tunnel_config(paths) + if cached is not None and is_key_valid(paths): + return cached, None + + if is_key_valid(paths): + try: + config = _get_tunnel_connection(auth_token=auth_token, base_url=base_url, session=session, + extra_headers=extra_headers) + write_tunnel_cache(paths, config) + return config, None + except TunnelClientError as exc: + LOGGER.warning("Unable to refresh tunnel connection details via API: %s", exc) + + try: + generate_keypair_if_needed(paths) + with open(paths.public_key, encoding="utf-8") as fh: + public_key = fh.read().strip() + config = _register_tunnel( + auth_token=auth_token, + base_url=base_url, + public_key=public_key, + ttl_seconds=ttl_seconds, + session=session, + extra_headers=extra_headers, + ) + write_key_meta(paths, config.expires_at) + write_tunnel_cache(paths, config) + return config, None + except (OSError, TunnelClientError) as exc: + LOGGER.warning("Unable to resolve SSH tunnel configuration: %s", exc) + return None, str(exc) + + +def _register_tunnel( + auth_token: str, + base_url: str, + public_key: str, + ttl_seconds: int | None = None, + session: requests.Session | None = None, + extra_headers: dict[str, str] | None = None, +) -> TunnelConfig: + payload: dict[str, Any] = {"public_key": public_key} + if ttl_seconds is not None: + payload["ttl_seconds"] = ttl_seconds + response = _call_tunnel_api( + method="POST", + url=f"{base_url}/api/v1/client/ssh/tunnel", + auth_token=auth_token, + payload=payload, + session=session, + extra_headers=extra_headers, + ) + return TunnelConfig.from_api_response(response) + + +def _get_tunnel_connection( + auth_token: str, + base_url: str, + session: requests.Session | None = None, + extra_headers: dict[str, str] | None = None, +) -> TunnelConfig: + response = _call_tunnel_api( + method="GET", + url=f"{base_url}/api/v1/client/ssh/tunnel", + auth_token=auth_token, + payload=None, + session=session, + extra_headers=extra_headers, + ) + return TunnelConfig.from_api_response(response) + + +def _call_tunnel_api( + method: str, + url: str, + auth_token: str, + payload: dict[str, Any] | None, + session: requests.Session | None = None, + extra_headers: dict[str, str] | None = None, +) -> dict[str, Any]: + headers = { + "Authorization": f"token {auth_token}", + "Accept": "application/json", + "Content-Type": "application/json", + } + if extra_headers: + headers.update(extra_headers) + + own_session = session is None + if own_session: + session = _create_api_session() + + try: + try: + if method == "POST": + response = session.post(url=url, json=payload, headers=headers, timeout=DEFAULT_TUNNEL_TIMEOUT) + elif method == "GET": + response = session.get(url=url, headers=headers, timeout=DEFAULT_TUNNEL_TIMEOUT) + else: + raise TunnelClientError(f"Unsupported tunnel API method: {method}") + except requests.RequestException as exc: + raise TunnelClientError(f"Tunnel API call failed ({method} {url}): {exc}") from exc + + if response.status_code != 200: + raise TunnelClientError( + f"Tunnel API call returned unexpected status code {response.status_code} ({method} {url})" + ) + + try: + response_payload = response.json() + except ValueError as exc: + content_type = response.headers.get("Content-Type", "") + body_preview = response.text[:500] if response.text else "" + LOGGER.debug( + "Tunnel API JSON parse failure: Content-Type=%s, body=%s", + content_type, + body_preview, + ) + raise TunnelClientError( + f"Tunnel API response is not JSON ({method} {url}): " + f"Content-Type={content_type}, body_preview={body_preview!r}" + ) from exc + + if not isinstance(response_payload, dict): + raise TunnelClientError( + f"Tunnel API response payload has invalid format ({method} {url})" + ) + + if response_payload.get("status") != "ok": + response_error = response_payload.get("response") + message = response_error.get("message") if isinstance(response_error, dict) else response_error + raise TunnelClientError(f"Tunnel API returned error: {message}") + + response_data = response_payload.get("response") + if not isinstance(response_data, dict): + raise TunnelClientError("Tunnel API response payload has invalid format") + + return response_data + finally: + if own_session: + session.close() diff --git a/argus/client/tunnel/models.py b/argus/client/tunnel/models.py new file mode 100644 index 00000000..4ae1b1a1 --- /dev/null +++ b/argus/client/tunnel/models.py @@ -0,0 +1,116 @@ +from dataclasses import asdict, dataclass +from datetime import UTC, datetime +from typing import Any, NotRequired, TypedDict + + +class TunnelClientError(Exception): + pass + + +DEFAULT_TUNNEL_TIMEOUT = 10 +DEFAULT_RECONNECT_RETRIES = 3 +MAX_PORT_BIND_ATTEMPTS = 10 +ALLOWED_HOST_KEY_TYPES = ( + "ssh-ed25519", + "ecdsa-sha2-nistp256", + "ecdsa-sha2-nistp384", + "ecdsa-sha2-nistp521", +) + + +class _TunnelApiResponse(TypedDict): + """Live response shape from ``/client/ssh/tunnel`` (POST register / GET fetch).""" + proxy_host: str + proxy_port: int + proxy_user: str + target_host: str + target_port: int + host_key_fingerprint: str + expires_at: NotRequired[str | None] + key_id: NotRequired[str | None] + tunnel_id: NotRequired[str | None] + + +class _TunnelCachePayload(TypedDict): + """On-disk cache shape written by :meth:`TunnelConfig.to_cache_payload`. + + Mirrors :class:`_TunnelApiResponse` but is independently typed so future + cache-only fields don't leak into the API contract. + """ + proxy_host: str + proxy_port: int + proxy_user: str + target_host: str + target_port: int + host_key_fingerprint: str + expires_at: NotRequired[str | None] + key_id: NotRequired[str | None] + tunnel_id: NotRequired[str | None] + + +# Required keys listed explicitly to avoid relying on TypedDict.__required_keys__ +# runtime behaviour, which changed in Python 3.14. +_TUNNEL_API_REQUIRED_KEYS: tuple[str, ...] = ( + "proxy_host", + "proxy_port", + "proxy_user", + "target_host", + "target_port", + "host_key_fingerprint", +) + + +@dataclass(frozen=True, slots=True) +class TunnelConfig: + proxy_host: str + proxy_port: int + proxy_user: str + target_host: str + target_port: int + host_key_fingerprint: str + expires_at: datetime | None = None + key_id: str | None = None + tunnel_id: str | None = None + + @classmethod + def from_api_response(cls, response: "_TunnelApiResponse | _TunnelCachePayload") -> "TunnelConfig": + missing = [k for k in _TUNNEL_API_REQUIRED_KEYS if not response.get(k)] + if missing: + raise TunnelClientError(f"Missing required tunnel response fields: {', '.join(missing)}") + + expires_at = response.get("expires_at") + return cls( + proxy_host=str(response["proxy_host"]), + proxy_port=int(response["proxy_port"]), + proxy_user=str(response["proxy_user"]), + target_host=str(response["target_host"]), + target_port=int(response["target_port"]), + host_key_fingerprint=str(response["host_key_fingerprint"]), + expires_at=parse_datetime(expires_at) if expires_at else None, + key_id=str(response["key_id"]) if response.get("key_id") else None, + tunnel_id=str(response["tunnel_id"]) if response.get("tunnel_id") else None, + ) + + def to_cache_payload(self) -> dict[str, Any]: + payload = asdict(self) + if self.expires_at is not None: + payload["expires_at"] = self.expires_at.astimezone(UTC).isoformat() + return payload + + +@dataclass(frozen=True, slots=True) +class TunnelStatePaths: + state_dir: str + private_key: str + public_key: str + key_meta: str + config_cache: str + + +def parse_datetime(value: str) -> datetime: + if not value: + raise ValueError("datetime value is required") + parsed = datetime.fromisoformat(value) + if parsed.tzinfo is None: + return parsed.replace(tzinfo=UTC) + return parsed.astimezone(UTC) diff --git a/argus/client/tunnel/ssh.py b/argus/client/tunnel/ssh.py new file mode 100644 index 00000000..cfeb8a2b --- /dev/null +++ b/argus/client/tunnel/ssh.py @@ -0,0 +1,367 @@ +import atexit +import base64 +import hashlib +import logging +import os +import shutil +import socket +import subprocess +import tempfile +import time + +from .models import ( + ALLOWED_HOST_KEY_TYPES, + DEFAULT_RECONNECT_RETRIES, + DEFAULT_TUNNEL_TIMEOUT, + MAX_PORT_BIND_ATTEMPTS, + TunnelClientError, + TunnelConfig, +) +from .state import get_tunnel_state_paths + + +LOGGER = logging.getLogger(__name__) + + +class SSHTunnel: + def __init__(self, key_path: str | None = None) -> None: + state_paths = get_tunnel_state_paths() + self._key_path = key_path or state_paths.private_key + self._process: subprocess.Popen[str] | None = None + self._local_port: int | None = None + self._known_hosts_path: str | None = None + self._atexit_registered = False + + @property + def local_port(self) -> int | None: + return self._local_port + + def establish(self, config: TunnelConfig) -> tuple[int | None, str | None]: + ssh_bin = shutil.which("ssh") + if ssh_bin is None: + reason = "ssh binary was not found on PATH" + LOGGER.warning(reason) + return None, reason + if shutil.which("ssh-keyscan") is None: + reason = "ssh-keyscan binary was not found on PATH" + LOGGER.warning(reason) + return None, reason + if not os.path.exists(self._key_path): + reason = f"SSH private key does not exist: {self._key_path}" + LOGGER.warning(reason) + return None, reason + + self.shutdown() + try: + known_hosts_path = self._prepare_known_hosts_file(config) + except TunnelClientError as exc: + reason = f"strict host verification failed: {exc}" + LOGGER.warning("SSH tunnel %s", reason) + return None, reason + + for attempt in range(1, MAX_PORT_BIND_ATTEMPTS + 1): + reserve_socket, local_port = self._reserve_local_port() + command = self._build_ssh_command( + config=config, + local_port=local_port, + known_hosts_path=known_hosts_path, + ssh_bin=ssh_bin, + ) + + try: + reserve_socket.close() + process = subprocess.Popen( # noqa: S603 + command, + stdout=subprocess.DEVNULL, + stderr=subprocess.PIPE, + text=True, + ) + except OSError as exc: + reason = f"failed to spawn ssh process: {exc}" + LOGGER.warning(reason) + _unlink(known_hosts_path) + return None, reason + + ready, error_text = self._wait_for_port_ready(process=process, local_port=local_port) + if ready: + self._process = process + self._local_port = local_port + self._known_hosts_path = known_hosts_path + self._register_atexit() + return local_port, None + + self._terminate_process(process) + + if "Address already in use" in error_text: + LOGGER.warning( + "SSH tunnel local bind conflict on attempt %s/%s, retrying with a new local port", + attempt, + MAX_PORT_BIND_ATTEMPTS, + ) + continue + + reason = f"establish attempt {attempt} failed: {error_text or 'unknown error'}" + LOGGER.warning("SSH tunnel %s", reason) + _unlink(known_hosts_path) + return None, reason + + reason = f"establish failed after {MAX_PORT_BIND_ATTEMPTS} attempts" + LOGGER.warning("SSH tunnel %s", reason) + _unlink(known_hosts_path) + return None, reason + + def is_alive(self) -> bool: + if not self._process or self._local_port is None: + return False + if self._process.poll() is not None: + return False + return is_local_port_open(self._local_port) + + def reconnect(self, config: TunnelConfig) -> tuple[int | None, str | None]: + self.shutdown() + last_reason: str | None = None + for attempt in range(1, DEFAULT_RECONNECT_RETRIES + 1): + local_port, reason = self.establish(config) + if local_port is not None: + return local_port, None + last_reason = reason + time.sleep(2 ** (attempt - 1)) + return None, last_reason or "reconnect exhausted" + + def shutdown(self) -> None: + if self._process is not None: + self._terminate_process(self._process) + self._process = None + + self._local_port = None + + if self._known_hosts_path is not None: + _unlink(self._known_hosts_path) + self._known_hosts_path = None + + def _register_atexit(self) -> None: + if self._atexit_registered: + return + atexit.register(self.shutdown) + self._atexit_registered = True + + def _build_ssh_command(self, config: TunnelConfig, local_port: int, known_hosts_path: str, ssh_bin: str = "ssh") -> list[str]: + return [ + ssh_bin, + "-N", + "-L", + f"127.0.0.1:{local_port}:{config.target_host}:{config.target_port}", + "-i", + str(self._key_path), + "-p", + str(config.proxy_port), + f"{config.proxy_user}@{config.proxy_host}", + "-o", + "ExitOnForwardFailure=yes", + "-o", + "BatchMode=yes", + "-o", + "IdentitiesOnly=yes", + "-o", + f"UserKnownHostsFile={known_hosts_path}", + "-o", + "GlobalKnownHostsFile=/dev/null", + "-o", + "StrictHostKeyChecking=yes", + "-o", + "HostKeyAlgorithms=ssh-ed25519,ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521", + "-o", + "PubkeyAcceptedAlgorithms=ssh-ed25519", + "-o", + f"ConnectTimeout={DEFAULT_TUNNEL_TIMEOUT}", + "-o", + "ServerAliveInterval=30", + "-o", + "ServerAliveCountMax=3", + "-o", + "TCPKeepAlive=yes", + "-o", + "ControlMaster=no", + "-o", + "LogLevel=ERROR", + ] + + @staticmethod + def _reserve_local_port() -> tuple[socket.socket, int]: + reserve_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + reserve_socket.bind(("127.0.0.1", 0)) + local_port = reserve_socket.getsockname()[1] + return reserve_socket, local_port + + @staticmethod + def _wait_for_port_ready(process: subprocess.Popen[str], local_port: int) -> tuple[bool, str]: + deadline = time.monotonic() + DEFAULT_TUNNEL_TIMEOUT + while time.monotonic() < deadline: + if process.poll() is not None: + return False, read_process_stderr(process) + if is_local_port_open(local_port): + return True, "" + time.sleep(0.1) + return False, "Timed out waiting for local SSH tunnel port to become reachable" + + @staticmethod + def _prepare_known_hosts_file(config: TunnelConfig) -> str: + """Build a temporary known_hosts file covering ``config.proxy_host``. + + Two input shapes are supported, both validated strictly: + + 1. **Full known_hosts entry** (``"host keytype keydata"``) — the format + returned once the backend stores entries directly. We rewrite the + leading host token to match the connection target (using + ``[host]:port`` for non-default ports). + 2. **SHA256 fingerprint** (``"SHA256:..."``) — current backend format; + we run ``ssh-keyscan`` and match by fingerprint. + + Anything else raises :class:`TunnelClientError` so strict host + verification stays enforced. + """ + raw = (config.host_key_fingerprint or "").strip() + if not raw: + raise TunnelClientError("host_key_fingerprint is empty; refusing to skip host verification") + + if _looks_like_known_hosts_entry(raw): + normalised = _normalise_known_hosts_entry(raw, config.proxy_host, config.proxy_port) + return write_temp_known_hosts(normalised) + + if raw.startswith("SHA256:"): + host_lines = scan_host_keys(config.proxy_host, config.proxy_port) + matched_line = match_known_host_line( + scanned_lines=host_lines, + expected_fingerprint=raw, + ) + return write_temp_known_hosts(matched_line) + + raise TunnelClientError( + f"host_key_fingerprint has unrecognised format: {raw[:32]!r}" + ) + + @staticmethod + def _terminate_process(process: subprocess.Popen[str]) -> None: + if process.poll() is not None: + return + process.terminate() + try: + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.kill() + process.wait(timeout=5) + + +def scan_host_keys(host: str, port: int) -> list[str]: + try: + result = subprocess.run( # noqa: S603 + ["ssh-keyscan", "-T", "5", "-p", str(port), "-t", "ed25519,ecdsa", host], + check=False, + capture_output=True, + text=True, + timeout=DEFAULT_TUNNEL_TIMEOUT, + ) + except FileNotFoundError as exc: + raise TunnelClientError("ssh-keyscan binary is required for host verification") from exc + except subprocess.TimeoutExpired as exc: + raise TunnelClientError(f"ssh-keyscan timed out for {host}:{port}") from exc + + lines = [line.strip() for line in result.stdout.splitlines() if line.strip() and not line.startswith("#")] + if lines: + return lines + + stderr = (result.stderr or "").strip() + raise TunnelClientError(f"Failed to scan SSH host key for {host}:{port}: {stderr or 'no host keys returned'}") + + +def match_known_host_line(scanned_lines: list[str], expected_fingerprint: str) -> str: + for line in scanned_lines: + parts = line.split() + if len(parts) < 3: + continue + key_type = parts[1] + key_blob = parts[2] + if key_type not in ALLOWED_HOST_KEY_TYPES: + continue + + derived = derive_fingerprint(f"{key_type} {key_blob}") + if derived == expected_fingerprint: + return line + + raise TunnelClientError( + "Host key fingerprint mismatch during strict verification " + f"(expected {expected_fingerprint}, accepted key types: ed25519/ecdsa only)" + ) + + +def derive_fingerprint(key_text: str) -> str: + parts = key_text.strip().split() + if len(parts) < 2: + raise TunnelClientError("Invalid host key text format") + + try: + key_blob = base64.b64decode(parts[1].encode("ascii"), validate=True) + except ValueError as exc: + raise TunnelClientError("Invalid base64 in scanned host key") from exc + + digest = hashlib.sha256(key_blob).digest() + b64 = base64.b64encode(digest).rstrip(b"=").decode("ascii") + return f"SHA256:{b64}" + + +def _looks_like_known_hosts_entry(raw: str) -> bool: + parts = raw.split() + return len(parts) >= 3 and parts[1] in ALLOWED_HOST_KEY_TYPES + + +def _normalise_known_hosts_entry(raw: str, host: str, port: int) -> str: + """Rewrite a known_hosts entry so the host token matches the connect address. + + SSH matches the entry by the literal host string used to connect — so a + backend-provided entry of ``"some-other-name keytype keydata"`` would be + ignored. We canonicalise the leading host token to match what + :func:`SSHTunnel._build_ssh_command` connects to (and use ``[host]:port`` + for non-default ports). + """ + parts = raw.split() + if len(parts) < 3: + raise TunnelClientError("known_hosts entry must contain host, key type, and key data") + key_type = parts[1] + key_blob = parts[2] + if key_type not in ALLOWED_HOST_KEY_TYPES: + raise TunnelClientError(f"unsupported known_hosts key type: {key_type}") + + host_token = host if port == 22 else f"[{host}]:{port}" + return f"{host_token} {key_type} {key_blob}" + + +def write_temp_known_hosts(known_host_line: str) -> str: + fd, temp_path = tempfile.mkstemp(prefix="argus-known-hosts-") + os.close(fd) + with open(temp_path, "w", encoding="utf-8") as fh: + fh.write(f"{known_host_line}\n") + os.chmod(temp_path, 0o600) + return temp_path + + +def _unlink(path: str) -> None: + """Remove a file; silently ignore if it does not exist.""" + try: + os.unlink(path) + except FileNotFoundError: + pass + + +def read_process_stderr(process: subprocess.Popen[str]) -> str: + if process.stderr is None: + return "" + try: + return process.stderr.read().strip() + except Exception: # noqa: BLE001 + return "" + + +def is_local_port_open(port: int) -> bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as probe: + probe.settimeout(0.5) + return probe.connect_ex(("127.0.0.1", port)) == 0 diff --git a/argus/client/tunnel/state.py b/argus/client/tunnel/state.py new file mode 100644 index 00000000..9461af20 --- /dev/null +++ b/argus/client/tunnel/state.py @@ -0,0 +1,228 @@ +import json +import logging +import os +import shutil +import subprocess +import tempfile +from datetime import UTC, datetime + +try: + from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey + from cryptography.hazmat.primitives.serialization import ( + Encoding, + NoEncryption, + PrivateFormat, + PublicFormat, + ) +except ImportError: + Ed25519PrivateKey = None + Encoding = None + NoEncryption = None + PrivateFormat = None + PublicFormat = None + +from .models import TunnelClientError, TunnelConfig, TunnelStatePaths, parse_datetime + + +LOGGER = logging.getLogger(__name__) + + +def get_tunnel_state_paths() -> TunnelStatePaths: + state_dir = _resolve_state_dir() + return TunnelStatePaths( + state_dir=state_dir, + private_key=os.path.join(state_dir, "id_argus_proxy"), + public_key=os.path.join(state_dir, "id_argus_proxy.pub"), + key_meta=os.path.join(state_dir, "id_argus_proxy.meta.json"), + config_cache=os.path.join(state_dir, "tunnel_config.json"), + ) + + +def delete_cached_tunnel_state() -> None: + """Delete cached tunnel key/config state used by the client.""" + try: + paths = get_tunnel_state_paths() + except OSError: + return + + for file_path in (paths.private_key, paths.public_key, paths.key_meta, paths.config_cache): + try: + _unlink(file_path) + except OSError: + LOGGER.debug("Failed removing cached tunnel state file: %s", file_path, exc_info=True) + + +def generate_keypair_if_needed(paths: TunnelStatePaths) -> None: + if is_key_valid(paths): + return + + if _generate_keypair_with_cryptography(paths): + return + + if shutil.which("ssh-keygen") is None: + raise TunnelClientError("ssh-keygen binary is required to generate SSH keypair") + + _unlink(paths.private_key) + _unlink(paths.public_key) + + result = subprocess.run( # noqa: S603 + [ + "ssh-keygen", + "-q", + "-t", + "ed25519", + "-N", + "", + "-f", + paths.private_key, + "-C", + "argus-proxy", + ], + check=False, + capture_output=True, + text=True, + ) + if result.returncode != 0: + stderr = (result.stderr or "").strip() + raise TunnelClientError(f"ssh-keygen failed: {stderr or 'unknown error'}") + + os.chmod(paths.private_key, 0o600) + os.chmod(paths.public_key, 0o644) + + +def _generate_keypair_with_cryptography(paths: TunnelStatePaths) -> bool: + if Ed25519PrivateKey is None: + return False + + try: + private_key = Ed25519PrivateKey.generate() + public_key = private_key.public_key() + + private_bytes = private_key.private_bytes( + encoding=Encoding.PEM, + format=PrivateFormat.OpenSSH, + encryption_algorithm=NoEncryption(), + ) + public_bytes = public_key.public_bytes( + encoding=Encoding.OpenSSH, + format=PublicFormat.OpenSSH, + ) + + _write_bytes(paths.private_key, private_bytes) + _write_bytes(paths.public_key, public_bytes + b"\n") + os.chmod(paths.private_key, 0o600) + os.chmod(paths.public_key, 0o644) + return True + except Exception as exc: # noqa: BLE001 + LOGGER.warning("Falling back to ssh-keygen due to cryptography key generation failure: %s", exc) + return False + + +def is_key_valid(paths: TunnelStatePaths) -> bool: + if not os.path.exists(paths.private_key) or not os.path.exists(paths.public_key) or not os.path.exists(paths.key_meta): + return False + + try: + key_meta = json.loads(_read_text(paths.key_meta)) + expires_at = parse_datetime(key_meta.get("expires_at")) + except (OSError, json.JSONDecodeError, TypeError, ValueError): + return False + + now = datetime.now(tz=UTC) + return now < expires_at + + +def write_key_meta(paths: TunnelStatePaths, expires_at: datetime | None) -> None: + if expires_at is None: + return + payload = {"expires_at": expires_at.astimezone(UTC).isoformat()} + _write_text(paths.key_meta, json.dumps(payload)) + os.chmod(paths.key_meta, 0o600) + + +def read_cached_tunnel_config(paths: TunnelStatePaths) -> TunnelConfig | None: + if not os.path.exists(paths.config_cache): + return None + + try: + payload = json.loads(_read_text(paths.config_cache)) + except (OSError, json.JSONDecodeError): + return None + + try: + config = TunnelConfig.from_api_response(payload) + except TunnelClientError: + return None + + if config.expires_at is not None and datetime.now(tz=UTC) >= config.expires_at: + return None + + return config + + +def write_tunnel_cache(paths: TunnelStatePaths, config: TunnelConfig) -> None: + _write_text(paths.config_cache, json.dumps(config.to_cache_payload())) + os.chmod(paths.config_cache, 0o600) + + +def _resolve_state_dir() -> str: + candidates: list[str] = [] + + if configured_dir := os.environ.get("ARGUS_TUNNEL_STATE_DIR"): + candidates.append(os.path.expanduser(configured_dir)) + + candidates.append(os.path.join(os.path.expanduser("~"), ".ssh")) + + if runtime_dir := os.environ.get("XDG_RUNTIME_DIR"): + candidates.append(os.path.join(runtime_dir, "argus-tunnel")) + + candidates.append(os.path.join(tempfile.gettempdir(), f"argus-tunnel-{os.getuid()}")) + + for candidate in candidates: + if _prepare_state_dir(candidate): + return candidate + + raise OSError("No writable directory available for SSH tunnel state") + + +def _prepare_state_dir(path: str) -> bool: + try: + if os.path.exists(path): + if not os.path.isdir(path): + return False + if not os.access(path, os.W_OK | os.X_OK): + return False + return True + + os.makedirs(path, mode=0o700, exist_ok=True) + os.chmod(path, 0o700) + return True + except OSError: + return False + + +# --------------------------------------------------------------------------- +# Small helpers to replace pathlib method calls with plain os / builtins +# --------------------------------------------------------------------------- + +def _unlink(path: str) -> None: + """Remove a file; silently ignore if it does not exist.""" + try: + os.unlink(path) + except FileNotFoundError: + pass + + +def _read_text(path: str, encoding: str = "utf-8") -> str: + with open(path, encoding=encoding) as fh: + return fh.read() + + +def _write_text(path: str, text: str, encoding: str = "utf-8") -> None: + with open(path, "w", encoding=encoding) as fh: + fh.write(text) + + +def _write_bytes(path: str, data: bytes) -> None: + with open(path, "wb") as fh: + fh.write(data) diff --git a/argus_backend.py b/argus_backend.py index 3a46331d..44adb20c 100644 --- a/argus_backend.py +++ b/argus_backend.py @@ -19,6 +19,22 @@ LOGGER = logging.getLogger(__name__) +def _categorize_user_agent(ua: str) -> str: # noqa: PLR0911 + if not ua: + return "unknown" + if "argus-client-ssh-tunnel" in ua: + return "argus-client-tunnel" + if ua.startswith(("python-requests", "python-urllib")): + return "argus-client" + if ua.startswith("Go-http-client"): + return "argus-cli-go" + if "Mozilla" in ua: + return "browser" + if "curl" in ua: + return "curl" + return "other" + + def register_metrics(): METRICS.export_defaults(group_by="endpoint", prefix=NO_PREFIX) METRICS.register_default( @@ -32,6 +48,37 @@ def register_metrics(): }, ) ) + METRICS.register_default( + METRICS.counter( + "http_request_by_ip_total", + "Total requests by source IP", + labels={ + "ip": lambda: request.remote_addr, + "endpoint": lambda: request.endpoint, + }, + ) + ) + METRICS.register_default( + METRICS.counter( + "http_request_ssh_tunnel_total", + "Total requests by SSH tunnel presence", + labels={ + "ssh_tunnel": lambda: "yes" if request.headers.get("X-SSH-Tunnel-Origin") else "no", + "tunnel_established": lambda: "yes" if request.headers.get("X-Tunnel-Established-At") else "no", + "endpoint": lambda: request.endpoint, + }, + ) + ) + METRICS.register_default( + METRICS.counter( + "http_request_by_user_agent_total", + "Total requests by user agent category", + labels={ + "user_agent_category": lambda: _categorize_user_agent(request.headers.get("User-Agent", "")), + "endpoint": lambda: request.endpoint, + }, + ) + ) def start_server(config=None) -> Flask: diff --git a/cli/.goreleaser.yml b/cli/.goreleaser.yml index e06875b6..5df76c39 100644 --- a/cli/.goreleaser.yml +++ b/cli/.goreleaser.yml @@ -21,7 +21,7 @@ builds: # Inject version information at link time ldflags: - -s -w - - -X main.version={{ .Version }} + - -X main.version={{ .Env.VERSION }} - -X main.commit={{ .Commit }} - -X main.date={{ .Date }} env: @@ -39,7 +39,7 @@ builds: archives: - id: argus-archive name_template: >- - argus_{{ .Version }}_ + argus_{{ .Env.VERSION }}_ {{- if eq .Os "darwin"}}macOS{{- else}}{{ .Os }}{{- end}}_ {{- .Arch }} formats: @@ -60,7 +60,7 @@ archives: strip_parent: true checksum: - name_template: "argus_{{ .Version }}_checksums.txt" + name_template: "argus_{{ .Env.VERSION }}_checksums.txt" algorithm: sha256 release: diff --git a/cli/README.md b/cli/README.md new file mode 100644 index 00000000..ade4f715 --- /dev/null +++ b/cli/README.md @@ -0,0 +1,190 @@ +# Argus CLI + +Command-line interface for [Argus](https://argus.scylladb.com) — a test tracking system for automated pipelines. Use it to inspect test runs, fetch logs, stream activity, submit comments, and manage SSH tunnels into Argus clusters. + +--- + +## Installation + +Download the latest release from the [releases page](https://github.com/scylladb/argus/releases) and place the binary somewhere on your `$PATH`. + +For a one-liner install, see [AGENTS.md](../AGENTS.md). + +### From source + +Requires Go 1.25+. + +```bash +git clone https://github.com/scylladb/argus +cd argus/cli +go build -o argus . +``` + +--- + +## Authentication + +**This is where most people get stuck.** Read this section before running any command. + +### How it works + +The CLI needs two things to talk to Argus: + +1. A **Personal Access Token (PAT)** — a long-lived token stored in your system keychain (macOS Keychain, Windows Credential Manager, Linux Secret Service / `pass`). +2. A **Cloudflare Access credential** — required only when Argus is behind Cloudflare Access (the default for production at `argus.scylladb.com`). + +The first time you run `argus auth`, it fetches both automatically and stores them. Subsequent commands pull them from the keychain silently. If a credential expires, the CLI re-authenticates transparently and retries. + +--- + +### Cloudflared: what it is and when you need it + +`cloudflared` is Cloudflare's tunnel client. The CLI uses it to obtain a short-lived JWT that proves to Cloudflare Access that you are allowed through the firewall before your request ever reaches Argus. + +**You need cloudflared when:** +- Connecting to `argus.scylladb.com` or any other Argus instance protected by Cloudflare Access. +- Running `argus auth` (browser-based login). + +**You do NOT need cloudflared when:** +- Connecting to `localhost` or any loopback address — the CLI detects this automatically and skips all Cloudflare logic. +- Using a service-account (headless) setup with `CF-Access-Client-Id` / `CF-Access-Client-Secret` credentials. +- Running in CI/CD where you set `ARGUS_AUTH_TOKEN` or `ARGUS_TOKEN` directly. +- You explicitly disable it (see [Bypassing Cloudflare](#bypassing-cloudflare) below). + +**The CLI manages cloudflared for you.** It checks `$PATH`, then its own cache at `$XDG_CACHE_HOME/argus-cli/cloudflared`, and downloads the latest release from GitHub if neither is found. You do not need to install it manually. + +--- + +### Auth modes + +#### Mode 1 — Browser login (default, interactive) + +For humans authenticating against a production Argus behind Cloudflare Access. + +``` +argus auth +``` + +Flow: +1. Checks keychain — exits early if credentials are still valid. +2. Invokes `cloudflared access login` — opens a browser window for Cloudflare Access SSO. +3. Exchanges the resulting JWT for an Argus session. +4. Converts the session into a durable PAT and stores it in the keychain. + +You only need to do this once. After that, every command works without re-authentication. + +#### Mode 2 — Headless / service-account (servers and CI) + +**If you are running on a server or in CI, this is the only supported mode.** + +The keychain-based modes (browser login and `argus auth-token`) require a system keychain daemon — macOS Keychain, Windows Credential Manager, or Linux Secret Service / `pass`. Most servers and CI runners do not have one. Without it, any command that tries to read from the keychain fails silently and the CLI has no credentials to send. + +The solution is to skip the keychain entirely and supply credentials through environment variables: + +```bash +export ARGUS_CF_ACCESS_CLIENT_ID=your-client-id +export ARGUS_CF_ACCESS_CLIENT_SECRET=your-client-secret +export ARGUS_AUTH_TOKEN=your-pat +``` + +Set these in your CI secret store or server environment and every `argus` command will pick them up automatically — no `argus auth` step, no keychain, no browser. + +To get a service-account client ID and secret, ask your Cloudflare Access administrator. To get an Argus PAT, run `argus auth` once on a developer machine and copy the token out of the keychain, or have an admin generate one via the Argus web UI. + +On a machine that **does** have a keychain, `argus auth headless` stores all three interactively: + +``` +argus auth headless +``` + +Prompts (masked) for the CF Access Client ID, CF Access Client Secret, and Argus PAT, then writes them to the keychain and sets `use_cloudflare: false` in the config file. After that, the CLI sends the CF Access service-account headers on every request instead of invoking `cloudflared`. + +#### Mode 3 — Direct token (local / dev) + +For local Argus instances or any deployment without Cloudflare Access. + +``` +argus auth-token +``` + +Stores the PAT directly. Cloudflare is never consulted. This is equivalent to setting `ARGUS_AUTH_TOKEN` but persists the token to the keychain. + +--- + +### Bypassing Cloudflare + +Several mechanisms disable Cloudflare integration. Use whichever fits your workflow: + +| Mechanism | Scope | When to use | +|---|---|---| +| Loopback URL (`localhost`, `127.*`, `::1`) | automatic | Local dev — no action needed | +| `ARGUS_DISABLE_CLOUDFLARE=true` | env var, process-wide | CI/CD, scripts, one-off commands (also settable via `argus config set use_cloudflare false`) | +| `--disable-cloudflare` flag | single command | Ad-hoc overrides | +| `argus config set use_cloudflare false` | config file, persistent | When you always connect without CF | + +--- + +### Credential priority + +Credentials are layered — later sources override earlier ones, so environment variables always win over the keychain. + +**Cloudflared mode** (default, `use_cloudflare: true`): + +1. PAT from keychain +2. Session cookie from keychain — fallback when no PAT is stored +3. CF Access JWT from `cloudflared` — fetched alongside #1 or #2; required by the CF firewall +4. `ARGUS_AUTH_TOKEN` env var — overrides keychain PAT / session +5. `ARGUS_TOKEN` env var — fallback if `ARGUS_AUTH_TOKEN` is unset +6. `ARGUS_CF_ACCESS_CLIENT_ID` + `ARGUS_CF_ACCESS_CLIENT_SECRET` — overrides the cloudflared JWT + +**Headless mode** (`use_cloudflare: false`): + +1. PAT from keychain +2. CF service-account bundle from keychain — fallback when no PAT is stored (holds both CF headers and an Argus PAT, stored by `argus auth headless`) +3. `ARGUS_AUTH_TOKEN` env var — overrides keychain PAT +4. `ARGUS_TOKEN` env var — fallback if `ARGUS_AUTH_TOKEN` is unset +5. `ARGUS_CF_ACCESS_CLIENT_ID` + `ARGUS_CF_ACCESS_CLIENT_SECRET` — overrides CF headers from keychain + +--- + +### Logging out + +``` +argus auth logout +``` + +Removes all stored credentials (PAT, session, CF service-account bundle) from the system keychain. + +--- + +## Configuration + +Config file lives at `$XDG_CONFIG_HOME/argus-cli/config.yaml` (on Linux: `~/.config/argus-cli/config.yaml`). It is created with defaults on first run. + +```yaml +url: https://argus.scylladb.com +use_cloudflare: true +``` + +Manage it with: + +```bash +argus config list +argus config get url +argus config set url https://my-argus.internal +argus config set use_cloudflare false +``` + +--- + +## Storage locations + +| Purpose | Path | +|---|---| +| Config file | `$XDG_CONFIG_HOME/argus-cli/config.yaml` | +| Cached responses | `$XDG_CACHE_HOME/argus-cli/cache/` | +| Cloudflared binary | `$XDG_CACHE_HOME/argus-cli/cloudflared` | +| Logs | `$XDG_CACHE_HOME/argus-cli/logs/` | +| Credentials | System keychain | + +On Linux `$XDG_CONFIG_HOME` defaults to `~/.config` and `$XDG_CACHE_HOME` to `~/.cache`. diff --git a/cli/cmd/pytest.go b/cli/cmd/pytest.go new file mode 100644 index 00000000..6f8c8120 --- /dev/null +++ b/cli/cmd/pytest.go @@ -0,0 +1,147 @@ +package cmd + +import ( + "fmt" + "net/url" + "strings" + + "github.com/scylladb/argus/cli/internal/api" + "github.com/scylladb/argus/cli/internal/cache" + "github.com/scylladb/argus/cli/internal/logging" + "github.com/scylladb/argus/cli/internal/models" + "github.com/spf13/cobra" +) + +// --------------------------------------------------------------------------- +// Parent command: pytest +// --------------------------------------------------------------------------- + +var pytestCmd = &cobra.Command{ + Use: "pytest", + Short: "Commands for pytest result operations", + Long: `Query and inspect pytest results tracked by Argus.`, +} + +// --------------------------------------------------------------------------- +// Subcommand: pytest results +// --------------------------------------------------------------------------- + +var pytestResultsCmd = &cobra.Command{ + Use: "results", + Short: "Query filtered pytest results", + Long: `Fetch pytest results with optional filtering by test name, status, +time range, markers, user-defined fields, and free-text search. + +All flags are optional. Without any flags the endpoint returns the most +recent 500 results across all tests.`, + RunE: func(cmd *cobra.Command, _ []string) error { + cmd.SilenceUsage = true + ctx := cmd.Context() + client := APIClientFrom(ctx) + out := OutputterFrom(ctx) + c := CacheFrom(ctx) + log := logging.For(LoggerFrom(ctx), "pytest-results") + + // Read flags. + test, _ := cmd.Flags().GetString("test") + limit, _ := cmd.Flags().GetInt("limit") + before, _ := cmd.Flags().GetInt64("before") + after, _ := cmd.Flags().GetInt64("after") + statuses, _ := cmd.Flags().GetStringSlice("status") + query, _ := cmd.Flags().GetString("query") + filters, _ := cmd.Flags().GetStringSlice("filter") + markers, _ := cmd.Flags().GetStringSlice("marker") + + // Build query string. + params := url.Values{} + if test != "" { + params.Set("test", test) + } + if limit > 0 { + params.Set("limit", fmt.Sprint(limit)) + } + if before > 0 { + params.Set("before", fmt.Sprint(before)) + } + if after > 0 { + params.Set("after", fmt.Sprint(after)) + } + for _, s := range statuses { + params.Add("status[]", s) + } + if query != "" { + params.Set("query", query) + } + for _, f := range filters { + params.Add("filters[]", f) + } + for _, m := range markers { + params.Add("markers[]", m) + } + + qs := params.Encode() + route := api.PytestFilterResults + if qs != "" { + route += "?" + qs + } + + log.Debug().Str("route", route).Msg("fetching filtered pytest results") + + // Check cache. + cacheKey := cache.PytestFilterKey(qs) + if cached, _, err := cache.Get[models.PytestFilterResponse](c, cacheKey); isCacheable(err) { + log.Debug().Msg("pytest filter results served from cache") + return out.Write(cached) + } + + req, err := client.NewRequest(ctx, "GET", route, nil) + if err != nil { + log.Error().Err(err).Str("route", route).Msg("failed to build request") + return err + } + + result, err := api.DoJSON[models.PytestFilterResponse](client, req) + if err != nil { + log.Error().Err(err).Msg("failed to fetch filtered pytest results") + return err + } + + if cacheErr := cache.Set(c, cacheKey, result, route, cache.TTLPytestFilter); cacheErr != nil { + log.Warn().Err(cacheErr).Msg("failed to cache pytest filter results") + } + + log.Info().Int("total", result.Total).Int("hits", len(result.Hits)).Msg("pytest filter results fetched successfully") + return out.Write(result) + }, +} + +// validPytestStatuses returns a comma-separated list of valid pytest status +// values for use in help text. +func validPytestStatuses() string { + return strings.Join([]string{ + string(models.PytestStatusPassed), + string(models.PytestStatusFailure), + string(models.PytestStatusError), + string(models.PytestStatusSkipped), + string(models.PytestStatusXFailed), + string(models.PytestStatusXPass), + string(models.PytestStatusPassedError), + string(models.PytestStatusFailureError), + string(models.PytestStatusSkippedError), + string(models.PytestStatusErrorError), + }, ", ") +} + +func init() { + pytestResultsCmd.Flags().String("test", "", "Filter test names by substring") + pytestResultsCmd.Flags().Int("limit", 500, "Maximum number of results to return") + pytestResultsCmd.Flags().Int64("before", 0, "Only results before this Unix timestamp") + pytestResultsCmd.Flags().Int64("after", 0, "Only results after this Unix timestamp") + pytestResultsCmd.Flags().StringSlice("status", nil, "Filter by status (repeatable): "+validPytestStatuses()) + pytestResultsCmd.Flags().String("query", "", "Regex search pattern across result name/message/markers") + pytestResultsCmd.Flags().StringSlice("filter", nil, "User-field filter (repeatable, format: [!]field=value)") + pytestResultsCmd.Flags().StringSlice("marker", nil, "Pytest marker filter (repeatable)") + + pytestCmd.AddCommand(pytestResultsCmd) + rootCmd.AddCommand(pytestCmd) +} diff --git a/cli/cmd/root.go b/cli/cmd/root.go index 650c9f5c..3185632c 100644 --- a/cli/cmd/root.go +++ b/cli/cmd/root.go @@ -50,6 +50,7 @@ var ( cacheTTL string nonInteractive bool verbosity int + noColor bool ) func init() { @@ -63,6 +64,7 @@ func init() { rootCmd.PersistentFlags().StringVar(&cacheTTL, "cache-ttl", "", "override the default cache TTL (e.g. 10m, 1h); ignored when --no-cache is set") rootCmd.PersistentFlags().BoolVar(&nonInteractive, "non-interactive", false, "disable interactive prompts; return an error instead of triggering re-authentication") rootCmd.PersistentFlags().CountVarP(&verbosity, "verbose", "v", "increase log verbosity: -v/-vv mirrors info logs to stderr, -vvv mirrors debug logs to stdout") + rootCmd.PersistentFlags().BoolVar(&noColor, "no-color", false, "disable colored output; use bracket status indicators instead (e.g. (OK), (FAIL))") } var rootCmd = &cobra.Command{ diff --git a/cli/cmd/run_issue.go b/cli/cmd/run_issue.go new file mode 100644 index 00000000..a5555dcd --- /dev/null +++ b/cli/cmd/run_issue.go @@ -0,0 +1,165 @@ +package cmd + +import ( + "encoding/json" + + "github.com/scylladb/argus/cli/internal/api" + "github.com/scylladb/argus/cli/internal/logging" + "github.com/scylladb/argus/cli/internal/models" + "github.com/scylladb/argus/cli/internal/services" + "github.com/spf13/cobra" + + "fmt" +) + +// --------------------------------------------------------------------------- +// Subcommand: run issue +// --------------------------------------------------------------------------- + +var issueCmd = &cobra.Command{ + Use: "issue", + Short: "Commands for run issue operations", + Long: `Manage issues linked to test runs in Argus.`, +} + +// --------------------------------------------------------------------------- +// Subcommand: run issue add +// --------------------------------------------------------------------------- + +var issueAddCmd = &cobra.Command{ + Use: "add", + Short: "Submit an issue for a test run", + Long: `Link an issue (GitHub or Jira) to a test run. + +If --test-id is omitted it will be resolved automatically from the run.`, + RunE: func(cmd *cobra.Command, _ []string) error { + cmd.SilenceUsage = true + ctx := cmd.Context() + client := APIClientFrom(ctx) + out := OutputterFrom(ctx) + c := CacheFrom(ctx) + log := logging.For(LoggerFrom(ctx), "run-issue-add") + + runID, _ := cmd.Flags().GetString("run-id") + issueURL, _ := cmd.Flags().GetString("issue-url") + flagTestID, _ := cmd.Flags().GetString("test-id") + + log.Debug().Str("run_id", runID).Str("issue_url", issueURL).Str("test_id", flagTestID).Msg("submitting issue") + + fetcher := newRunFetcher() + testID, err := services.ResolveTestID(ctx, client, c, fetcher, runID, flagTestID) + if err != nil { + log.Error().Err(err).Str("run_id", runID).Msg("failed to resolve test ID") + return err + } + log.Debug().Str("run_id", runID).Str("test_id", testID).Msg("test ID resolved") + + route := fmt.Sprintf(api.TestRunIssueSubmit, testID, runID) + body := map[string]string{"issue_url": issueURL} + req, err := client.NewRequest(ctx, "POST", route, body) + if err != nil { + log.Error().Err(err).Str("run_id", runID).Str("route", route).Msg("failed to build request") + return err + } + + result, err := api.DoJSON[json.RawMessage](client, req) + if err != nil { + log.Error().Err(err).Str("run_id", runID).Str("issue_url", issueURL).Msg("failed to submit issue") + return err + } + + log.Info().Str("run_id", runID).Str("test_id", testID).Str("issue_url", issueURL).Msg("issue submitted successfully") + return out.Write(result) + }, +} + +// --------------------------------------------------------------------------- +// Subcommand: run issue list +// --------------------------------------------------------------------------- + +var issueListCmd = &cobra.Command{ + Use: "list", + Short: "List issues linked to a test run or other entity", + Long: `Fetch issues (GitHub and Jira) linked to an Argus entity. + +Exactly one filter flag must be provided: + --run-id, --release-id, --group-id, --test-id, --user-id, --view-id, --event-id`, + RunE: func(cmd *cobra.Command, _ []string) error { + cmd.SilenceUsage = true + ctx := cmd.Context() + client := APIClientFrom(ctx) + out := OutputterFrom(ctx) + log := logging.For(LoggerFrom(ctx), "issue-list") + + filters := []struct { + flag string + key string + }{ + {"run-id", "run_id"}, + {"release-id", "release_id"}, + {"group-id", "group_id"}, + {"test-id", "test_id"}, + {"user-id", "user_id"}, + {"view-id", "view_id"}, + {"event-id", "event_id"}, + } + + var filterKey, entityID string + for _, f := range filters { + if v, _ := cmd.Flags().GetString(f.flag); v != "" { + if filterKey != "" { + return fmt.Errorf("only one filter flag may be specified") + } + filterKey = f.key + entityID = v + } + } + if filterKey == "" { + return fmt.Errorf("one of --run-id, --release-id, --group-id, --test-id, --user-id, --view-id, or --event-id is required") + } + + log.Debug().Str("entity_id", entityID).Str("filter_key", filterKey).Msg("listing issues") + + route := fmt.Sprintf("%s?filterKey=%s&id=%s", api.IssuesGet, filterKey, entityID) + log.Debug().Str("route", route).Msg("fetching issues from API") + + req, err := client.NewRequest(ctx, "GET", route, nil) + if err != nil { + log.Error().Err(err).Str("route", route).Msg("failed to build request") + return err + } + + result, err := api.DoJSON[[]json.RawMessage](client, req) + if err != nil { + log.Error().Err(err).Str("entity_id", entityID).Msg("failed to fetch issues") + return err + } + + issues := models.ParseIssues(result) + log.Info().Str("entity_id", entityID).Int("count", len(issues)).Msg("issues fetched successfully") + return out.Write(models.NewTabularSlice(issues)) + }, +} + +// --------------------------------------------------------------------------- +// Registration +// --------------------------------------------------------------------------- + +func init() { + issueAddCmd.Flags().String("run-id", "", "Run UUID (required)") + issueAddCmd.Flags().String("issue-url", "", "Issue URL to link (required)") + issueAddCmd.Flags().String("test-id", "", "Test UUID (optional, resolved from the run if omitted)") + _ = issueAddCmd.MarkFlagRequired("run-id") + _ = issueAddCmd.MarkFlagRequired("issue-url") + + issueListCmd.Flags().String("run-id", "", "Filter by run UUID") + issueListCmd.Flags().String("release-id", "", "Filter by release UUID") + issueListCmd.Flags().String("group-id", "", "Filter by group UUID") + issueListCmd.Flags().String("test-id", "", "Filter by test UUID") + issueListCmd.Flags().String("user-id", "", "Filter by user UUID") + issueListCmd.Flags().String("view-id", "", "Filter by view UUID") + issueListCmd.Flags().String("event-id", "", "Filter by event UUID") + + issueCmd.AddCommand(issueAddCmd, issueListCmd) + rootCmd.AddCommand(issueCmd) +} diff --git a/cli/cmd/testrun.go b/cli/cmd/testrun.go index 6f62d19e..f62dfcf9 100644 --- a/cli/cmd/testrun.go +++ b/cli/cmd/testrun.go @@ -515,12 +515,29 @@ var resultsCmd = &cobra.Command{ testID, _ := cmd.Flags().GetString("test-id") runID, _ := cmd.Flags().GetString("run-id") + + // Auto-resolve test-id from the run when not provided. + if testID == "" { + log.Debug().Str("run_id", runID).Msg("resolving test-id from run") + fetcher := newRunFetcher() + resolved, err := services.ResolveTestID(ctx, client, c, fetcher, runID, "") + if err != nil { + log.Error().Err(err).Str("run_id", runID).Msg("failed to resolve test-id") + return fmt.Errorf("resolving test-id: %w", err) + } + testID = resolved + log.Debug().Str("run_id", runID).Str("test_id", testID).Msg("test-id resolved from run") + } + log.Debug().Str("test_id", testID).Str("run_id", runID).Msg("fetching run results") cacheKey := cache.ResultsKey(testID, runID) if cached, _, err := cache.Get[models.FetchResultsResponse](c, cacheKey); isCacheable(err) { log.Debug().Str("run_id", runID).Msg("results served from cache") + showURLs, _ := cmd.Flags().GetBool("show-urls") + cached.ShowURLs = showURLs + cached.NoColor = noColor return out.Write(cached) } @@ -558,11 +575,14 @@ var resultsCmd = &cobra.Command{ return err } - result := models.FetchResultsResponse{Tables: envelope.Tables} + result := models.FetchResultsResponse{ResultTables: envelope.Tables} + showURLs, _ := cmd.Flags().GetBool("show-urls") + result.ShowURLs = showURLs + result.NoColor = noColor if cacheErr := cache.Set(c, cacheKey, result, route, cache.TTLResults); cacheErr != nil { log.Warn().Err(cacheErr).Str("run_id", runID).Msg("failed to cache results") } - log.Info().Str("test_id", testID).Str("run_id", runID).Int("table_count", len(result.Tables)).Msg("results fetched successfully") + log.Info().Str("test_id", testID).Str("run_id", runID).Int("table_count", len(result.ResultTables)).Msg("results fetched successfully") return out.Write(result) }, } @@ -742,9 +762,9 @@ func init() { _ = activityCmd.MarkFlagRequired("run-id") // run results - resultsCmd.Flags().String("test-id", "", "Test UUID (required)") + resultsCmd.Flags().String("test-id", "", "Test UUID (optional, auto-resolved from run-id if omitted)") resultsCmd.Flags().String("run-id", "", "Run UUID (required)") - _ = resultsCmd.MarkFlagRequired("test-id") + resultsCmd.Flags().Bool("show-urls", false, "Display full URLs in cell values instead of 'link'") _ = resultsCmd.MarkFlagRequired("run-id") // run comments diff --git a/cli/internal/api/routes.go b/cli/internal/api/routes.go index e6f04c17..54b38ec5 100644 --- a/cli/internal/api/routes.go +++ b/cli/internal/api/routes.go @@ -25,8 +25,13 @@ const ( TestRunCommentUpdate = "/api/v1/test/%s/run/%s/comment/%s/update" // POST – update a comment (test_id, run_id, comment_id) TestRunCommentDelete = "/api/v1/test/%s/run/%s/comment/%s/delete" // POST – delete a comment (test_id, run_id, comment_id) + // Issue routes + TestRunIssueSubmit = "/api/v1/test/%s/run/%s/issues/submit" // POST – submit an issue (test_id, run_id) + IssuesGet = "/api/v1/issues/get" // GET – list issues (filterKey, id query params) + // Pytest result routes - TestRunPytestResults = "/api/v1/run/%s/pytest/results" // GET – pytest results for a run (run_id) + TestRunPytestResults = "/api/v1/run/%s/pytest/results" // GET – pytest results for a run (run_id) + PytestFilterResults = "/api/v1/views/widgets/pytest/results" // GET – filtered pytest results (query params: test, limit, before, after, status[], query, filters[], markers[]) // Log file routes TestRunLogDownload = "/api/v1/tests/%s/%s/log/%s/download" // GET – download log file, 302 to S3 (plugin_name, run_id, log_name) @@ -44,11 +49,11 @@ const ( SSHTunnel = "/api/v1/client/ssh/tunnel" // POST – register public key and receive proxy config // SSH tunnel routes – admin (requires Admin role) - AdminProxyTunnelConfig = "/admin/api/v1/proxy-tunnel/config" // GET – one active config (tunnel_id query param optional); POST – create - AdminProxyTunnelConfigs = "/admin/api/v1/proxy-tunnel/configs" // GET – all configs (active_only query param optional) + AdminProxyTunnelConfig = "/admin/api/v1/proxy-tunnel/config" // GET – one active config (tunnel_id query param optional); POST – create + AdminProxyTunnelConfigs = "/admin/api/v1/proxy-tunnel/configs" // GET – all configs (active_only query param optional) AdminProxyTunnelSetActive = "/admin/api/v1/proxy-tunnel/config/%s/active" // POST – enable/disable a config (tunnel_id) AdminProxyTunnelDelete = "/admin/api/v1/proxy-tunnel/config/%s" // DELETE – permanently remove a config (tunnel_id) - AdminSSHKeys = "/admin/api/v1/ssh/keys" // GET – list all registered keys with metadata + AdminSSHKeys = "/admin/api/v1/ssh/keys" // GET – list all registered keys with metadata AdminSSHKeyDelete = "/admin/api/v1/ssh/keys/%s" // DELETE – revoke a key (key_id) ) diff --git a/cli/internal/cache/keys.go b/cli/internal/cache/keys.go index 82898ba0..547e8f44 100644 --- a/cli/internal/cache/keys.go +++ b/cli/internal/cache/keys.go @@ -1,6 +1,8 @@ package cache import ( + "crypto/sha256" + "fmt" "path" "time" ) @@ -29,6 +31,10 @@ const ( // TTLPytestResults is the TTL for pytest result lists. TTLPytestResults = 5 * time.Minute + // TTLPytestFilter is the TTL for filtered pytest results. + // Shorter than per-run results since filters are dynamic. + TTLPytestFilter = 2 * time.Minute + // TTLVersion is the TTL for the API version response. // The version only changes on a new server deployment. TTLVersion = time.Hour @@ -158,3 +164,13 @@ func NemesesFilteredKey(runID, before, after string) string { } return path.Join("nemeses", runID, before, after) } + +// PytestFilterKey returns the cache key for a filtered pytest results query. +// The queryString is hashed to produce a fixed-length directory name that +// uniquely identifies the parameter combination. +// +// On disk: cache/pytest-filter/{hash}/ +func PytestFilterKey(queryString string) string { + h := sha256.Sum256([]byte(queryString)) + return path.Join("pytest-filter", fmt.Sprintf("%x", h[:8])) +} diff --git a/cli/internal/models/issues.go b/cli/internal/models/issues.go new file mode 100644 index 00000000..57705121 --- /dev/null +++ b/cli/internal/models/issues.go @@ -0,0 +1,61 @@ +package models + +import "encoding/json" + +// Issue is a unified display model for both GitHub and Jira issues returned by +// the /issues/get endpoint. The Key field is synthesized: owner/repo#number for +// GitHub issues, or the Jira issue key. +type Issue struct { + Key string `json:"key"` + Subtype string `json:"subtype"` + Title string `json:"title"` + State string `json:"state"` + URL string `json:"url"` +} + +// ParseIssues converts raw JSON issue objects (which may be GitHub or Jira +// flavored) into a unified slice of Issue for tabular display. +func ParseIssues(raw []json.RawMessage) []Issue { + issues := make([]Issue, 0, len(raw)) + for _, r := range raw { + var m map[string]json.RawMessage + if err := json.Unmarshal(r, &m); err != nil { + continue + } + + var issue Issue + issue.Subtype = unquote(m["subtype"]) + issue.State = unquote(m["state"]) + + switch issue.Subtype { + case "github": + owner := unquote(m["owner"]) + repo := unquote(m["repo"]) + var number json.Number + _ = json.Unmarshal(m["number"], &number) + issue.Key = owner + "/" + repo + "#" + number.String() + issue.Title = unquote(m["title"]) + issue.URL = unquote(m["url"]) + case "jira": + issue.Key = unquote(m["key"]) + issue.Title = unquote(m["summary"]) + issue.URL = unquote(m["permalink"]) + default: + issue.Title = unquote(m["title"]) + issue.URL = unquote(m["url"]) + } + + issues = append(issues, issue) + } + return issues +} + +// unquote removes surrounding quotes from a raw JSON string value. +func unquote(raw json.RawMessage) string { + var s string + if raw == nil { + return "" + } + _ = json.Unmarshal(raw, &s) + return s +} diff --git a/cli/internal/models/runs.go b/cli/internal/models/runs.go index 28213e31..5253a3b1 100644 --- a/cli/internal/models/runs.go +++ b/cli/internal/models/runs.go @@ -7,6 +7,8 @@ import ( "strconv" "strings" "time" + + "github.com/scylladb/argus/cli/internal/output" ) // --------------------------------------------------------------------------- @@ -491,12 +493,16 @@ func (a ActivityResponse) Rows() [][]string { type ResultCell struct { Value any `json:"value"` Status string `json:"status"` + Type string `json:"type,omitempty"` } // ResultColumnMeta describes one column in a ResultTable. type ResultColumnMeta struct { - Name string `json:"name"` - Status string `json:"status"` + Name string `json:"name"` + Unit string `json:"unit,omitempty"` + Type string `json:"type,omitempty"` + HigherIsBetter *bool `json:"higher_is_better,omitempty"` + Visible *bool `json:"visible,omitempty"` } // ResultTable is one performance/result table returned by the fetch_results @@ -505,26 +511,170 @@ type ResultTable struct { Description string `json:"description"` TableData map[string]map[string]ResultCell `json:"table_data"` Columns []ResultColumnMeta `json:"columns"` - Rows []string `json:"rows"` + RowNames []string `json:"rows"` TableStatus string `json:"table_status"` + ShowURLs bool `json:"-"` + NoColor bool `json:"-"` } // Headers implements output.Tabular for ResultTable. The first column is the // row name; subsequent columns come from the table's Columns metadata. +// When a column has a unit defined, it is appended in brackets (e.g. "latency [ms]"). func (rt ResultTable) Headers() []string { h := make([]string, 0, 1+len(rt.Columns)) h = append(h, "Row") for _, c := range rt.Columns { - h = append(h, c.Name) + if c.Unit != "" { + h = append(h, fmt.Sprintf("%s [%s]", c.Name, strings.ReplaceAll(c.Unit, " ", ""))) + } else { + h = append(h, c.Name) + } } return h } +// ANSI escape sequences for cell status highlighting. +const ( + ansiReset = "\033[0m" + ansiGreen = "\033[32m" + ansiRed = "\033[31m" + ansiYellow = "\033[33m" +) + +// colorByStatus wraps s with ANSI color codes matching the cell status, +// mirroring the frontend ResultCellStatusStyleMap. When noColor is true, +// a bracket status indicator is appended instead (e.g. "(OK)", "(FAIL)"). +func colorByStatus(s, status string, noColor bool) string { + if noColor { + switch status { + case "PASS": + return s + " (OK)" + case "ERROR": + return s + " (FAIL)" + case "WARNING": + return s + " (WARN)" + default: + return s + } + } + switch status { + case "PASS": + return ansiGreen + s + ansiReset + case "ERROR": + return ansiRed + s + ansiReset + case "WARNING": + return ansiYellow + s + ansiReset + default: + return s + } +} + +// formatCellValue formats a cell's value according to its type, mirroring the +// frontend Cell.svelte formatting logic. +// +// - FLOAT → 2 decimal places (e.g. "3.14") +// - INTEGER → locale-style thousands separators (e.g. "1,234,567") +// - DURATION → HH:MM:SS +// - URLs → "link" (since terminals can't click) +// - nil → "N/A" +// - default → fmt.Sprint +func formatCellValue(cell ResultCell, showURLs bool) string { + if cell.Value == nil { + return "N/A" + } + + // String values: detect URLs / images like the frontend does. + if s, ok := cell.Value.(string); ok { + if isURL(s) && !showURLs { + return "link" + } + return s + } + + // Numeric values: format based on the column type stored in the cell. + num, ok := toFloat64(cell.Value) + if !ok { + return fmt.Sprint(cell.Value) + } + + switch cell.Type { + case "FLOAT": + return strconv.FormatFloat(num, 'f', 2, 64) + case "INTEGER": + return formatInteger(int64(num)) + case "DURATION": + return formatDuration(num) + default: + return fmt.Sprint(cell.Value) + } +} + +// isURL checks whether s looks like an HTTP(S) URL. +func isURL(s string) bool { + return strings.HasPrefix(s, "http://") || strings.HasPrefix(s, "https://") +} + +// toFloat64 converts a JSON-decoded numeric value (float64 or json.Number) to float64. +func toFloat64(v any) (float64, bool) { + switch n := v.(type) { + case float64: + return n, true + case json.Number: + f, err := n.Float64() + return f, err == nil + case int: + return float64(n), true + case int64: + return float64(n), true + default: + return 0, false + } +} + +// formatInteger formats an integer with comma-separated thousands groups +// (e.g. 1234567 → "1,234,567"), matching Number.toLocaleString() in JS. +func formatInteger(n int64) string { + s := strconv.FormatInt(n, 10) + if n < 0 { + return "-" + insertCommas(s[1:]) + } + return insertCommas(s) +} + +func insertCommas(s string) string { + if len(s) <= 3 { + return s + } + remainder := len(s) % 3 + var b strings.Builder + if remainder > 0 { + b.WriteString(s[:remainder]) + b.WriteByte(',') + } + for i := remainder; i < len(s); i += 3 { + if i > remainder { + b.WriteByte(',') + } + b.WriteString(s[i : i+3]) + } + return b.String() +} + +// formatDuration converts seconds to HH:MM:SS, matching the frontend +// durationToStr helper. +func formatDuration(seconds float64) string { + total := int(seconds) + h := total / 3600 + m := (total % 3600) / 60 + s := total % 60 + return fmt.Sprintf("%02d:%02d:%02d", h, m, s) +} + // Rows implements output.Tabular for ResultTable. Row order follows -// rt.Rows; column order follows rt.Columns. -func (rt ResultTable) TableRows() [][]string { - out := make([][]string, 0, len(rt.Rows)) - for _, rowName := range rt.Rows { +// rt.RowNames; column order follows rt.Columns. +func (rt ResultTable) Rows() [][]string { + out := make([][]string, 0, len(rt.RowNames)) + for _, rowName := range rt.RowNames { row := make([]string, 0, 1+len(rt.Columns)) row = append(row, rowName) colData := rt.TableData[rowName] @@ -533,7 +683,7 @@ func (rt ResultTable) TableRows() [][]string { if !ok { row = append(row, "") } else { - row = append(row, fmt.Sprint(cell.Value)) + row = append(row, colorByStatus(formatCellValue(cell, rt.ShowURLs), cell.Status, rt.NoColor)) } } out = append(out, row) @@ -543,43 +693,96 @@ func (rt ResultTable) TableRows() [][]string { // FetchResultsEnvelope is the non-standard envelope returned by the // fetch_results endpoint. Unlike other endpoints the payload key is "tables" -// rather than "response". +// rather than "response". Each element is a single-key map from the table +// name to its data. type FetchResultsEnvelope struct { - Status string `json:"status"` - Tables []ResultTable `json:"tables"` + Status string `json:"status"` + Tables []map[string]ResultTable `json:"tables"` } -// FetchResultsResponse wraps []ResultTable to implement output.Tabular by -// rendering all tables sequentially. +// FetchResultsResponse wraps the result tables for output. In JSON mode a +// compact representation is produced via MarshalJSON; in text mode each table +// is rendered separately via the MultiTabular interface. type FetchResultsResponse struct { - Tables []ResultTable + ResultTables []map[string]ResultTable `json:"tables"` + ShowURLs bool `json:"-"` + NoColor bool `json:"-"` +} + +// jsonCell is the compact JSON representation of a single result cell. +type jsonCell struct { + Value any `json:"value"` + Status string `json:"status"` +} + +// jsonRow is the compact JSON representation of a single result row. +type jsonRow struct { + Name string `json:"name"` + Cells map[string]jsonCell `json:"cells"` } -// Headers implements output.Tabular. Uses the first table's headers or -// returns a single "Table" column when empty. -func (f FetchResultsResponse) Headers() []string { - if len(f.Tables) == 0 { - return []string{"Table"} +// jsonTable is the compact JSON representation of a single result table. +type jsonTable struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Status string `json:"status"` + Rows []jsonRow `json:"rows"` +} + +// MarshalJSON produces a compact JSON array of tables with inlined row data, +// suitable for consumption by tools like jq. +func (f FetchResultsResponse) MarshalJSON() ([]byte, error) { + tables := make([]jsonTable, 0, len(f.ResultTables)) + for _, entry := range f.ResultTables { + for name, tbl := range entry { + jt := jsonTable{ + Name: name, + Description: tbl.Description, + Status: tbl.TableStatus, + Rows: make([]jsonRow, 0, len(tbl.RowNames)), + } + for _, rowName := range tbl.RowNames { + jr := jsonRow{ + Name: rowName, + Cells: make(map[string]jsonCell, len(tbl.Columns)), + } + colData := tbl.TableData[rowName] + for _, col := range tbl.Columns { + if cell, ok := colData[col.Name]; ok { + jr.Cells[col.Name] = jsonCell{ + Value: cell.Value, + Status: cell.Status, + } + } + } + jt.Rows = append(jt.Rows, jr) + } + tables = append(tables, jt) + } } - // Prefix with "Table" column to distinguish between tables. - h := []string{"Table"} - h = append(h, f.Tables[0].Headers()...) - return h + return json.Marshal(tables) } -// Rows implements output.Tabular. Each result table's rows are emitted with -// the table description prepended as the first column. -func (f FetchResultsResponse) Rows() [][]string { - var rows [][]string - for _, tbl := range f.Tables { - for _, r := range tbl.TableRows() { - row := make([]string, 0, 1+len(r)) - row = append(row, tbl.Description) - row = append(row, r...) - rows = append(rows, row) +// Tables implements output.MultiTabular. Each result table is returned as a +// NamedTable whose name is the table's description (falling back to the map +// key when no description is set). +func (f FetchResultsResponse) Tables() []output.NamedTable { + out := make([]output.NamedTable, 0, len(f.ResultTables)) + for _, entry := range f.ResultTables { + for name, tbl := range entry { + tbl.ShowURLs = f.ShowURLs + tbl.NoColor = f.NoColor + title := name + if tbl.Description != "" && tbl.Description != name { + title = fmt.Sprintf("%s\n%s", name, tbl.Description) + } + out = append(out, output.NamedTable{ + Name: title, + Tab: tbl, + }) } } - return rows + return out } // --------------------------------------------------------------------------- diff --git a/cli/internal/models/tests.go b/cli/internal/models/tests.go index 89135281..4d4c67a6 100644 --- a/cli/internal/models/tests.go +++ b/cli/internal/models/tests.go @@ -1,5 +1,11 @@ package models +import ( + "encoding/json" + "fmt" + "strings" +) + // PytestStatus mirrors argus.common.enums.PytestStatus. type PytestStatus string @@ -53,3 +59,63 @@ type PytestSubmitData struct { // PytestResultListResponse is the response payload for pytest result list // endpoints. type PytestResultListResponse = []PytestResult + +// PytestFilterHit is a single result from the pytest filter endpoint. +// It extends PytestResult with optional user-defined fields. +type PytestFilterHit struct { + Name string `json:"name" table:"Name"` + Status PytestStatus `json:"status" table:"Status"` + ID string `json:"id" table:"ID"` + TestType string `json:"test_type" table:"Test Type"` + RunID string `json:"run_id" table:"Run ID"` + TestID string `json:"test_id" table:"Test ID"` + Duration float64 `json:"duration" table:"Duration"` + Message string `json:"message" table:"Message"` + SessionTimestamp string `json:"session_timestamp" table:"Session Timestamp"` + Markers []string `json:"markers" table:"Markers"` + UserFields map[string]string `json:"user_fields" table:"User Fields"` +} + +// PytestFilterResponse is the response from the pytest filter endpoint. +// Only Total and Hits are used for tabular output; chart data is preserved +// for JSON output mode. +type PytestFilterResponse struct { + Total int `json:"total"` + Hits []PytestFilterHit `json:"hits"` + BarChart json.RawMessage `json:"barChart"` + PieChart json.RawMessage `json:"pieChart"` +} + +// Headers implements the Tabular interface for PytestFilterResponse. +func (PytestFilterResponse) Headers() []string { + return []string{"Name", "Status", "ID", "Test Type", "Run ID", "Duration", "Message", "Session Timestamp", "Markers", "User Fields"} +} + +// Rows implements the Tabular interface for PytestFilterResponse. +func (r PytestFilterResponse) Rows() [][]string { + rows := make([][]string, 0, len(r.Hits)) + for _, h := range r.Hits { + markers := strings.Join(h.Markers, ", ") + var fields string + if len(h.UserFields) > 0 { + parts := make([]string, 0, len(h.UserFields)) + for k, v := range h.UserFields { + parts = append(parts, k+"="+v) + } + fields = strings.Join(parts, ", ") + } + rows = append(rows, []string{ + h.Name, + string(h.Status), + h.ID, + h.TestType, + h.RunID, + fmt.Sprintf("%.2f", h.Duration), + h.Message, + h.SessionTimestamp, + markers, + fields, + }) + } + return rows +} diff --git a/cli/internal/output/output.go b/cli/internal/output/output.go index ca49211c..c1b2ee9b 100644 --- a/cli/internal/output/output.go +++ b/cli/internal/output/output.go @@ -18,6 +18,20 @@ type Tabular interface { Rows() [][]string } +// NamedTable pairs a human-readable name with a [Tabular] value so that +// multi-table outputs can label each table independently. +type NamedTable struct { + Name string + Tab Tabular +} + +// MultiTabular is implemented by values that contain several independent +// tables. The text renderer prints each table separately with a header line; +// the JSON renderer ignores this and marshals the value directly. +type MultiTabular interface { + Tables() []NamedTable +} + // Outputter writes a value to the configured destination in the // implementation-specific format. // diff --git a/cli/internal/output/text.go b/cli/internal/output/text.go index eb1bad69..732c5c79 100644 --- a/cli/internal/output/text.go +++ b/cli/internal/output/text.go @@ -46,6 +46,11 @@ func newText(w io.Writer) Outputter { // Write renders v as a text table. v should implement [Tabular] for // meaningful column output; non-Tabular values fall back to a raw JSON row. func (t *textOutputter) Write(v any) error { + // Multi-table values get one table per entry with a header line. + if mt, ok := v.(MultiTabular); ok { + return t.writeMulti(mt) + } + tab, ok := v.(Tabular) if !ok { switch v := v.(type) { @@ -107,6 +112,44 @@ func (t *textOutputter) writeRawJSON(v any) error { return nil } +// writeMulti renders a [MultiTabular] as a sequence of labelled tables. +func (t *textOutputter) writeMulti(mt MultiTabular) error { + for i, nt := range mt.Tables() { + if i > 0 { + if _, err := fmt.Fprintln(t.w); err != nil { + return fmt.Errorf("%w: %w", ErrTextOutputRow, err) + } + } + if _, err := fmt.Fprintf(t.w, "\n## %s\n\n", nt.Name); err != nil { + return fmt.Errorf("%w: %w", ErrTextOutputRow, err) + } + + table := tablewriter.NewTable(t.w) + configureTableWidths(table) + + if headers := nt.Tab.Headers(); len(headers) > 0 { + table.Header(headers) + } + + for _, row := range nt.Tab.Rows() { + if err := table.Append(row); err != nil { + _ = table.Close() + return fmt.Errorf("%w: %w", ErrTextOutputRow, err) + } + } + + if err := table.Render(); err != nil { + _ = table.Close() + return fmt.Errorf("%w: %w", ErrTextOutputRender, err) + } + + if err := table.Close(); err != nil { + return err + } + } + return nil +} + // colMaxWidth is the maximum character width for any single table column. const colMaxWidth = 80 diff --git a/cli/scripts/extract-semver.sh b/cli/scripts/extract-semver.sh new file mode 100755 index 00000000..f77fe8cf --- /dev/null +++ b/cli/scripts/extract-semver.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +# extract-semver.sh — strip the "cli/" prefix from a release tag. +# +# Usage: +# GITHUB_REF_NAME=cli/v1.2.0 ./extract-semver.sh +# +# Prints the bare semver (e.g. "v1.2.0") to stdout. + +set -euo pipefail + +TAG="${GITHUB_REF_NAME}" +echo "${TAG#cli/}" diff --git a/cli/scripts/resolve-prev-tag.sh b/cli/scripts/resolve-prev-tag.sh new file mode 100755 index 00000000..ec70a69b --- /dev/null +++ b/cli/scripts/resolve-prev-tag.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +# resolve-prev-tag.sh — find the previous cli/v* tag relative to the current one. +# +# Usage: +# GITHUB_REF_NAME=cli/v1.2.0 ./resolve-prev-tag.sh +# +# Prints the previous tag name to stdout, or "FIRST" when no prior tag exists. + +set -euo pipefail + +CURRENT_TAG="${GITHUB_REF_NAME}" + +PREV=$(git tag --merged HEAD --list 'cli/v[0-9]*' --sort=version:refname \ + | (grep -v "^${CURRENT_TAG}$" || true) \ + | tail -1) + +if [[ -z "$PREV" ]]; then + PREV="FIRST" +fi + +echo "${PREV}" diff --git a/cli/scripts/verify-tag-on-main.sh b/cli/scripts/verify-tag-on-main.sh new file mode 100755 index 00000000..91811fcd --- /dev/null +++ b/cli/scripts/verify-tag-on-main.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +# verify-tag-on-main.sh — abort if the current tag is not reachable from master/main. +# +# Usage: +# GITHUB_REF_NAME=cli/v1.2.0 ./verify-tag-on-main.sh +# +# Exits 0 when the tag is on master/main, 1 otherwise. + +set -euo pipefail + +TAG_SHA=$(git rev-list -n1 "${GITHUB_REF_NAME}") +for branch in master main; do + if git show-ref --verify --quiet "refs/remotes/origin/${branch}"; then + if git merge-base --is-ancestor "${TAG_SHA}" "origin/${branch}"; then + echo "Tag ${GITHUB_REF_NAME} (${TAG_SHA}) is on ${branch}. Proceeding." + exit 0 + fi + fi +done +echo "ERROR: Tag ${GITHUB_REF_NAME} is not reachable from master/main. Aborting release." >&2 +exit 1 diff --git a/docs/cli-install.md b/docs/cli-install.md new file mode 100644 index 00000000..64db8003 --- /dev/null +++ b/docs/cli-install.md @@ -0,0 +1,45 @@ +# Argus CLI — Installation Guide + +## Step 1: Check if already installed and up to date + +```bash +echo "installed: $(argus version 2>/dev/null || echo 'not installed')" +echo "latest: $(curl -sL https://api.github.com/repos/scylladb/argus/releases | grep -oP '"tag_name":\s*"cli/v\K[^"]+' | head -1)" +``` + +If the installed version matches the latest, stop here. Otherwise continue. + +## Step 2: Install or upgrade + +Set the version to install (use the latest from Step 1): + +```bash +VERSION=0.1.2 +``` + +### Linux amd64 + +```bash +curl -sL "https://github.com/scylladb/argus/releases/download/cli/v${VERSION}/argus_${VERSION}_linux_amd64.tar.gz" | tar xz -C ~/.local/bin argus +``` + +### macOS amd64 + +```bash +curl -sL "https://github.com/scylladb/argus/releases/download/cli/v${VERSION}/argus_${VERSION}_macOS_amd64.tar.gz" | tar xz -C ~/.local/bin argus +``` + +## Step 3: Verify + +```bash +export PATH="$HOME/.local/bin:$PATH" +argus version +``` + +## Configuration + +The CLI reads `~/.config/argus/config.yaml`. Run `argus configure` to set up interactively. + +## All releases + +https://github.com/scylladb/argus/releases?q=cli diff --git a/frontend/TestRun/TestRun.svelte b/frontend/TestRun/TestRun.svelte index be19dc86..2491ebb3 100644 --- a/frontend/TestRun/TestRun.svelte +++ b/frontend/TestRun/TestRun.svelte @@ -78,6 +78,7 @@ end_time: string, build_job_url: string, subtest_name: string, + events: any[], status: string, test_id: string, release_id: string, @@ -339,9 +340,14 @@ - + {#if testRun.events.length > 0} + + {/if} @@ -357,9 +363,6 @@ -