From c2ad576c43eac9d557f4b80d1f42b03e004ef3ae Mon Sep 17 00:00:00 2001 From: ReeceHoffmann Date: Thu, 11 Jun 2026 12:21:51 -0700 Subject: [PATCH 1/6] feat: add cached reference collapse step --- Dockerfile | 7 +- fixtures.py | 57 ++- python/workflow_pathoscope/utils.py | 233 ++++++++++ tests/assests/redundant_reference.json | 97 ++++ tests/test_workflow.py | 585 +++++++++++++++++-------- workflow.py | 75 +++- 6 files changed, 851 insertions(+), 203 deletions(-) create mode 100644 tests/assests/redundant_reference.json diff --git a/Dockerfile b/Dockerfile index 5cd9cd9..f08955e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,8 +1,9 @@ FROM python:3.13-bookworm AS deps WORKDIR /app -COPY --from=ghcr.io/virtool/tools:1.1.0 /tools/bowtie2/2.5.4/bowtie* /usr/local/bin/ -COPY --from=ghcr.io/virtool/tools:1.1.0 /tools/pigz/2.8/pigz /usr/local/bin/ -COPY --from=ghcr.io/virtool/tools:1.1.0 /tools/samtools/1.22.1/bin/samtools /usr/local/bin/ +COPY --from=ghcr.io/virtool/tools:1.2.0 /tools/bowtie2/2.5.4/bowtie* /usr/local/bin/ +COPY --from=ghcr.io/virtool/tools:1.2.0 /tools/cd-hit/4.8.1/cd-hit-est /usr/local/bin/ +COPY --from=ghcr.io/virtool/tools:1.2.0 /tools/pigz/2.8/pigz /usr/local/bin/ +COPY --from=ghcr.io/virtool/tools:1.2.0 /tools/samtools/1.22.1/bin/samtools /usr/local/bin/ FROM python:3.13-bookworm AS uv WORKDIR /app diff --git a/fixtures.py b/fixtures.py index e77948d..9e2e9a5 100644 --- a/fixtures.py +++ b/fixtures.py @@ -4,6 +4,42 @@ from pyfixtures import fixture +def get_reference_index_path(work_path: Path) -> Path: + return work_path / "reference_index" / "reference" + + +def get_collapsed_reference_path(work_path: Path) -> Path: + return work_path / "collapsed_reference" / "reference.json" + + +def get_subtraction_indexes_path(work_path: Path) -> Path: + return work_path / "subtraction_indexes" + + +def get_isolate_path(work_path: Path) -> Path: + return work_path / "isolates" + + +def get_isolate_fasta_path(isolate_path: Path) -> Path: + return isolate_path / "isolate_index.fa" + + +def get_isolate_fastq_path(isolate_path: Path) -> Path: + return isolate_path / "isolate_mapped.fq" + + +def get_isolate_index_path(isolate_path: Path) -> Path: + return isolate_path / "isolates" + + +def get_isolate_bam_path(isolate_path: Path) -> Path: + return isolate_path / "to_isolates.bam" + + +def get_subtracted_bam_path(work_path: Path) -> Path: + return work_path / "subtracted.bam" + + @fixture def intermediate(): """A namespace for storing intermediate values.""" @@ -14,7 +50,7 @@ def intermediate(): @fixture def isolate_path(work_path: Path): - path = work_path / "isolates" + path = get_isolate_path(work_path) path.mkdir() return path @@ -22,32 +58,37 @@ def isolate_path(work_path: Path): @fixture def reference_index_path(work_path: Path): - return work_path / "reference_index" / "reference" + return get_reference_index_path(work_path) + + +@fixture +def collapsed_reference_path(work_path: Path): + return get_collapsed_reference_path(work_path) @fixture def subtraction_indexes_path(work_path: Path) -> Path: - return work_path / "subtraction_indexes" + return get_subtraction_indexes_path(work_path) @fixture def isolate_fasta_path(isolate_path: Path): - return isolate_path / "isolate_index.fa" + return get_isolate_fasta_path(isolate_path) @fixture def isolate_fastq_path(isolate_path: Path): - return isolate_path / "isolate_mapped.fq" + return get_isolate_fastq_path(isolate_path) @fixture def isolate_index_path(isolate_path: Path): - return isolate_path / "isolates" + return get_isolate_index_path(isolate_path) @fixture def isolate_bam_path(isolate_path: Path): - return isolate_path / "to_isolates.bam" + return get_isolate_bam_path(isolate_path) @fixture @@ -58,4 +99,4 @@ def p_score_cutoff(): @fixture def subtracted_bam_path(work_path: Path) -> Path: """The path to the BAM file after subtraction reads have been eliminated.""" - return work_path / "subtracted.bam" + return get_subtracted_bam_path(work_path) diff --git a/python/workflow_pathoscope/utils.py b/python/workflow_pathoscope/utils.py index fc0eea4..2ee8360 100755 --- a/python/workflow_pathoscope/utils.py +++ b/python/workflow_pathoscope/utils.py @@ -4,9 +4,11 @@ import json import re from pathlib import Path +from tempfile import TemporaryDirectory from virtool.caches.utils import derive_key from virtool.workflow.data.cache import CacheHit, WorkflowCache +from virtool.workflow.errors import SubprocessFailedError from virtool.workflow.runtime.run_subprocess import RunSubprocess from virtool.workflow.utils import get_workflow_version @@ -14,6 +16,8 @@ BOWTIE2_BUILD_TOOL = "bowtie2-build" +CD_HIT_EST_TOOL = "cd-hit-est" +CD_HIT_EST_IDENTITY = "0.99" WORKFLOW_NAME = "pathoscope" WORKFLOW_VERSION = get_workflow_version() @@ -37,6 +41,48 @@ async def collect_stdout(line: bytes) -> None: return match.group(1) +async def get_cd_hit_est_version(run_subprocess: RunSubprocess) -> str: + output = [] + + async def collect_output(line: bytes) -> None: + output.append(line.decode()) + + try: + await run_subprocess( + [CD_HIT_EST_TOOL, "-h"], + stderr_handler=collect_output, + stdout_handler=collect_output, + ) + except SubprocessFailedError: + # cd-hit-est -h prints help/version text and exits with code 1. + pass + + match = re.search(r"\bCD-HIT\s+version\s+([^\s]+)", "".join(output)) + + if match is None: + raise ValueError("Could not parse cd-hit-est version") + + return match.group(1) + + +async def get_reference_collapse_cache_params( + parent_id: str, + run_subprocess: RunSubprocess, +) -> dict[str, str]: + tool_version = await get_cd_hit_est_version(run_subprocess) + + return { + "index_kind": "collapsed_reference", + "workflow": WORKFLOW_NAME, + "workflow_version": WORKFLOW_VERSION, + "parent_id": parent_id, + "source": "index_json", + "tool_name": CD_HIT_EST_TOOL, + "tool_version": tool_version, + "identity": CD_HIT_EST_IDENTITY, + } + + async def get_mapping_index_cache_params( index_kind: str, parent_id: str, @@ -146,6 +192,193 @@ def _get_reference_otus(reference_data): return reference_data +def _get_otu_schema_segment_names(otu: dict) -> set[str]: + return {str(segment["name"]) for segment in otu["schema"]} + + +def _get_schema_sequence_segment_key( + otu: dict, + sequence: dict, + valid_schema_segments: set[str], +) -> str: + segment_key = sequence["segment"] + + if segment_key not in valid_schema_segments: + raise ValueError( + f"Sequence {sequence['_id']} uses segment {segment_key!r}, which is not " + f"defined in OTU {otu['_id']} schema" + ) + + return segment_key + + +def _write_fasta(sequences: list[dict], path: Path) -> None: + with open(path, "w") as handle: + for sequence in sequences: + handle.write(f">{sequence['_id']}\n{sequence['sequence']}\n") + + +def _parse_cd_hit_clusters(cluster_path: Path) -> dict[str, str]: + representatives_by_sequence_id = {} + cluster_sequence_ids = [] + representative_id = None + + def flush_cluster() -> None: + if representative_id is None: + return + + for sequence_id in cluster_sequence_ids: + representatives_by_sequence_id[sequence_id] = representative_id + + with open(cluster_path) as handle: + for line in handle: + line = line.strip() + + if line.startswith(">Cluster"): + flush_cluster() + cluster_sequence_ids = [] + representative_id = None + continue + + match = re.search(r">(.+?)\.\.\.", line) + + if match is None: + continue + + sequence_id = match.group(1) + cluster_sequence_ids.append(sequence_id) + + if line.endswith("*"): + representative_id = sequence_id + + flush_cluster() + + return representatives_by_sequence_id + + +def _build_representative_set( + isolate: dict, + representatives_by_sequence_id: dict[str, str], +) -> frozenset[str]: + return frozenset( + representatives_by_sequence_id[sequence["_id"]] + for sequence in isolate["sequences"] + ) + + +async def collapse_reference_json( + json_path: Path, + target_path: Path, + proc: int, + run_subprocess: RunSubprocess, +) -> dict[str, int]: + """Collapse redundant isolates in a reference JSON using cd-hit-est clusters.""" + with _open_json(json_path) as handle: + reference_data = json.load(handle) + + otus = _get_reference_otus(reference_data) + before_count = 0 + after_count = 0 + + with TemporaryDirectory(prefix="pathoscope-collapse-") as temp_dir: + temp_path = Path(temp_dir) + + for otu in otus: + sequences_by_segment = {} + valid_schema_segments = _get_otu_schema_segment_names(otu) + before_count += len(otu["isolates"]) + for isolate in otu["isolates"]: + for sequence in isolate["sequences"]: + segment_key = _get_schema_sequence_segment_key( + otu, + sequence, + valid_schema_segments, + ) + + sequences_by_segment.setdefault(segment_key, []).append(sequence) + + representatives_by_sequence_id = {} + + sorted_segment_sequences = sorted( + sequences_by_segment.items(), + key=lambda item: item[0], + ) + + for segment_name, sequences in sorted_segment_sequences: + segment_input_path = ( + temp_path / f"otu-{otu['_id']}-segment-{segment_name}.fa" + ) + segment_output_path = ( + temp_path / f"otu-{otu['_id']}-segment-{segment_name}.cdhit" + ) + + await asyncio.to_thread(_write_fasta, sequences, segment_input_path) + + await run_subprocess( + [ + CD_HIT_EST_TOOL, + "-i", + str(segment_input_path), + "-o", + str(segment_output_path), + "-c", + CD_HIT_EST_IDENTITY, + "-T", + str(proc), + "-M", + "0", + "-d", + "0", + ], + ) + + representatives_by_sequence_id.update( + _parse_cd_hit_clusters( + segment_output_path.with_suffix(".cdhit.clstr") + ) + ) + + default_sequence_ids = { + sequence["_id"] + for isolate in otu["isolates"] + if isolate["default"] + for sequence in isolate["sequences"] + } + seen_representative_sets = set() + collapsed_isolates = [] + + for isolate in otu["isolates"]: + representative_set = _build_representative_set( + isolate, + representatives_by_sequence_id, + ) + + first_for_set = representative_set not in seen_representative_sets + seen_representative_sets.add(representative_set) + + contains_default_sequence = any( + sequence["_id"] in default_sequence_ids + for sequence in isolate["sequences"] + ) + + if isolate["default"] or contains_default_sequence or first_for_set: + collapsed_isolates.append(isolate) + + otu["isolates"] = collapsed_isolates + after_count += len(collapsed_isolates) + + await asyncio.to_thread(target_path.parent.mkdir, parents=True, exist_ok=True) + + with open(target_path, "w") as handle: + json.dump(reference_data, handle) + + return { + "isolate_count_before": before_count, + "isolate_count_after": after_count, + "isolate_count_removed": before_count - after_count, + } + + def write_default_isolate_fasta( json_path: Path, target_path: Path, diff --git a/tests/assests/redundant_reference.json b/tests/assests/redundant_reference.json new file mode 100644 index 0000000..df9b172 --- /dev/null +++ b/tests/assests/redundant_reference.json @@ -0,0 +1,97 @@ +{ + "otus": [ + { + "_id": "collapse-otu", + "schema": [ + { + "name": "a" + }, + { + "name": "b" + } + ], + "isolates": [ + { + "_id": "default", + "default": true, + "sequences": [ + { + "_id": "default-a", + "segment": "a", + "sequence": "ACGTACGTACGTACGTACGT" + }, + { + "_id": "default-b", + "segment": "b", + "sequence": "TGCATGCATGCATGCATGCA" + } + ] + }, + { + "_id": "representative-1", + "default": false, + "sequences": [ + { + "_id": "representative-1-a", + "segment": "a", + "sequence": "ACGTACGTACGTACGTACGT" + }, + { + "_id": "representative-1-b", + "segment": "b", + "sequence": "TGCAGGATCGTTTAACGTAG" + } + ] + }, + { + "_id": "representative-2", + "default": false, + "sequences": [ + { + "_id": "representative-2-a", + "segment": "a", + "sequence": "CCCCAAAAGGGGTTTTCCCC" + }, + { + "_id": "representative-2-b", + "segment": "b", + "sequence": "AAAACCCCGGGGTTTTAAAA" + } + ] + }, + { + "_id": "unique-combo", + "default": false, + "sequences": [ + { + "_id": "unique-combo-a", + "segment": "a", + "sequence": "CCCCAAAAGGGGTTTTCCCC" + }, + { + "_id": "unique-combo-b", + "segment": "b", + "sequence": "TGCAGGATCGTTTAACGTAG" + } + ] + }, + { + "_id": "duplicate", + "default": false, + "sequences": [ + { + "_id": "duplicate-a", + "segment": "a", + "sequence": "CCCCAAAAGGGGTTTTCCCC" + }, + { + "_id": "duplicate-b", + "segment": "b", + "sequence": "AAAACCCCGGGGTTTTAAAA" + } + ] + } + ] + } + ] +} diff --git a/tests/test_workflow.py b/tests/test_workflow.py index bb91e2d..2aaaf6a 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -1,5 +1,7 @@ +import asyncio import gzip import json +import re import shutil import pysam @@ -12,13 +14,25 @@ from virtool.caches.utils import derive_key from virtool.workflow import RunSubprocess from virtool.workflow.data.analyses import WFAnalysis -from virtool.workflow.data.cache import CacheHit, CacheMiss +from virtool.workflow.data.cache import WorkflowCache from virtool.workflow.data.indexes import WFIndex from virtool.workflow.data.samples import WFSample from virtool.workflow.data.subtractions import WFSubtraction +from virtool.workflow.errors import JobsAPINotFoundError from virtool.workflow.pytest_plugin import WorkflowData +from fixtures import ( + get_collapsed_reference_path, + get_isolate_bam_path, + get_isolate_fastq_path, + get_isolate_index_path, + get_isolate_path, + get_reference_index_path, + get_subtracted_bam_path, + get_subtraction_indexes_path, +) from workflow import ( + collapse_reference, create_reference_index, create_subtraction_index, eliminate_subtraction, @@ -28,8 +42,12 @@ reassignment, ) from workflow_pathoscope.utils import ( + CD_HIT_EST_IDENTITY, + collapse_reference_json, get_mapping_index_cache_params, + get_reference_collapse_cache_params, write_default_isolate_fasta, + write_isolate_fasta, ) @@ -41,6 +59,10 @@ "rev.1.bt2", "rev.2.bt2", ) +REDUNDANT_REFERENCE_JSON_PATH = ( + Path(__file__).parent / "assests" / "redundant_reference.json" +) +TOOL_VERSION_PATTERN = re.compile(r"\d+(?:\.\d+)+(?:[-+._A-Za-z0-9]*)?") @pytest.fixture() @@ -53,12 +75,17 @@ def work_path(tmpdir): @pytest.fixture() def reference_index_path(work_path: Path) -> Path: - return work_path / "reference_index" / "reference" + return get_reference_index_path(work_path) + + +@pytest.fixture() +def collapsed_reference_path(work_path: Path) -> Path: + return get_collapsed_reference_path(work_path) @pytest.fixture() def subtraction_indexes_path(work_path: Path) -> Path: - return work_path / "subtraction_indexes" + return get_subtraction_indexes_path(work_path) @pytest.fixture() @@ -69,71 +96,85 @@ def subtraction_index_path( return get_subtraction_index_path(subtraction_indexes_path, subtractions[0].id) -class FakeWorkflowCache: - def __init__( - self, - hit_source: Path | None = None, - put_exception: Exception | None = None, - put_created: bool = True, - ): - self.hit_source = hit_source - self.put_exception = put_exception - self.put_created = put_created - self.gets = [] - self.puts = [] +@pytest.fixture() +def isolate_path(work_path: Path) -> Path: + path = get_isolate_path(work_path) + path.mkdir() - async def get(self, key: str, target: Path): - self.gets.append((key, target)) + return path - if self.hit_source is None: - return CacheMiss(key) - target.mkdir(parents=True, exist_ok=True) +@pytest.fixture() +def isolate_fastq_path(isolate_path: Path) -> Path: + return get_isolate_fastq_path(isolate_path) - restored_path = target / self.hit_source.name - shutil.copytree(self.hit_source, restored_path) - return CacheHit(key, restored_path) +@pytest.fixture() +def isolate_index_path(isolate_path: Path) -> Path: + return get_isolate_index_path(isolate_path) - async def put(self, key: str, source: Path, params: dict | None = None): - self.puts.append((key, source, params)) - if self.put_exception is not None: - raise self.put_exception +@pytest.fixture() +def isolate_bam_path(isolate_path: Path) -> Path: + return get_isolate_bam_path(isolate_path) + - return self.put_created +@pytest.fixture() +def subtracted_bam_path(work_path: Path) -> Path: + return get_subtracted_bam_path(work_path) -class FakeRunSubprocess: - def __init__(self): - self.commands = [] +class _FakeWorkflowCacheAPI: + """Fake only the API calls used by the real workflow cache.""" - async def __call__( + def __init__( self, - command: list[str], - cwd: str | Path | None = None, - env: dict | None = None, - stderr_handler=None, - stdout_handler=None, - ): - self.commands.append(command) + work_dir: Path, + *, + put_exception: Exception | None = None, + put_created: bool | None = None, + ) -> None: + self.work_dir = work_dir + self.put_exception = put_exception + self.put_created = put_created + self.stored: dict[str, tuple[Path, dict | None]] = {} - if command == ["bowtie2-build", "--version"]: - await stdout_handler(b"/usr/bin/bowtie2-build-s version 2.5.4\n") - return SimpleNamespace(returncode=0) + async def get_cache(self, key: str, dest: Path) -> None: + try: + source, _ = self.stored[key] + except KeyError: + raise JobsAPINotFoundError from None - if command[0] == "bowtie2-build": - prefix = Path(command[-1]) - prefix.parent.mkdir(parents=True, exist_ok=True) + dest.parent.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(shutil.copyfile, source, dest) - for suffix in BOWTIE2_INDEX_SUFFIXES: - (prefix.parent / f"{prefix.name}.{suffix}").write_bytes( - f"{prefix.name}.{suffix}".encode(), - ) + async def put_cache( + self, + key: str, + path: Path, + params: dict | None = None, + ) -> bool: + if self.put_exception: + raise self.put_exception + + if key in self.stored: + return False + + stored_path = self.work_dir / "stored" / key / "cache.tar" + stored_path.parent.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(shutil.copyfile, path, stored_path) + self.stored[key] = (stored_path, params) + + if self.put_created is not None: + return self.put_created + + return True - return SimpleNamespace(returncode=0) - raise AssertionError(f"Unexpected subprocess command: {command}") +@pytest.fixture() +def workflow_cache(tmp_path: Path) -> WorkflowCache: + api = _FakeWorkflowCacheAPI(tmp_path / "fake_workflow_cache_api") + return WorkflowCache(api, tmp_path / "workflow_cache") def read_directory_bytes(path: Path) -> dict[str, bytes]: @@ -162,6 +203,26 @@ def write_bowtie2_bundle( (path / name).write_bytes(file_content) +def assert_bowtie2_index_exists(prefix: Path): + assert read_directory_bytes(prefix.parent).keys() == { + f"{prefix.name}.{suffix}" for suffix in BOWTIE2_INDEX_SUFFIXES + } + + for suffix in BOWTIE2_INDEX_SUFFIXES: + assert (prefix.parent / f"{prefix.name}.{suffix}").stat().st_size > 0 + + +def assert_cache_params( + params: dict[str, str], + expected: dict[str, str], +) -> None: + assert params.keys() == {*expected.keys(), "tool_version"} + assert { + key: value for key, value in params.items() if key != "tool_version" + } == expected + assert TOOL_VERSION_PATTERN.fullmatch(params["tool_version"]) + + def write_reference_json(path: Path): path.parent.mkdir(parents=True, exist_ok=True) @@ -206,6 +267,12 @@ def write_reference_json(path: Path): ) +def write_redundant_reference_json(path: Path): + path.parent.mkdir(parents=True, exist_ok=True) + + shutil.copyfile(REDUNDANT_REFERENCE_JSON_PATH, path) + + @pytest.fixture() def analysis(workflow_data: WorkflowData, mocker): analysis_ = mocker.Mock(WFAnalysis) @@ -262,6 +329,172 @@ def index(workflow_data: WorkflowData, example_path: Path, work_path: Path): ) +async def test_collapse_reference_hit( + collapsed_reference_path: Path, + index: WFIndex, + run_subprocess: RunSubprocess, + tmp_path: Path, + workflow_cache: WorkflowCache, +): + source = tmp_path / collapsed_reference_path.parent.name + source.mkdir() + write_reference_json(source / collapsed_reference_path.name) + (source / "collapse-manifest.json").write_text( + json.dumps( + { + "isolate_count_before": 4, + "isolate_count_after": 3, + "isolate_count_removed": 1, + } + ) + ) + params = await get_reference_collapse_cache_params(index.id, run_subprocess) + key = derive_key(params) + assert await workflow_cache.put(key, source, params) + + logger = get_logger("test") + + result_path = await collapse_reference( + workflow_cache, + collapsed_reference_path, + index, + logger, + 4, + run_subprocess, + ) + + assert result_path == collapsed_reference_path + assert read_directory_bytes( + collapsed_reference_path.parent + ) == read_directory_bytes(source) + + +async def test_collapse_reference_miss_retains_required_isolates( + collapsed_reference_path: Path, + index: WFIndex, + run_subprocess: RunSubprocess, + workflow_cache: WorkflowCache, +): + write_redundant_reference_json(index.json_path) + logger = get_logger("test") + + result_path = await collapse_reference( + workflow_cache, + collapsed_reference_path, + index, + logger, + 4, + run_subprocess, + ) + + params = await get_reference_collapse_cache_params(index.id, run_subprocess) + + assert result_path == collapsed_reference_path + assert_cache_params( + params, + { + "index_kind": "collapsed_reference", + "workflow": "pathoscope", + "workflow_version": "UNKNOWN", + "parent_id": index.id, + "source": "index_json", + "tool_name": "cd-hit-est", + "identity": "0.99", + }, + ) + + with open(collapsed_reference_path) as handle: + collapsed_reference = json.load(handle) + + assert [ + (isolate["sequences"][0]["_id"], isolate["sequences"][1]["_id"]) + for isolate in collapsed_reference["otus"][0]["isolates"] + ] == [("default-a", "default-b"), ("representative-1-a", "representative-1-b"), ("representative-2-a", "representative-2-b"), ("unique-combo-a", "unique-combo-b")] + assert json.loads( + (collapsed_reference_path.parent / "collapse-manifest.json").read_text() + ) == { + "isolate_count_before": 5, + "isolate_count_after": 4, + "isolate_count_removed": 1, + } + + +async def test_collapse_reference_json_outputs_collapsed_reference_fasta( + run_subprocess: RunSubprocess, + tmp_path: Path, +): + source_path = tmp_path / "reference.json" + collapsed_path = tmp_path / "collapsed" / "reference.json" + default_fasta_path = tmp_path / "default.fa" + isolate_fasta_path = tmp_path / "isolates.fa" + + write_redundant_reference_json(source_path) + + assert await collapse_reference_json( + source_path, + collapsed_path, + 2, + run_subprocess, + ) == { + "isolate_count_before": 5, + "isolate_count_after": 4, + "isolate_count_removed": 1, + } + + assert write_default_isolate_fasta(collapsed_path, default_fasta_path) == { + "default-a": 20, + "default-b": 20, + } + assert write_isolate_fasta( + {"collapse-otu"}, + collapsed_path, + isolate_fasta_path, + ) == { + "default-a": 20, + "default-b ": 20, + "representative-1-a": 20, + "representative-1-b": 20, + "representative-2-a": 20, + "representative-2-b": 20, + "unique-combo-a": 20, + "unique-combo-b": 20, + } + assert "duplicate-a" not in isolate_fasta_path.read_text() + assert "duplicate-b" not in isolate_fasta_path.read_text() + + +async def test_collapse_reference_json_rejects_sequences_outside_otu_schema( + run_subprocess: RunSubprocess, + tmp_path: Path, +): + source_path = tmp_path / "reference.json" + collapsed_path = tmp_path / "collapsed" / "reference.json" + + write_redundant_reference_json(source_path) + + with open(source_path) as handle: + reference_data = json.load(handle) + + reference_data["otus"][0]["isolates"][0]["sequences"][0]["segment"] = "c" + + with open(source_path, "w") as handle: + json.dump(reference_data, handle) + + with pytest.raises( + ValueError, + match=( + "Sequence default-a uses segment 'c', which is not defined in " + "OTU collapse-otu schema" + ), + ): + await collapse_reference_json( + source_path, + collapsed_path, + 2, + run_subprocess, + ) + + @pytest.fixture() def sample(workflow_data: WorkflowData, example_path: Path, work_path: Path): workflow_data.sample.library_type = "normal" @@ -303,11 +536,14 @@ def subtractions(workflow_data: WorkflowData, example_path: Path, work_path: Pat async def test_create_reference_index_hit( + collapsed_reference_path: Path, index: WFIndex, reference_index_path: Path, + run_subprocess: RunSubprocess, tmp_path: Path, + workflow_cache: WorkflowCache, ): - write_reference_json(index.json_path) + write_reference_json(collapsed_reference_path) source = tmp_path / reference_index_path.parent.name write_bowtie2_bundle( source, @@ -315,12 +551,24 @@ async def test_create_reference_index_hit( b"cached-reference", {"cache-manifest.json": b"cached-manifest"}, ) - cache = FakeWorkflowCache(source) - run_subprocess = FakeRunSubprocess() + params = await get_mapping_index_cache_params( + "reference_mapping_index", + index.id, + run_subprocess, + { + "collapse_identity": CD_HIT_EST_IDENTITY, + "source": "collapsed_reference", + "selection": "default_isolates", + }, + ) + key = derive_key(params) + assert await workflow_cache.put(key, source, params) + logger = get_logger("test") result_index_path = await create_reference_index( - cache, + workflow_cache, + collapsed_reference_path, index, logger, 4, @@ -328,41 +576,25 @@ async def test_create_reference_index_hit( reference_index_path, ) - params = await get_mapping_index_cache_params( - "reference_mapping_index", - index.id, - FakeRunSubprocess(), - { - "source": "index_json", - "selection": "default_isolates", - }, - ) - - assert cache.gets == [ - ( - derive_key(params), - reference_index_path.parent.parent, - ), - ] - assert cache.puts == [] assert result_index_path == reference_index_path - assert run_subprocess.commands == [["bowtie2-build", "--version"]] assert read_directory_bytes(reference_index_path.parent) == read_directory_bytes( source ) async def test_create_reference_index_miss( + collapsed_reference_path: Path, index: WFIndex, reference_index_path: Path, + run_subprocess: RunSubprocess, + workflow_cache: WorkflowCache, ): - write_reference_json(index.json_path) - cache = FakeWorkflowCache() - run_subprocess = FakeRunSubprocess() + write_reference_json(collapsed_reference_path) logger = get_logger("test") result_index_path = await create_reference_index( - cache, + workflow_cache, + collapsed_reference_path, index, logger, 4, @@ -373,51 +605,50 @@ async def test_create_reference_index_miss( params = await get_mapping_index_cache_params( "reference_mapping_index", index.id, - FakeRunSubprocess(), + run_subprocess, { - "source": "index_json", + "collapse_identity": CD_HIT_EST_IDENTITY, + "source": "collapsed_reference", "selection": "default_isolates", }, ) - key = derive_key(params) - assert cache.gets == [(key, reference_index_path.parent.parent)] - assert cache.puts == [(key, reference_index_path.parent, params)] assert result_index_path == reference_index_path - assert params == { - "index_kind": "reference_mapping_index", - "workflow": "pathoscope", - "workflow_version": "UNKNOWN", - "parent_id": index.id, - "source": "index_json", - "selection": "default_isolates", - "tool_name": "bowtie2-build", - "tool_version": "2.5.4", - } - assert len(run_subprocess.commands) == 2 - assert run_subprocess.commands[0] == ["bowtie2-build", "--version"] - assert run_subprocess.commands[1][:3] == [ - "bowtie2-build", - "--threads", - "4", - ] - assert Path(run_subprocess.commands[1][3]).name == "reference.fa" - assert run_subprocess.commands[1][4] == str(reference_index_path) - assert read_directory_bytes(reference_index_path.parent) == { - **bowtie2_bundle_bytes("reference"), - } + assert_cache_params( + params, + { + "index_kind": "reference_mapping_index", + "workflow": "pathoscope", + "workflow_version": "UNKNOWN", + "parent_id": index.id, + "collapse_identity": "0.99", + "source": "collapsed_reference", + "selection": "default_isolates", + "tool_name": "bowtie2-build", + }, + ) + assert_bowtie2_index_exists(reference_index_path) async def test_create_reference_index_continues_when_cache_put_is_skipped( + collapsed_reference_path: Path, index: WFIndex, reference_index_path: Path, + run_subprocess: RunSubprocess, + tmp_path: Path, ): - write_reference_json(index.json_path) - cache = FakeWorkflowCache(put_created=False) - run_subprocess = FakeRunSubprocess() + write_reference_json(collapsed_reference_path) + cache = WorkflowCache( + _FakeWorkflowCacheAPI( + tmp_path / "fake_workflow_cache_api", + put_created=False, + ), + tmp_path / "workflow_cache", + ) logger = get_logger("test") result_index_path = await create_reference_index( cache, + collapsed_reference_path, index, logger, 4, @@ -425,36 +656,31 @@ async def test_create_reference_index_continues_when_cache_put_is_skipped( reference_index_path, ) - params = await get_mapping_index_cache_params( - "reference_mapping_index", - index.id, - FakeRunSubprocess(), - { - "source": "index_json", - "selection": "default_isolates", - }, - ) - key = derive_key(params) - assert cache.gets == [(key, reference_index_path.parent.parent)] - assert cache.puts == [(key, reference_index_path.parent, params)] assert result_index_path == reference_index_path - assert read_directory_bytes(reference_index_path.parent) == { - **bowtie2_bundle_bytes("reference"), - } + assert_bowtie2_index_exists(reference_index_path) async def test_create_reference_index_raises_unexpected_cache_put_failure( + collapsed_reference_path: Path, index: WFIndex, reference_index_path: Path, + run_subprocess: RunSubprocess, + tmp_path: Path, ): - write_reference_json(index.json_path) - cache = FakeWorkflowCache(put_exception=RuntimeError("cache upload failed")) - run_subprocess = FakeRunSubprocess() + write_reference_json(collapsed_reference_path) + cache = WorkflowCache( + _FakeWorkflowCacheAPI( + tmp_path / "fake_workflow_cache_api", + put_exception=RuntimeError("cache upload failed"), + ), + tmp_path / "workflow_cache", + ) logger = get_logger("test") with pytest.raises(RuntimeError, match="cache upload failed"): await create_reference_index( cache, + collapsed_reference_path, index, logger, 4, @@ -477,10 +703,12 @@ def test_write_default_isolate_fasta(tmp_path: Path): async def test_create_subtraction_index_hit( + run_subprocess: RunSubprocess, subtractions: list[WFSubtraction], subtraction_index_path: Path, subtraction_indexes_path: Path, tmp_path: Path, + workflow_cache: WorkflowCache, ): source = tmp_path / subtraction_index_path.parent.name write_bowtie2_bundle( @@ -489,13 +717,19 @@ async def test_create_subtraction_index_hit( b"cached-subtraction", {"cache-manifest.json": b"cached-manifest"}, ) - cache = FakeWorkflowCache(source) - run_subprocess = FakeRunSubprocess() - logger = get_logger("test") subtraction = subtractions[0] + params = await get_mapping_index_cache_params( + "subtraction_mapping_index", + subtraction.id, + run_subprocess, + ) + key = derive_key(params) + assert await workflow_cache.put(key, source, params) + + logger = get_logger("test") result_indexes_path = await create_subtraction_index( - cache, + workflow_cache, logger, 4, run_subprocess, @@ -503,38 +737,24 @@ async def test_create_subtraction_index_hit( subtraction_indexes_path, ) - params = await get_mapping_index_cache_params( - "subtraction_mapping_index", - subtraction.id, - FakeRunSubprocess(), - ) - - assert cache.gets == [ - ( - derive_key(params), - subtraction_index_path.parent.parent, - ), - ] - assert cache.puts == [] assert result_indexes_path == subtraction_indexes_path - assert run_subprocess.commands == [["bowtie2-build", "--version"]] assert read_directory_bytes(subtraction_index_path.parent) == read_directory_bytes( source ) async def test_create_subtraction_index_miss( + run_subprocess: RunSubprocess, subtractions: list[WFSubtraction], subtraction_index_path: Path, subtraction_indexes_path: Path, + workflow_cache: WorkflowCache, ): - cache = FakeWorkflowCache() - run_subprocess = FakeRunSubprocess() logger = get_logger("test") subtraction = subtractions[0] result_indexes_path = await create_subtraction_index( - cache, + workflow_cache, logger, 4, run_subprocess, @@ -545,33 +765,20 @@ async def test_create_subtraction_index_miss( params = await get_mapping_index_cache_params( "subtraction_mapping_index", subtraction.id, - FakeRunSubprocess(), + run_subprocess, ) - key = derive_key(params) - assert cache.gets == [(key, subtraction_index_path.parent.parent)] - assert cache.puts == [(key, subtraction_index_path.parent, params)] assert result_indexes_path == subtraction_indexes_path - assert params == { - "index_kind": "subtraction_mapping_index", - "workflow": "pathoscope", - "workflow_version": "UNKNOWN", - "parent_id": subtraction.id, - "tool_name": "bowtie2-build", - "tool_version": "2.5.4", - } - assert run_subprocess.commands == [ - ["bowtie2-build", "--version"], - [ - "bowtie2-build", - "--threads", - "4", - str(subtraction.fasta_path), - str(subtraction_index_path), - ], - ] - assert read_directory_bytes(subtraction_index_path.parent) == bowtie2_bundle_bytes( - "subtraction" + assert_cache_params( + params, + { + "index_kind": "subtraction_mapping_index", + "workflow": "pathoscope", + "workflow_version": "UNKNOWN", + "parent_id": subtraction.id, + "tool_name": "bowtie2-build", + }, ) + assert_bowtie2_index_exists(subtraction_index_path) async def test_map_default_isolates( @@ -608,22 +815,20 @@ async def test_map_default_isolates( async def test_map_isolates( example_path: Path, index: WFIndex, + isolate_bam_path: Path, + isolate_fastq_path: Path, + isolate_index_path: Path, sample: WFSample, run_subprocess: RunSubprocess, snapshot: SnapshotAssertion, - work_path: Path, ): for path in (example_path / "index").iterdir(): if "reference" in path.name: shutil.copyfile( path, - work_path / path.name.replace("reference", "isolates"), + isolate_index_path.parent / path.name.replace("reference", "isolates"), ) - isolate_fastq_path = work_path / "mapped.fq" - isolate_index_path = work_path / "isolates" - isolate_bam_path = work_path / "to_isolates.bam" - proc = 1 await map_isolates( @@ -654,18 +859,18 @@ async def test_map_isolates( ) async def test_eliminate_subtraction( example_path: Path, + isolate_bam_path: Path, + isolate_fastq_path: Path, no_subtractions: bool, + subtracted_bam_path: Path, subtractions: list[WFSubtraction], subtraction_indexes_path: Path, run_subprocess: RunSubprocess, snapshot: SnapshotAssertion, tmp_path: Path, + workflow_cache: WorkflowCache, work_path: Path, ): - isolate_fastq_path = work_path / "to_isolates.fq" - isolate_bam_path = work_path / "to_isolates.bam" - subtracted_path = work_path / "subtracted.bam" - shutil.copyfile(example_path / "to_isolates.bam", isolate_bam_path) shutil.copyfile(example_path / "to_isolates.fq", isolate_fastq_path) @@ -681,12 +886,19 @@ async def test_eliminate_subtraction( if subtractions: cached_subtraction_path = tmp_path / subtractions[0].id shutil.copytree(subtractions[0].path, cached_subtraction_path) + params = await get_mapping_index_cache_params( + "subtraction_mapping_index", + subtractions[0].id, + run_subprocess, + ) + key = derive_key(params) + assert await workflow_cache.put(key, cached_subtraction_path, params) await create_subtraction_index( - FakeWorkflowCache(cached_subtraction_path), + workflow_cache, logger, proc, - FakeRunSubprocess(), + run_subprocess, subtractions, subtraction_indexes_path, ) @@ -702,7 +914,7 @@ async def test_eliminate_subtraction( run_subprocess, subtraction_indexes_path, subtractions, - subtracted_path, + subtracted_bam_path, work_path, ) @@ -752,8 +964,8 @@ def parse_alignments(path: Path) -> set[tuple]: for read in alignment_file } - assert parse_alignments(work_path / "subtracted.bam") == snapshot(name="alignments") - assert parse_headers(work_path / "subtracted.bam") == snapshot(name="headers") + assert parse_alignments(subtracted_bam_path) == snapshot(name="alignments") + assert parse_headers(subtracted_bam_path) == snapshot(name="headers") async def test_pathoscope( @@ -763,10 +975,9 @@ async def test_pathoscope( mocker, ref_lengths, snapshot: SnapshotAssertion, + subtracted_bam_path: Path, work_path: Path, ): - subtracted_bam_path = work_path / "to_isolates.bam" - shutil.copyfile(example_path / "to_isolates.bam", subtracted_bam_path) intermediate = SimpleNamespace(lengths=ref_lengths) diff --git a/workflow.py b/workflow.py index 349d87a..21931ca 100644 --- a/workflow.py +++ b/workflow.py @@ -1,4 +1,5 @@ import asyncio +import json import os import shlex import shutil @@ -7,16 +8,20 @@ from types import SimpleNamespace from typing import Any +from virtool.caches.utils import derive_key from virtool.workflow import hooks, step from virtool.workflow.data.analyses import WFAnalysis -from virtool.workflow.data.cache import WorkflowCache +from virtool.workflow.data.cache import CacheHit, WorkflowCache from virtool.workflow.data.indexes import WFIndex from virtool.workflow.data.samples import WFSample from virtool.workflow.data.subtractions import WFSubtraction from virtool.workflow.runtime.run_subprocess import RunSubprocess from workflow_pathoscope.utils import ( + CD_HIT_EST_IDENTITY, build_bowtie2_index, + collapse_reference_json, create_mapping_index, + get_reference_collapse_cache_params, run_pathoscope, write_default_isolate_fasta, write_isolate_fasta, @@ -46,9 +51,67 @@ async def delete_analysis_document(analysis: WFAnalysis): await analysis.delete() +@step +async def collapse_reference( + cache: WorkflowCache, + collapsed_reference_path: Path, + index: WFIndex, + logger, + proc: int, + run_subprocess: RunSubprocess, +) -> Path: + """Ensure a cd-hit-est collapsed reference JSON exists locally.""" + params = await get_reference_collapse_cache_params(index.id, run_subprocess) + key = derive_key(params) + collapsed_reference_dir = collapsed_reference_path.parent + log = logger.bind( + identity=CD_HIT_EST_IDENTITY, + index_kind="collapsed_reference", + key=key, + parent_id=index.id, + ) + + log.info("checking workflow cache") + + result = await cache.get(key, collapsed_reference_dir.parent) + + if isinstance(result, CacheHit): + log.info("restored cached collapsed reference", outcome="hit") + manifest_path = collapsed_reference_dir / "collapse-manifest.json" + + with open(manifest_path) as handle: + log.info("reference collapse restored", **json.load(handle)) + + return collapsed_reference_path + + log.info("collapsing reference", outcome="miss", source=str(index.json_path)) + + stats = await collapse_reference_json( + index.json_path, + collapsed_reference_path, + proc, + run_subprocess, + ) + + with open(collapsed_reference_dir / "collapse-manifest.json", "w") as handle: + json.dump(stats, handle) + + log.info("reference collapse complete", **stats) + + created = await cache.put(key, collapsed_reference_dir, params=params) + + if created: + log.info("cached collapsed reference", outcome="put") + else: + log.info("collapsed reference cache already exists", outcome="put_skipped") + + return collapsed_reference_path + + @step async def create_reference_index( cache: WorkflowCache, + collapsed_reference_path: Path, index: WFIndex, logger, proc: int, @@ -61,13 +124,13 @@ async def create_reference_index( await asyncio.to_thread( write_default_isolate_fasta, - index.json_path, + collapsed_reference_path, reference_fasta_path, ) logger.info( "assembled default reference fasta", - source=str(index.json_path), + source=str(collapsed_reference_path), ) await create_mapping_index( @@ -80,7 +143,8 @@ async def create_reference_index( index_prefix=reference_index_path, parent_id=index.id, extra_params={ - "source": "index_json", + "collapse_identity": CD_HIT_EST_IDENTITY, + "source": "collapsed_reference", "selection": "default_isolates", }, ) @@ -149,6 +213,7 @@ async def map_default_isolates( @step async def build_isolate_index( + collapsed_reference_path: Path, index: WFIndex, intermediate: SimpleNamespace, isolate_fasta_path: Path, @@ -161,7 +226,7 @@ async def build_isolate_index( intermediate.lengths = await asyncio.to_thread( write_isolate_fasta, {index.get_otu_id_by_sequence_id(id_) for id_ in intermediate.to_otus}, - index.json_path, + collapsed_reference_path, isolate_fasta_path, ) From 93859aad1bd1b6638353d17a074c8a9305d814ef Mon Sep 17 00:00:00 2001 From: ReeceHoffmann Date: Thu, 11 Jun 2026 12:28:46 -0700 Subject: [PATCH 2/6] test: fix reference asset fixture path --- tests/{assests => assets}/redundant_reference.json | 0 tests/test_workflow.py | 4 ++-- 2 files changed, 2 insertions(+), 2 deletions(-) rename tests/{assests => assets}/redundant_reference.json (100%) diff --git a/tests/assests/redundant_reference.json b/tests/assets/redundant_reference.json similarity index 100% rename from tests/assests/redundant_reference.json rename to tests/assets/redundant_reference.json diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 2aaaf6a..7fa7973 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -60,7 +60,7 @@ "rev.2.bt2", ) REDUNDANT_REFERENCE_JSON_PATH = ( - Path(__file__).parent / "assests" / "redundant_reference.json" + Path(__file__).parent / "assets" / "redundant_reference.json" ) TOOL_VERSION_PATTERN = re.compile(r"\d+(?:\.\d+)+(?:[-+._A-Za-z0-9]*)?") @@ -451,7 +451,7 @@ async def test_collapse_reference_json_outputs_collapsed_reference_fasta( isolate_fasta_path, ) == { "default-a": 20, - "default-b ": 20, + "default-b": 20, "representative-1-a": 20, "representative-1-b": 20, "representative-2-a": 20, From fce0c5a62f01e525735d0dcf94aec3e68d10ee33 Mon Sep 17 00:00:00 2001 From: ReeceHoffmann Date: Fri, 12 Jun 2026 08:43:19 -0700 Subject: [PATCH 3/6] perf: parallelize reference segment collapse - run cd-hit-est per segment with bounded concurrency - keep each cd-hit-est process single-threaded so proc controls segment-level parallelism --- python/workflow_pathoscope/utils.py | 75 ++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 23 deletions(-) diff --git a/python/workflow_pathoscope/utils.py b/python/workflow_pathoscope/utils.py index 2ee8360..e34e766 100755 --- a/python/workflow_pathoscope/utils.py +++ b/python/workflow_pathoscope/utils.py @@ -266,6 +266,35 @@ def _build_representative_set( ) +async def _collapse_reference_segment( + segment_input_path: Path, + segment_output_path: Path, + sequences: list[dict], + run_subprocess: RunSubprocess, +) -> dict[str, str]: + await asyncio.to_thread(_write_fasta, sequences, segment_input_path) + + await run_subprocess( + [ + CD_HIT_EST_TOOL, + "-i", + str(segment_input_path), + "-o", + str(segment_output_path), + "-c", + CD_HIT_EST_IDENTITY, + "-T", + "1", + "-M", + "0", + "-d", + "0", + ], + ) + + return _parse_cd_hit_clusters(segment_output_path.with_suffix(".cdhit.clstr")) + + async def collapse_reference_json( json_path: Path, target_path: Path, @@ -282,6 +311,20 @@ async def collapse_reference_json( with TemporaryDirectory(prefix="pathoscope-collapse-") as temp_dir: temp_path = Path(temp_dir) + semaphore = asyncio.Semaphore(proc) + + async def collapse_segment( + segment_input_path: Path, + segment_output_path: Path, + sequences: list[dict], + ) -> dict[str, str]: + async with semaphore: + return await _collapse_reference_segment( + segment_input_path, + segment_output_path, + sequences, + run_subprocess, + ) for otu in otus: sequences_by_segment = {} @@ -303,6 +346,7 @@ async def collapse_reference_json( sequences_by_segment.items(), key=lambda item: item[0], ) + segment_tasks = [] for segment_name, sequences in sorted_segment_sequences: segment_input_path = ( @@ -312,32 +356,17 @@ async def collapse_reference_json( temp_path / f"otu-{otu['_id']}-segment-{segment_name}.cdhit" ) - await asyncio.to_thread(_write_fasta, sequences, segment_input_path) - - await run_subprocess( - [ - CD_HIT_EST_TOOL, - "-i", - str(segment_input_path), - "-o", - str(segment_output_path), - "-c", - CD_HIT_EST_IDENTITY, - "-T", - str(proc), - "-M", - "0", - "-d", - "0", - ], - ) - - representatives_by_sequence_id.update( - _parse_cd_hit_clusters( - segment_output_path.with_suffix(".cdhit.clstr") + segment_tasks.append( + collapse_segment( + segment_input_path, + segment_output_path, + sequences, ) ) + for representatives in await asyncio.gather(*segment_tasks): + representatives_by_sequence_id.update(representatives) + default_sequence_ids = { sequence["_id"] for isolate in otu["isolates"] From 0f01a53075b1988c829daaf06fddcc50d7881376 Mon Sep 17 00:00:00 2001 From: ReeceHoffmann Date: Fri, 12 Jun 2026 13:01:11 -0700 Subject: [PATCH 4/6] feat: cache collapsed reference segments --- python/workflow_pathoscope/utils.py | 194 ++++++++++++++++------------ tests/test_workflow.py | 17 ++- 2 files changed, 124 insertions(+), 87 deletions(-) diff --git a/python/workflow_pathoscope/utils.py b/python/workflow_pathoscope/utils.py index e34e766..cdbc939 100755 --- a/python/workflow_pathoscope/utils.py +++ b/python/workflow_pathoscope/utils.py @@ -3,6 +3,7 @@ import gzip import json import re +from collections.abc import Awaitable, Callable from pathlib import Path from tempfile import TemporaryDirectory @@ -192,25 +193,20 @@ def _get_reference_otus(reference_data): return reference_data -def _get_otu_schema_segment_names(otu: dict) -> set[str]: - return {str(segment["name"]) for segment in otu["schema"]} - - -def _get_schema_sequence_segment_key( +def _validate_isolate_sequence_segments_match_schema( otu: dict, - sequence: dict, - valid_schema_segments: set[str], -) -> str: - segment_key = sequence["segment"] + isolate: dict, + schema_segments: set[str], +) -> None: + sequence_segments = {str(sequence["segment"]) for sequence in isolate["sequences"]} - if segment_key not in valid_schema_segments: + if sequence_segments != schema_segments: raise ValueError( - f"Sequence {sequence['_id']} uses segment {segment_key!r}, which is not " - f"defined in OTU {otu['_id']} schema" + f"Isolate {isolate['_id']} sequence segments " + f"{sorted(sequence_segments)!r} do not match OTU {otu['_id']} schema " + f"segments {sorted(schema_segments)!r}" ) - return segment_key - def _write_fasta(sequences: list[dict], path: Path) -> None: with open(path, "w") as handle: @@ -266,6 +262,22 @@ def _build_representative_set( ) +def _validate_and_group_otu_sequences_by_segment(otu: dict) -> dict[str, list[dict]]: + sequences_by_segment = {} + schema_segments = {str(segment["name"]) for segment in otu["schema"]} + + for isolate in otu["isolates"]: + _validate_isolate_sequence_segments_match_schema( + otu, + isolate, + schema_segments, + ) + for sequence in isolate["sequences"]: + sequences_by_segment.setdefault(sequence["segment"], []).append(sequence) + + return sequences_by_segment + + async def _collapse_reference_segment( segment_input_path: Path, segment_output_path: Path, @@ -295,6 +307,79 @@ async def _collapse_reference_segment( return _parse_cd_hit_clusters(segment_output_path.with_suffix(".cdhit.clstr")) +async def _collapse_otu_segments( + otu: dict, + temp_path: Path, + collapse_segment: Callable[[Path, Path, list[dict]], Awaitable[dict[str, str]]], +) -> dict[str, str]: + representatives_by_sequence_id = {} + sequences_by_segment = _validate_and_group_otu_sequences_by_segment(otu) + + segment_tasks = [] + sorted_segment_sequences = sorted( + sequences_by_segment.items(), + key=lambda item: item[0], + ) + + for segment_name, sequences in sorted_segment_sequences: + segment_input_path = temp_path / f"otu-{otu['_id']}-segment-{segment_name}.fa" + segment_output_path = ( + temp_path / f"otu-{otu['_id']}-segment-{segment_name}.cdhit" + ) + + segment_tasks.append( + collapse_segment( + segment_input_path, + segment_output_path, + sequences, + ) + ) + + for representatives in await asyncio.gather(*segment_tasks): + representatives_by_sequence_id.update(representatives) + + return representatives_by_sequence_id + + +async def _collapse_otu_reference( + otu: dict, + temp_path: Path, + collapse_segment: Callable[[Path, Path, list[dict]], Awaitable[dict[str, str]]], +) -> dict: + representatives_by_sequence_id = await _collapse_otu_segments( + otu, + temp_path, + collapse_segment, + ) + + default_sequence_ids = { + sequence["_id"] + for isolate in otu["isolates"] + if isolate["default"] + for sequence in isolate["sequences"] + } + seen_representative_sets = set() + collapsed_isolates = [] + + for isolate in otu["isolates"]: + representative_set = _build_representative_set( + isolate, + representatives_by_sequence_id, + ) + + first_for_set = representative_set not in seen_representative_sets + seen_representative_sets.add(representative_set) + + contains_default_sequence = any( + sequence["_id"] in default_sequence_ids for sequence in isolate["sequences"] + ) + + if isolate["default"] or contains_default_sequence or first_for_set: + collapsed_isolates.append(isolate) + + return {**otu, "isolates": collapsed_isolates} + + async def collapse_reference_json( json_path: Path, target_path: Path, @@ -326,80 +411,27 @@ async def collapse_segment( run_subprocess, ) - for otu in otus: - sequences_by_segment = {} - valid_schema_segments = _get_otu_schema_segment_names(otu) - before_count += len(otu["isolates"]) - for isolate in otu["isolates"]: - for sequence in isolate["sequences"]: - segment_key = _get_schema_sequence_segment_key( - otu, - sequence, - valid_schema_segments, - ) - - sequences_by_segment.setdefault(segment_key, []).append(sequence) + collapsed_otus = [] - representatives_by_sequence_id = {} - - sorted_segment_sequences = sorted( - sequences_by_segment.items(), - key=lambda item: item[0], + for otu in otus: + collapsed_otu = await _collapse_otu_reference( + otu, + temp_path, + collapse_segment, ) - segment_tasks = [] - - for segment_name, sequences in sorted_segment_sequences: - segment_input_path = ( - temp_path / f"otu-{otu['_id']}-segment-{segment_name}.fa" - ) - segment_output_path = ( - temp_path / f"otu-{otu['_id']}-segment-{segment_name}.cdhit" - ) - - segment_tasks.append( - collapse_segment( - segment_input_path, - segment_output_path, - sequences, - ) - ) - - for representatives in await asyncio.gather(*segment_tasks): - representatives_by_sequence_id.update(representatives) - - default_sequence_ids = { - sequence["_id"] - for isolate in otu["isolates"] - if isolate["default"] - for sequence in isolate["sequences"] - } - seen_representative_sets = set() - collapsed_isolates = [] - - for isolate in otu["isolates"]: - representative_set = _build_representative_set( - isolate, - representatives_by_sequence_id, - ) - - first_for_set = representative_set not in seen_representative_sets - seen_representative_sets.add(representative_set) - - contains_default_sequence = any( - sequence["_id"] in default_sequence_ids - for sequence in isolate["sequences"] - ) - - if isolate["default"] or contains_default_sequence or first_for_set: - collapsed_isolates.append(isolate) + collapsed_otus.append(collapsed_otu) + before_count += len(otu["isolates"]) + after_count += len(collapsed_otu["isolates"]) - otu["isolates"] = collapsed_isolates - after_count += len(collapsed_isolates) + if isinstance(reference_data, dict): + collapsed_reference_data = {**reference_data, "otus": collapsed_otus} + else: + collapsed_reference_data = collapsed_otus await asyncio.to_thread(target_path.parent.mkdir, parents=True, exist_ok=True) with open(target_path, "w") as handle: - json.dump(reference_data, handle) + json.dump(collapsed_reference_data, handle) return { "isolate_count_before": before_count, diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 7fa7973..1299430 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -409,7 +409,12 @@ async def test_collapse_reference_miss_retains_required_isolates( assert [ (isolate["sequences"][0]["_id"], isolate["sequences"][1]["_id"]) for isolate in collapsed_reference["otus"][0]["isolates"] - ] == [("default-a", "default-b"), ("representative-1-a", "representative-1-b"), ("representative-2-a", "representative-2-b"), ("unique-combo-a", "unique-combo-b")] + ] == [ + ("default-a", "default-b"), + ("representative-1-a", "representative-1-b"), + ("representative-2-a", "representative-2-b"), + ("unique-combo-a", "unique-combo-b"), + ] assert json.loads( (collapsed_reference_path.parent / "collapse-manifest.json").read_text() ) == { @@ -463,7 +468,7 @@ async def test_collapse_reference_json_outputs_collapsed_reference_fasta( assert "duplicate-b" not in isolate_fasta_path.read_text() -async def test_collapse_reference_json_rejects_sequences_outside_otu_schema( +async def test_collapse_reference_json_rejects_isolate_segments_that_do_not_match_schema( run_subprocess: RunSubprocess, tmp_path: Path, ): @@ -475,16 +480,16 @@ async def test_collapse_reference_json_rejects_sequences_outside_otu_schema( with open(source_path) as handle: reference_data = json.load(handle) - reference_data["otus"][0]["isolates"][0]["sequences"][0]["segment"] = "c" + reference_data["otus"][0]["isolates"][0]["sequences"][0]["segment"] = "b" with open(source_path, "w") as handle: json.dump(reference_data, handle) with pytest.raises( ValueError, - match=( - "Sequence default-a uses segment 'c', which is not defined in " - "OTU collapse-otu schema" + match=re.escape( + "Isolate default sequence segments ['b'] do not match OTU collapse-otu " + "schema segments ['a', 'b']" ), ): await collapse_reference_json( From 1de0d5de2acbce094626da4083a3a3aa51b83823 Mon Sep 17 00:00:00 2001 From: ReeceHoffmann Date: Fri, 12 Jun 2026 13:17:11 -0700 Subject: [PATCH 5/6] fix: use isolate id field in reference collapse --- python/workflow_pathoscope/utils.py | 2 +- tests/assets/redundant_reference.json | 10 +++++----- tests/test_workflow.py | 11 +++++++++++ 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/python/workflow_pathoscope/utils.py b/python/workflow_pathoscope/utils.py index cdbc939..a618900 100755 --- a/python/workflow_pathoscope/utils.py +++ b/python/workflow_pathoscope/utils.py @@ -202,7 +202,7 @@ def _validate_isolate_sequence_segments_match_schema( if sequence_segments != schema_segments: raise ValueError( - f"Isolate {isolate['_id']} sequence segments " + f"Isolate {isolate['id']} sequence segments " f"{sorted(sequence_segments)!r} do not match OTU {otu['_id']} schema " f"segments {sorted(schema_segments)!r}" ) diff --git a/tests/assets/redundant_reference.json b/tests/assets/redundant_reference.json index df9b172..71d22ac 100644 --- a/tests/assets/redundant_reference.json +++ b/tests/assets/redundant_reference.json @@ -12,7 +12,7 @@ ], "isolates": [ { - "_id": "default", + "id": "default", "default": true, "sequences": [ { @@ -28,7 +28,7 @@ ] }, { - "_id": "representative-1", + "id": "representative-1", "default": false, "sequences": [ { @@ -44,7 +44,7 @@ ] }, { - "_id": "representative-2", + "id": "representative-2", "default": false, "sequences": [ { @@ -60,7 +60,7 @@ ] }, { - "_id": "unique-combo", + "id": "unique-combo", "default": false, "sequences": [ { @@ -76,7 +76,7 @@ ] }, { - "_id": "duplicate", + "id": "duplicate", "default": false, "sequences": [ { diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 1299430..eb99058 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -236,6 +236,7 @@ def write_reference_json(path: Path): "_id": "default-otu", "isolates": [ { + "id": "default", "default": True, "sequences": [ {"_id": "default-a", "sequence": "ACGT"}, @@ -243,6 +244,7 @@ def write_reference_json(path: Path): ], }, { + "id": "non-default", "default": False, "sequences": [ {"_id": "non-default", "sequence": "GGGG"}, @@ -254,6 +256,7 @@ def write_reference_json(path: Path): "_id": "non-default-otu", "isolates": [ { + "id": "non-default-only", "default": False, "sequences": [ {"_id": "non-default-only", "sequence": "CCCC"}, @@ -406,6 +409,14 @@ async def test_collapse_reference_miss_retains_required_isolates( with open(collapsed_reference_path) as handle: collapsed_reference = json.load(handle) + assert [ + isolate["id"] for isolate in collapsed_reference["otus"][0]["isolates"] + ] == [ + "default", + "representative-1", + "representative-2", + "unique-combo", + ] assert [ (isolate["sequences"][0]["_id"], isolate["sequences"][1]["_id"]) for isolate in collapsed_reference["otus"][0]["isolates"] From d8214dca0e0293a97acb9f389ca2d845730806c5 Mon Sep 17 00:00:00 2001 From: ReeceHoffmann Date: Fri, 12 Jun 2026 13:41:07 -0700 Subject: [PATCH 6/6] fix: relax isolate segment validation - allow isolates to omit schema segments during reference collapse - permit empty segment sentinels for unsegmented OTUs --- python/workflow_pathoscope/utils.py | 9 ++- tests/test_workflow.py | 89 ++++++++++++++++++++++++++++- 2 files changed, 94 insertions(+), 4 deletions(-) diff --git a/python/workflow_pathoscope/utils.py b/python/workflow_pathoscope/utils.py index a618900..518a66b 100755 --- a/python/workflow_pathoscope/utils.py +++ b/python/workflow_pathoscope/utils.py @@ -199,11 +199,16 @@ def _validate_isolate_sequence_segments_match_schema( schema_segments: set[str], ) -> None: sequence_segments = {str(sequence["segment"]) for sequence in isolate["sequences"]} + unknown_segments = ( + sequence_segments - schema_segments + if schema_segments + else {segment for segment in sequence_segments if segment} + ) - if sequence_segments != schema_segments: + if unknown_segments: raise ValueError( f"Isolate {isolate['id']} sequence segments " - f"{sorted(sequence_segments)!r} do not match OTU {otu['_id']} schema " + f"{sorted(unknown_segments)!r} are not defined in OTU {otu['_id']} schema " f"segments {sorted(schema_segments)!r}" ) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index eb99058..795830b 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -479,6 +479,91 @@ async def test_collapse_reference_json_outputs_collapsed_reference_fasta( assert "duplicate-b" not in isolate_fasta_path.read_text() +async def test_collapse_reference_json_allows_isolates_missing_schema_segments( + run_subprocess: RunSubprocess, + tmp_path: Path, +): + source_path = tmp_path / "reference.json" + collapsed_path = tmp_path / "collapsed" / "reference.json" + + write_redundant_reference_json(source_path) + + with open(source_path) as handle: + reference_data = json.load(handle) + + reference_data["otus"][0]["isolates"][0]["sequences"] = [ + reference_data["otus"][0]["isolates"][0]["sequences"][1], + ] + + with open(source_path, "w") as handle: + json.dump(reference_data, handle) + + assert await collapse_reference_json( + source_path, + collapsed_path, + 2, + run_subprocess, + ) == { + "isolate_count_before": 5, + "isolate_count_after": 4, + "isolate_count_removed": 1, + } + + with open(collapsed_path) as handle: + collapsed_reference = json.load(handle) + + assert collapsed_reference["otus"][0]["isolates"][0]["sequences"] == [ + { + "_id": "default-b", + "segment": "b", + "sequence": "TGCATGCATGCATGCATGCA", + }, + ] + + +async def test_collapse_reference_json_allows_unsegmented_isolates( + run_subprocess: RunSubprocess, + tmp_path: Path, +): + source_path = tmp_path / "reference.json" + collapsed_path = tmp_path / "collapsed" / "reference.json" + + write_redundant_reference_json(source_path) + + with open(source_path) as handle: + reference_data = json.load(handle) + + reference_data["otus"][0]["schema"] = [] + + for isolate in reference_data["otus"][0]["isolates"]: + for sequence in isolate["sequences"]: + sequence["segment"] = "" + + with open(source_path, "w") as handle: + json.dump(reference_data, handle) + + assert await collapse_reference_json( + source_path, + collapsed_path, + 2, + run_subprocess, + ) == { + "isolate_count_before": 5, + "isolate_count_after": 4, + "isolate_count_removed": 1, + } + + with open(collapsed_path) as handle: + collapsed_reference = json.load(handle) + + assert collapsed_reference["otus"][0]["schema"] == [] + assert { + sequence["segment"] + for isolate in collapsed_reference["otus"][0]["isolates"] + for sequence in isolate["sequences"] + } == {""} + + async def test_collapse_reference_json_rejects_isolate_segments_that_do_not_match_schema( run_subprocess: RunSubprocess, tmp_path: Path, @@ -491,7 +576,7 @@ async def test_collapse_reference_json_rejects_isolate_segments_that_do_not_matc with open(source_path) as handle: reference_data = json.load(handle) - reference_data["otus"][0]["isolates"][0]["sequences"][0]["segment"] = "b" + reference_data["otus"][0]["isolates"][0]["sequences"][0]["segment"] = "c" with open(source_path, "w") as handle: json.dump(reference_data, handle) @@ -499,7 +584,7 @@ async def test_collapse_reference_json_rejects_isolate_segments_that_do_not_matc with pytest.raises( ValueError, match=re.escape( - "Isolate default sequence segments ['b'] do not match OTU collapse-otu " + "Isolate default sequence segments ['c'] are not defined in OTU collapse-otu " "schema segments ['a', 'b']" ), ):