diff --git a/Dockerfile b/Dockerfile index 5cd9cd9..f08955e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,8 +1,9 @@ FROM python:3.13-bookworm AS deps WORKDIR /app -COPY --from=ghcr.io/virtool/tools:1.1.0 /tools/bowtie2/2.5.4/bowtie* /usr/local/bin/ -COPY --from=ghcr.io/virtool/tools:1.1.0 /tools/pigz/2.8/pigz /usr/local/bin/ -COPY --from=ghcr.io/virtool/tools:1.1.0 /tools/samtools/1.22.1/bin/samtools /usr/local/bin/ +COPY --from=ghcr.io/virtool/tools:1.2.0 /tools/bowtie2/2.5.4/bowtie* /usr/local/bin/ +COPY --from=ghcr.io/virtool/tools:1.2.0 /tools/cd-hit/4.8.1/cd-hit-est /usr/local/bin/ +COPY --from=ghcr.io/virtool/tools:1.2.0 /tools/pigz/2.8/pigz /usr/local/bin/ +COPY --from=ghcr.io/virtool/tools:1.2.0 /tools/samtools/1.22.1/bin/samtools /usr/local/bin/ FROM python:3.13-bookworm AS uv WORKDIR /app diff --git a/fixtures.py b/fixtures.py index e77948d..9e2e9a5 100644 --- a/fixtures.py +++ b/fixtures.py @@ -4,6 +4,42 @@ from pyfixtures import fixture +def get_reference_index_path(work_path: Path) -> Path: + return work_path / "reference_index" / "reference" + + +def get_collapsed_reference_path(work_path: Path) -> Path: + return work_path / "collapsed_reference" / "reference.json" + + +def get_subtraction_indexes_path(work_path: Path) -> Path: + return work_path / "subtraction_indexes" + + +def get_isolate_path(work_path: Path) -> Path: + return work_path / "isolates" + + +def get_isolate_fasta_path(isolate_path: Path) -> Path: + return isolate_path / "isolate_index.fa" + + +def get_isolate_fastq_path(isolate_path: Path) -> Path: + return isolate_path / "isolate_mapped.fq" + + +def get_isolate_index_path(isolate_path: Path) -> Path: + return isolate_path / "isolates" + + +def get_isolate_bam_path(isolate_path: Path) -> Path: + return isolate_path / "to_isolates.bam" + + +def get_subtracted_bam_path(work_path: Path) -> Path: + return work_path / "subtracted.bam" + + @fixture def intermediate(): """A namespace for storing intermediate values.""" @@ -14,7 +50,7 @@ def intermediate(): @fixture def isolate_path(work_path: Path): - path = work_path / "isolates" + path = get_isolate_path(work_path) path.mkdir() return path @@ -22,32 +58,37 @@ def isolate_path(work_path: Path): @fixture def reference_index_path(work_path: Path): - return work_path / "reference_index" / "reference" + return get_reference_index_path(work_path) + + +@fixture +def collapsed_reference_path(work_path: Path): + return get_collapsed_reference_path(work_path) @fixture def subtraction_indexes_path(work_path: Path) -> Path: - return work_path / "subtraction_indexes" + return get_subtraction_indexes_path(work_path) @fixture def isolate_fasta_path(isolate_path: Path): - return isolate_path / "isolate_index.fa" + return get_isolate_fasta_path(isolate_path) @fixture def isolate_fastq_path(isolate_path: Path): - return isolate_path / "isolate_mapped.fq" + return get_isolate_fastq_path(isolate_path) @fixture def isolate_index_path(isolate_path: Path): - return isolate_path / "isolates" + return get_isolate_index_path(isolate_path) @fixture def isolate_bam_path(isolate_path: Path): - return isolate_path / "to_isolates.bam" + return get_isolate_bam_path(isolate_path) @fixture @@ -58,4 +99,4 @@ def p_score_cutoff(): @fixture def subtracted_bam_path(work_path: Path) -> Path: """The path to the BAM file after subtraction reads have been eliminated.""" - return work_path / "subtracted.bam" + return get_subtracted_bam_path(work_path) diff --git a/python/workflow_pathoscope/utils.py b/python/workflow_pathoscope/utils.py index fc0eea4..518a66b 100755 --- a/python/workflow_pathoscope/utils.py +++ b/python/workflow_pathoscope/utils.py @@ -3,10 +3,13 @@ import gzip import json import re +from collections.abc import Awaitable, Callable from pathlib import Path +from tempfile import TemporaryDirectory from virtool.caches.utils import derive_key from virtool.workflow.data.cache import CacheHit, WorkflowCache +from virtool.workflow.errors import SubprocessFailedError from virtool.workflow.runtime.run_subprocess import RunSubprocess from virtool.workflow.utils import get_workflow_version @@ -14,6 +17,8 @@ BOWTIE2_BUILD_TOOL = "bowtie2-build" +CD_HIT_EST_TOOL = "cd-hit-est" +CD_HIT_EST_IDENTITY = "0.99" WORKFLOW_NAME = "pathoscope" WORKFLOW_VERSION = get_workflow_version() @@ -37,6 +42,48 @@ async def collect_stdout(line: bytes) -> None: return match.group(1) +async def get_cd_hit_est_version(run_subprocess: RunSubprocess) -> str: + output = [] + + async def collect_output(line: bytes) -> None: + output.append(line.decode()) + + try: + await run_subprocess( + [CD_HIT_EST_TOOL, "-h"], + stderr_handler=collect_output, + stdout_handler=collect_output, + ) + except SubprocessFailedError: + # cd-hit-est -h prints help/version text and exits with code 1. + pass + + match = re.search(r"\bCD-HIT\s+version\s+([^\s]+)", "".join(output)) + + if match is None: + raise ValueError("Could not parse cd-hit-est version") + + return match.group(1) + + +async def get_reference_collapse_cache_params( + parent_id: str, + run_subprocess: RunSubprocess, +) -> dict[str, str]: + tool_version = await get_cd_hit_est_version(run_subprocess) + + return { + "index_kind": "collapsed_reference", + "workflow": WORKFLOW_NAME, + "workflow_version": WORKFLOW_VERSION, + "parent_id": parent_id, + "source": "index_json", + "tool_name": CD_HIT_EST_TOOL, + "tool_version": tool_version, + "identity": CD_HIT_EST_IDENTITY, + } + + async def get_mapping_index_cache_params( index_kind: str, parent_id: str, @@ -146,6 +193,258 @@ def _get_reference_otus(reference_data): return reference_data +def _validate_isolate_sequence_segments_match_schema( + otu: dict, + isolate: dict, + schema_segments: set[str], +) -> None: + sequence_segments = {str(sequence["segment"]) for sequence in isolate["sequences"]} + unknown_segments = ( + sequence_segments - schema_segments + if schema_segments + else {segment for segment in sequence_segments if segment} + ) + + if unknown_segments: + raise ValueError( + f"Isolate {isolate['id']} sequence segments " + f"{sorted(unknown_segments)!r} are not defined in OTU {otu['_id']} schema " + f"segments {sorted(schema_segments)!r}" + ) + + +def _write_fasta(sequences: list[dict], path: Path) -> None: + with open(path, "w") as handle: + for sequence in sequences: + handle.write(f">{sequence['_id']}\n{sequence['sequence']}\n") + + +def _parse_cd_hit_clusters(cluster_path: Path) -> dict[str, str]: + representatives_by_sequence_id = {} + cluster_sequence_ids = [] + representative_id = None + + def flush_cluster() -> None: + if representative_id is None: + return + + for sequence_id in cluster_sequence_ids: + representatives_by_sequence_id[sequence_id] = representative_id + + with open(cluster_path) as handle: + for line in handle: + line = line.strip() + + if line.startswith(">Cluster"): + flush_cluster() + cluster_sequence_ids = [] + representative_id = None + continue + + match = re.search(r">(.+?)\.\.\.", line) + + if match is None: + continue + + sequence_id = match.group(1) + cluster_sequence_ids.append(sequence_id) + + if line.endswith("*"): + representative_id = sequence_id + + flush_cluster() + + return representatives_by_sequence_id + + +def _build_representative_set( + isolate: dict, + representatives_by_sequence_id: dict[str, str], +) -> frozenset[str]: + return frozenset( + representatives_by_sequence_id[sequence["_id"]] + for sequence in isolate["sequences"] + ) + + +def _validate_and_group_otu_sequences_by_segment(otu: dict) -> dict[str, list[dict]]: + sequences_by_segment = {} + schema_segments = {str(segment["name"]) for segment in otu["schema"]} + + for isolate in otu["isolates"]: + _validate_isolate_sequence_segments_match_schema( + otu, + isolate, + schema_segments, + ) + for sequence in isolate["sequences"]: + sequences_by_segment.setdefault(sequence["segment"], []).append(sequence) + + return sequences_by_segment + + +async def _collapse_reference_segment( + segment_input_path: Path, + segment_output_path: Path, + sequences: list[dict], + run_subprocess: RunSubprocess, +) -> dict[str, str]: + await asyncio.to_thread(_write_fasta, sequences, segment_input_path) + + await run_subprocess( + [ + CD_HIT_EST_TOOL, + "-i", + str(segment_input_path), + "-o", + str(segment_output_path), + "-c", + CD_HIT_EST_IDENTITY, + "-T", + "1", + "-M", + "0", + "-d", + "0", + ], + ) + + return _parse_cd_hit_clusters(segment_output_path.with_suffix(".cdhit.clstr")) + + +async def _collapse_otu_segments( + otu: dict, + temp_path: Path, + collapse_segment: Callable[[Path, Path, list[dict]], Awaitable[dict[str, str]]], +) -> dict[str, str]: + representatives_by_sequence_id = {} + sequences_by_segment = _validate_and_group_otu_sequences_by_segment(otu) + + segment_tasks = [] + sorted_segment_sequences = sorted( + sequences_by_segment.items(), + key=lambda item: item[0], + ) + + for segment_name, sequences in sorted_segment_sequences: + segment_input_path = temp_path / f"otu-{otu['_id']}-segment-{segment_name}.fa" + segment_output_path = ( + temp_path / f"otu-{otu['_id']}-segment-{segment_name}.cdhit" + ) + + segment_tasks.append( + collapse_segment( + segment_input_path, + segment_output_path, + sequences, + ) + ) + + for representatives in await asyncio.gather(*segment_tasks): + representatives_by_sequence_id.update(representatives) + + return representatives_by_sequence_id + + +async def _collapse_otu_reference( + otu: dict, + temp_path: Path, + collapse_segment: Callable[[Path, Path, list[dict]], Awaitable[dict[str, str]]], +) -> dict: + representatives_by_sequence_id = await _collapse_otu_segments( + otu, + temp_path, + collapse_segment, + ) + + default_sequence_ids = { + sequence["_id"] + for isolate in otu["isolates"] + if isolate["default"] + for sequence in isolate["sequences"] + } + seen_representative_sets = set() + collapsed_isolates = [] + + for isolate in otu["isolates"]: + representative_set = _build_representative_set( + isolate, + representatives_by_sequence_id, + ) + + first_for_set = representative_set not in seen_representative_sets + seen_representative_sets.add(representative_set) + + contains_default_sequence = any( + sequence["_id"] in default_sequence_ids for sequence in isolate["sequences"] + ) + + if isolate["default"] or contains_default_sequence or first_for_set: + collapsed_isolates.append(isolate) + + return {**otu, "isolates": collapsed_isolates} + + +async def collapse_reference_json( + json_path: Path, + target_path: Path, + proc: int, + run_subprocess: RunSubprocess, +) -> dict[str, int]: + """Collapse redundant isolates in a reference JSON using cd-hit-est clusters.""" + with _open_json(json_path) as handle: + reference_data = json.load(handle) + + otus = _get_reference_otus(reference_data) + before_count = 0 + after_count = 0 + + with TemporaryDirectory(prefix="pathoscope-collapse-") as temp_dir: + temp_path = Path(temp_dir) + semaphore = asyncio.Semaphore(proc) + + async def collapse_segment( + segment_input_path: Path, + segment_output_path: Path, + sequences: list[dict], + ) -> dict[str, str]: + async with semaphore: + return await _collapse_reference_segment( + segment_input_path, + segment_output_path, + sequences, + run_subprocess, + ) + + collapsed_otus = [] + + for otu in otus: + collapsed_otu = await _collapse_otu_reference( + otu, + temp_path, + collapse_segment, + ) + collapsed_otus.append(collapsed_otu) + before_count += len(otu["isolates"]) + after_count += len(collapsed_otu["isolates"]) + + if isinstance(reference_data, dict): + collapsed_reference_data = {**reference_data, "otus": collapsed_otus} + else: + collapsed_reference_data = collapsed_otus + + await asyncio.to_thread(target_path.parent.mkdir, parents=True, exist_ok=True) + + with open(target_path, "w") as handle: + json.dump(collapsed_reference_data, handle) + + return { + "isolate_count_before": before_count, + "isolate_count_after": after_count, + "isolate_count_removed": before_count - after_count, + } + + def write_default_isolate_fasta( json_path: Path, target_path: Path, diff --git a/tests/assets/redundant_reference.json b/tests/assets/redundant_reference.json new file mode 100644 index 0000000..71d22ac --- /dev/null +++ b/tests/assets/redundant_reference.json @@ -0,0 +1,97 @@ +{ + "otus": [ + { + "_id": "collapse-otu", + "schema": [ + { + "name": "a" + }, + { + "name": "b" + } + ], + "isolates": [ + { + "id": "default", + "default": true, + "sequences": [ + { + "_id": "default-a", + "segment": "a", + "sequence": "ACGTACGTACGTACGTACGT" + }, + { + "_id": "default-b", + "segment": "b", + "sequence": "TGCATGCATGCATGCATGCA" + } + ] + }, + { + "id": "representative-1", + "default": false, + "sequences": [ + { + "_id": "representative-1-a", + "segment": "a", + "sequence": "ACGTACGTACGTACGTACGT" + }, + { + "_id": "representative-1-b", + "segment": "b", + "sequence": "TGCAGGATCGTTTAACGTAG" + } + ] + }, + { + "id": "representative-2", + "default": false, + "sequences": [ + { + "_id": "representative-2-a", + "segment": "a", + "sequence": "CCCCAAAAGGGGTTTTCCCC" + }, + { + "_id": "representative-2-b", + "segment": "b", + "sequence": "AAAACCCCGGGGTTTTAAAA" + } + ] + }, + { + "id": "unique-combo", + "default": false, + "sequences": [ + { + "_id": "unique-combo-a", + "segment": "a", + "sequence": "CCCCAAAAGGGGTTTTCCCC" + }, + { + "_id": "unique-combo-b", + "segment": "b", + "sequence": "TGCAGGATCGTTTAACGTAG" + } + ] + }, + { + "id": "duplicate", + "default": false, + "sequences": [ + { + "_id": "duplicate-a", + "segment": "a", + "sequence": "CCCCAAAAGGGGTTTTCCCC" + }, + { + "_id": "duplicate-b", + "segment": "b", + "sequence": "AAAACCCCGGGGTTTTAAAA" + } + ] + } + ] + } + ] +} diff --git a/tests/test_workflow.py b/tests/test_workflow.py index bb91e2d..795830b 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -1,5 +1,7 @@ +import asyncio import gzip import json +import re import shutil import pysam @@ -12,13 +14,25 @@ from virtool.caches.utils import derive_key from virtool.workflow import RunSubprocess from virtool.workflow.data.analyses import WFAnalysis -from virtool.workflow.data.cache import CacheHit, CacheMiss +from virtool.workflow.data.cache import WorkflowCache from virtool.workflow.data.indexes import WFIndex from virtool.workflow.data.samples import WFSample from virtool.workflow.data.subtractions import WFSubtraction +from virtool.workflow.errors import JobsAPINotFoundError from virtool.workflow.pytest_plugin import WorkflowData +from fixtures import ( + get_collapsed_reference_path, + get_isolate_bam_path, + get_isolate_fastq_path, + get_isolate_index_path, + get_isolate_path, + get_reference_index_path, + get_subtracted_bam_path, + get_subtraction_indexes_path, +) from workflow import ( + collapse_reference, create_reference_index, create_subtraction_index, eliminate_subtraction, @@ -28,8 +42,12 @@ reassignment, ) from workflow_pathoscope.utils import ( + CD_HIT_EST_IDENTITY, + collapse_reference_json, get_mapping_index_cache_params, + get_reference_collapse_cache_params, write_default_isolate_fasta, + write_isolate_fasta, ) @@ -41,6 +59,10 @@ "rev.1.bt2", "rev.2.bt2", ) +REDUNDANT_REFERENCE_JSON_PATH = ( + Path(__file__).parent / "assets" / "redundant_reference.json" +) +TOOL_VERSION_PATTERN = re.compile(r"\d+(?:\.\d+)+(?:[-+._A-Za-z0-9]*)?") @pytest.fixture() @@ -53,12 +75,17 @@ def work_path(tmpdir): @pytest.fixture() def reference_index_path(work_path: Path) -> Path: - return work_path / "reference_index" / "reference" + return get_reference_index_path(work_path) + + +@pytest.fixture() +def collapsed_reference_path(work_path: Path) -> Path: + return get_collapsed_reference_path(work_path) @pytest.fixture() def subtraction_indexes_path(work_path: Path) -> Path: - return work_path / "subtraction_indexes" + return get_subtraction_indexes_path(work_path) @pytest.fixture() @@ -69,71 +96,85 @@ def subtraction_index_path( return get_subtraction_index_path(subtraction_indexes_path, subtractions[0].id) -class FakeWorkflowCache: - def __init__( - self, - hit_source: Path | None = None, - put_exception: Exception | None = None, - put_created: bool = True, - ): - self.hit_source = hit_source - self.put_exception = put_exception - self.put_created = put_created - self.gets = [] - self.puts = [] +@pytest.fixture() +def isolate_path(work_path: Path) -> Path: + path = get_isolate_path(work_path) + path.mkdir() + + return path - async def get(self, key: str, target: Path): - self.gets.append((key, target)) - if self.hit_source is None: - return CacheMiss(key) +@pytest.fixture() +def isolate_fastq_path(isolate_path: Path) -> Path: + return get_isolate_fastq_path(isolate_path) - target.mkdir(parents=True, exist_ok=True) - restored_path = target / self.hit_source.name - shutil.copytree(self.hit_source, restored_path) +@pytest.fixture() +def isolate_index_path(isolate_path: Path) -> Path: + return get_isolate_index_path(isolate_path) - return CacheHit(key, restored_path) - async def put(self, key: str, source: Path, params: dict | None = None): - self.puts.append((key, source, params)) +@pytest.fixture() +def isolate_bam_path(isolate_path: Path) -> Path: + return get_isolate_bam_path(isolate_path) - if self.put_exception is not None: - raise self.put_exception - return self.put_created +@pytest.fixture() +def subtracted_bam_path(work_path: Path) -> Path: + return get_subtracted_bam_path(work_path) -class FakeRunSubprocess: - def __init__(self): - self.commands = [] +class _FakeWorkflowCacheAPI: + """Fake only the API calls used by the real workflow cache.""" - async def __call__( + def __init__( self, - command: list[str], - cwd: str | Path | None = None, - env: dict | None = None, - stderr_handler=None, - stdout_handler=None, - ): - self.commands.append(command) + work_dir: Path, + *, + put_exception: Exception | None = None, + put_created: bool | None = None, + ) -> None: + self.work_dir = work_dir + self.put_exception = put_exception + self.put_created = put_created + self.stored: dict[str, tuple[Path, dict | None]] = {} - if command == ["bowtie2-build", "--version"]: - await stdout_handler(b"/usr/bin/bowtie2-build-s version 2.5.4\n") - return SimpleNamespace(returncode=0) + async def get_cache(self, key: str, dest: Path) -> None: + try: + source, _ = self.stored[key] + except KeyError: + raise JobsAPINotFoundError from None - if command[0] == "bowtie2-build": - prefix = Path(command[-1]) - prefix.parent.mkdir(parents=True, exist_ok=True) + dest.parent.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(shutil.copyfile, source, dest) - for suffix in BOWTIE2_INDEX_SUFFIXES: - (prefix.parent / f"{prefix.name}.{suffix}").write_bytes( - f"{prefix.name}.{suffix}".encode(), - ) + async def put_cache( + self, + key: str, + path: Path, + params: dict | None = None, + ) -> bool: + if self.put_exception: + raise self.put_exception - return SimpleNamespace(returncode=0) + if key in self.stored: + return False - raise AssertionError(f"Unexpected subprocess command: {command}") + stored_path = self.work_dir / "stored" / key / "cache.tar" + stored_path.parent.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread(shutil.copyfile, path, stored_path) + self.stored[key] = (stored_path, params) + + if self.put_created is not None: + return self.put_created + + return True + + +@pytest.fixture() +def workflow_cache(tmp_path: Path) -> WorkflowCache: + api = _FakeWorkflowCacheAPI(tmp_path / "fake_workflow_cache_api") + return WorkflowCache(api, tmp_path / "workflow_cache") def read_directory_bytes(path: Path) -> dict[str, bytes]: @@ -162,6 +203,26 @@ def write_bowtie2_bundle( (path / name).write_bytes(file_content) +def assert_bowtie2_index_exists(prefix: Path): + assert read_directory_bytes(prefix.parent).keys() == { + f"{prefix.name}.{suffix}" for suffix in BOWTIE2_INDEX_SUFFIXES + } + + for suffix in BOWTIE2_INDEX_SUFFIXES: + assert (prefix.parent / f"{prefix.name}.{suffix}").stat().st_size > 0 + + +def assert_cache_params( + params: dict[str, str], + expected: dict[str, str], +) -> None: + assert params.keys() == {*expected.keys(), "tool_version"} + assert { + key: value for key, value in params.items() if key != "tool_version" + } == expected + assert TOOL_VERSION_PATTERN.fullmatch(params["tool_version"]) + + def write_reference_json(path: Path): path.parent.mkdir(parents=True, exist_ok=True) @@ -175,6 +236,7 @@ def write_reference_json(path: Path): "_id": "default-otu", "isolates": [ { + "id": "default", "default": True, "sequences": [ {"_id": "default-a", "sequence": "ACGT"}, @@ -182,6 +244,7 @@ def write_reference_json(path: Path): ], }, { + "id": "non-default", "default": False, "sequences": [ {"_id": "non-default", "sequence": "GGGG"}, @@ -193,6 +256,7 @@ def write_reference_json(path: Path): "_id": "non-default-otu", "isolates": [ { + "id": "non-default-only", "default": False, "sequences": [ {"_id": "non-default-only", "sequence": "CCCC"}, @@ -206,6 +270,12 @@ def write_reference_json(path: Path): ) +def write_redundant_reference_json(path: Path): + path.parent.mkdir(parents=True, exist_ok=True) + + shutil.copyfile(REDUNDANT_REFERENCE_JSON_PATH, path) + + @pytest.fixture() def analysis(workflow_data: WorkflowData, mocker): analysis_ = mocker.Mock(WFAnalysis) @@ -262,6 +332,270 @@ def index(workflow_data: WorkflowData, example_path: Path, work_path: Path): ) +async def test_collapse_reference_hit( + collapsed_reference_path: Path, + index: WFIndex, + run_subprocess: RunSubprocess, + tmp_path: Path, + workflow_cache: WorkflowCache, +): + source = tmp_path / collapsed_reference_path.parent.name + source.mkdir() + write_reference_json(source / collapsed_reference_path.name) + (source / "collapse-manifest.json").write_text( + json.dumps( + { + "isolate_count_before": 4, + "isolate_count_after": 3, + "isolate_count_removed": 1, + } + ) + ) + params = await get_reference_collapse_cache_params(index.id, run_subprocess) + key = derive_key(params) + assert await workflow_cache.put(key, source, params) + + logger = get_logger("test") + + result_path = await collapse_reference( + workflow_cache, + collapsed_reference_path, + index, + logger, + 4, + run_subprocess, + ) + + assert result_path == collapsed_reference_path + assert read_directory_bytes( + collapsed_reference_path.parent + ) == read_directory_bytes(source) + + +async def test_collapse_reference_miss_retains_required_isolates( + collapsed_reference_path: Path, + index: WFIndex, + run_subprocess: RunSubprocess, + workflow_cache: WorkflowCache, +): + write_redundant_reference_json(index.json_path) + logger = get_logger("test") + + result_path = await collapse_reference( + workflow_cache, + collapsed_reference_path, + index, + logger, + 4, + run_subprocess, + ) + + params = await get_reference_collapse_cache_params(index.id, run_subprocess) + + assert result_path == collapsed_reference_path + assert_cache_params( + params, + { + "index_kind": "collapsed_reference", + "workflow": "pathoscope", + "workflow_version": "UNKNOWN", + "parent_id": index.id, + "source": "index_json", + "tool_name": "cd-hit-est", + "identity": "0.99", + }, + ) + + with open(collapsed_reference_path) as handle: + collapsed_reference = json.load(handle) + + assert [ + isolate["id"] for isolate in collapsed_reference["otus"][0]["isolates"] + ] == [ + "default", + "representative-1", + "representative-2", + "unique-combo", + ] + assert [ + (isolate["sequences"][0]["_id"], isolate["sequences"][1]["_id"]) + for isolate in collapsed_reference["otus"][0]["isolates"] + ] == [ + ("default-a", "default-b"), + ("representative-1-a", "representative-1-b"), + ("representative-2-a", "representative-2-b"), + ("unique-combo-a", "unique-combo-b"), + ] + assert json.loads( + (collapsed_reference_path.parent / "collapse-manifest.json").read_text() + ) == { + "isolate_count_before": 5, + "isolate_count_after": 4, + "isolate_count_removed": 1, + } + + +async def test_collapse_reference_json_outputs_collapsed_reference_fasta( + run_subprocess: RunSubprocess, + tmp_path: Path, +): + source_path = tmp_path / "reference.json" + collapsed_path = tmp_path / "collapsed" / "reference.json" + default_fasta_path = tmp_path / "default.fa" + isolate_fasta_path = tmp_path / "isolates.fa" + + write_redundant_reference_json(source_path) + + assert await collapse_reference_json( + source_path, + collapsed_path, + 2, + run_subprocess, + ) == { + "isolate_count_before": 5, + "isolate_count_after": 4, + "isolate_count_removed": 1, + } + + assert write_default_isolate_fasta(collapsed_path, default_fasta_path) == { + "default-a": 20, + "default-b": 20, + } + assert write_isolate_fasta( + {"collapse-otu"}, + collapsed_path, + isolate_fasta_path, + ) == { + "default-a": 20, + "default-b": 20, + "representative-1-a": 20, + "representative-1-b": 20, + "representative-2-a": 20, + "representative-2-b": 20, + "unique-combo-a": 20, + "unique-combo-b": 20, + } + assert "duplicate-a" not in isolate_fasta_path.read_text() + assert "duplicate-b" not in isolate_fasta_path.read_text() + + +async def test_collapse_reference_json_allows_isolates_missing_schema_segments( + run_subprocess: RunSubprocess, + tmp_path: Path, +): + source_path = tmp_path / "reference.json" + collapsed_path = tmp_path / "collapsed" / "reference.json" + + write_redundant_reference_json(source_path) + + with open(source_path) as handle: + reference_data = json.load(handle) + + reference_data["otus"][0]["isolates"][0]["sequences"] = [ + reference_data["otus"][0]["isolates"][0]["sequences"][1], + ] + + with open(source_path, "w") as handle: + json.dump(reference_data, handle) + + assert await collapse_reference_json( + source_path, + collapsed_path, + 2, + run_subprocess, + ) == { + "isolate_count_before": 5, + "isolate_count_after": 4, + "isolate_count_removed": 1, + } + + with open(collapsed_path) as handle: + collapsed_reference = json.load(handle) + + assert collapsed_reference["otus"][0]["isolates"][0]["sequences"] == [ + { + "_id": "default-b", + "segment": "b", + "sequence": "TGCATGCATGCATGCATGCA", + }, + ] + + +async def test_collapse_reference_json_allows_unsegmented_isolates( + run_subprocess: RunSubprocess, + tmp_path: Path, +): + source_path = tmp_path / "reference.json" + collapsed_path = tmp_path / "collapsed" / "reference.json" + + write_redundant_reference_json(source_path) + + with open(source_path) as handle: + reference_data = json.load(handle) + + reference_data["otus"][0]["schema"] = [] + + for isolate in reference_data["otus"][0]["isolates"]: + for sequence in isolate["sequences"]: + sequence["segment"] = "" + + with open(source_path, "w") as handle: + json.dump(reference_data, handle) + + assert await collapse_reference_json( + source_path, + collapsed_path, + 2, + run_subprocess, + ) == { + "isolate_count_before": 5, + "isolate_count_after": 4, + "isolate_count_removed": 1, + } + + with open(collapsed_path) as handle: + collapsed_reference = json.load(handle) + + assert collapsed_reference["otus"][0]["schema"] == [] + assert { + sequence["segment"] + for isolate in collapsed_reference["otus"][0]["isolates"] + for sequence in isolate["sequences"] + } == {""} + + +async def test_collapse_reference_json_rejects_isolate_segments_that_do_not_match_schema( + run_subprocess: RunSubprocess, + tmp_path: Path, +): + source_path = tmp_path / "reference.json" + collapsed_path = tmp_path / "collapsed" / "reference.json" + + write_redundant_reference_json(source_path) + + with open(source_path) as handle: + reference_data = json.load(handle) + + reference_data["otus"][0]["isolates"][0]["sequences"][0]["segment"] = "c" + + with open(source_path, "w") as handle: + json.dump(reference_data, handle) + + with pytest.raises( + ValueError, + match=re.escape( + "Isolate default sequence segments ['c'] are not defined in OTU collapse-otu " + "schema segments ['a', 'b']" + ), + ): + await collapse_reference_json( + source_path, + collapsed_path, + 2, + run_subprocess, + ) + + @pytest.fixture() def sample(workflow_data: WorkflowData, example_path: Path, work_path: Path): workflow_data.sample.library_type = "normal" @@ -303,11 +637,14 @@ def subtractions(workflow_data: WorkflowData, example_path: Path, work_path: Pat async def test_create_reference_index_hit( + collapsed_reference_path: Path, index: WFIndex, reference_index_path: Path, + run_subprocess: RunSubprocess, tmp_path: Path, + workflow_cache: WorkflowCache, ): - write_reference_json(index.json_path) + write_reference_json(collapsed_reference_path) source = tmp_path / reference_index_path.parent.name write_bowtie2_bundle( source, @@ -315,12 +652,24 @@ async def test_create_reference_index_hit( b"cached-reference", {"cache-manifest.json": b"cached-manifest"}, ) - cache = FakeWorkflowCache(source) - run_subprocess = FakeRunSubprocess() + params = await get_mapping_index_cache_params( + "reference_mapping_index", + index.id, + run_subprocess, + { + "collapse_identity": CD_HIT_EST_IDENTITY, + "source": "collapsed_reference", + "selection": "default_isolates", + }, + ) + key = derive_key(params) + assert await workflow_cache.put(key, source, params) + logger = get_logger("test") result_index_path = await create_reference_index( - cache, + workflow_cache, + collapsed_reference_path, index, logger, 4, @@ -328,41 +677,25 @@ async def test_create_reference_index_hit( reference_index_path, ) - params = await get_mapping_index_cache_params( - "reference_mapping_index", - index.id, - FakeRunSubprocess(), - { - "source": "index_json", - "selection": "default_isolates", - }, - ) - - assert cache.gets == [ - ( - derive_key(params), - reference_index_path.parent.parent, - ), - ] - assert cache.puts == [] assert result_index_path == reference_index_path - assert run_subprocess.commands == [["bowtie2-build", "--version"]] assert read_directory_bytes(reference_index_path.parent) == read_directory_bytes( source ) async def test_create_reference_index_miss( + collapsed_reference_path: Path, index: WFIndex, reference_index_path: Path, + run_subprocess: RunSubprocess, + workflow_cache: WorkflowCache, ): - write_reference_json(index.json_path) - cache = FakeWorkflowCache() - run_subprocess = FakeRunSubprocess() + write_reference_json(collapsed_reference_path) logger = get_logger("test") result_index_path = await create_reference_index( - cache, + workflow_cache, + collapsed_reference_path, index, logger, 4, @@ -373,51 +706,50 @@ async def test_create_reference_index_miss( params = await get_mapping_index_cache_params( "reference_mapping_index", index.id, - FakeRunSubprocess(), + run_subprocess, { - "source": "index_json", + "collapse_identity": CD_HIT_EST_IDENTITY, + "source": "collapsed_reference", "selection": "default_isolates", }, ) - key = derive_key(params) - assert cache.gets == [(key, reference_index_path.parent.parent)] - assert cache.puts == [(key, reference_index_path.parent, params)] assert result_index_path == reference_index_path - assert params == { - "index_kind": "reference_mapping_index", - "workflow": "pathoscope", - "workflow_version": "UNKNOWN", - "parent_id": index.id, - "source": "index_json", - "selection": "default_isolates", - "tool_name": "bowtie2-build", - "tool_version": "2.5.4", - } - assert len(run_subprocess.commands) == 2 - assert run_subprocess.commands[0] == ["bowtie2-build", "--version"] - assert run_subprocess.commands[1][:3] == [ - "bowtie2-build", - "--threads", - "4", - ] - assert Path(run_subprocess.commands[1][3]).name == "reference.fa" - assert run_subprocess.commands[1][4] == str(reference_index_path) - assert read_directory_bytes(reference_index_path.parent) == { - **bowtie2_bundle_bytes("reference"), - } + assert_cache_params( + params, + { + "index_kind": "reference_mapping_index", + "workflow": "pathoscope", + "workflow_version": "UNKNOWN", + "parent_id": index.id, + "collapse_identity": "0.99", + "source": "collapsed_reference", + "selection": "default_isolates", + "tool_name": "bowtie2-build", + }, + ) + assert_bowtie2_index_exists(reference_index_path) async def test_create_reference_index_continues_when_cache_put_is_skipped( + collapsed_reference_path: Path, index: WFIndex, reference_index_path: Path, + run_subprocess: RunSubprocess, + tmp_path: Path, ): - write_reference_json(index.json_path) - cache = FakeWorkflowCache(put_created=False) - run_subprocess = FakeRunSubprocess() + write_reference_json(collapsed_reference_path) + cache = WorkflowCache( + _FakeWorkflowCacheAPI( + tmp_path / "fake_workflow_cache_api", + put_created=False, + ), + tmp_path / "workflow_cache", + ) logger = get_logger("test") result_index_path = await create_reference_index( cache, + collapsed_reference_path, index, logger, 4, @@ -425,36 +757,31 @@ async def test_create_reference_index_continues_when_cache_put_is_skipped( reference_index_path, ) - params = await get_mapping_index_cache_params( - "reference_mapping_index", - index.id, - FakeRunSubprocess(), - { - "source": "index_json", - "selection": "default_isolates", - }, - ) - key = derive_key(params) - assert cache.gets == [(key, reference_index_path.parent.parent)] - assert cache.puts == [(key, reference_index_path.parent, params)] assert result_index_path == reference_index_path - assert read_directory_bytes(reference_index_path.parent) == { - **bowtie2_bundle_bytes("reference"), - } + assert_bowtie2_index_exists(reference_index_path) async def test_create_reference_index_raises_unexpected_cache_put_failure( + collapsed_reference_path: Path, index: WFIndex, reference_index_path: Path, + run_subprocess: RunSubprocess, + tmp_path: Path, ): - write_reference_json(index.json_path) - cache = FakeWorkflowCache(put_exception=RuntimeError("cache upload failed")) - run_subprocess = FakeRunSubprocess() + write_reference_json(collapsed_reference_path) + cache = WorkflowCache( + _FakeWorkflowCacheAPI( + tmp_path / "fake_workflow_cache_api", + put_exception=RuntimeError("cache upload failed"), + ), + tmp_path / "workflow_cache", + ) logger = get_logger("test") with pytest.raises(RuntimeError, match="cache upload failed"): await create_reference_index( cache, + collapsed_reference_path, index, logger, 4, @@ -477,10 +804,12 @@ def test_write_default_isolate_fasta(tmp_path: Path): async def test_create_subtraction_index_hit( + run_subprocess: RunSubprocess, subtractions: list[WFSubtraction], subtraction_index_path: Path, subtraction_indexes_path: Path, tmp_path: Path, + workflow_cache: WorkflowCache, ): source = tmp_path / subtraction_index_path.parent.name write_bowtie2_bundle( @@ -489,13 +818,19 @@ async def test_create_subtraction_index_hit( b"cached-subtraction", {"cache-manifest.json": b"cached-manifest"}, ) - cache = FakeWorkflowCache(source) - run_subprocess = FakeRunSubprocess() - logger = get_logger("test") subtraction = subtractions[0] + params = await get_mapping_index_cache_params( + "subtraction_mapping_index", + subtraction.id, + run_subprocess, + ) + key = derive_key(params) + assert await workflow_cache.put(key, source, params) + + logger = get_logger("test") result_indexes_path = await create_subtraction_index( - cache, + workflow_cache, logger, 4, run_subprocess, @@ -503,38 +838,24 @@ async def test_create_subtraction_index_hit( subtraction_indexes_path, ) - params = await get_mapping_index_cache_params( - "subtraction_mapping_index", - subtraction.id, - FakeRunSubprocess(), - ) - - assert cache.gets == [ - ( - derive_key(params), - subtraction_index_path.parent.parent, - ), - ] - assert cache.puts == [] assert result_indexes_path == subtraction_indexes_path - assert run_subprocess.commands == [["bowtie2-build", "--version"]] assert read_directory_bytes(subtraction_index_path.parent) == read_directory_bytes( source ) async def test_create_subtraction_index_miss( + run_subprocess: RunSubprocess, subtractions: list[WFSubtraction], subtraction_index_path: Path, subtraction_indexes_path: Path, + workflow_cache: WorkflowCache, ): - cache = FakeWorkflowCache() - run_subprocess = FakeRunSubprocess() logger = get_logger("test") subtraction = subtractions[0] result_indexes_path = await create_subtraction_index( - cache, + workflow_cache, logger, 4, run_subprocess, @@ -545,33 +866,20 @@ async def test_create_subtraction_index_miss( params = await get_mapping_index_cache_params( "subtraction_mapping_index", subtraction.id, - FakeRunSubprocess(), + run_subprocess, ) - key = derive_key(params) - assert cache.gets == [(key, subtraction_index_path.parent.parent)] - assert cache.puts == [(key, subtraction_index_path.parent, params)] assert result_indexes_path == subtraction_indexes_path - assert params == { - "index_kind": "subtraction_mapping_index", - "workflow": "pathoscope", - "workflow_version": "UNKNOWN", - "parent_id": subtraction.id, - "tool_name": "bowtie2-build", - "tool_version": "2.5.4", - } - assert run_subprocess.commands == [ - ["bowtie2-build", "--version"], - [ - "bowtie2-build", - "--threads", - "4", - str(subtraction.fasta_path), - str(subtraction_index_path), - ], - ] - assert read_directory_bytes(subtraction_index_path.parent) == bowtie2_bundle_bytes( - "subtraction" + assert_cache_params( + params, + { + "index_kind": "subtraction_mapping_index", + "workflow": "pathoscope", + "workflow_version": "UNKNOWN", + "parent_id": subtraction.id, + "tool_name": "bowtie2-build", + }, ) + assert_bowtie2_index_exists(subtraction_index_path) async def test_map_default_isolates( @@ -608,22 +916,20 @@ async def test_map_default_isolates( async def test_map_isolates( example_path: Path, index: WFIndex, + isolate_bam_path: Path, + isolate_fastq_path: Path, + isolate_index_path: Path, sample: WFSample, run_subprocess: RunSubprocess, snapshot: SnapshotAssertion, - work_path: Path, ): for path in (example_path / "index").iterdir(): if "reference" in path.name: shutil.copyfile( path, - work_path / path.name.replace("reference", "isolates"), + isolate_index_path.parent / path.name.replace("reference", "isolates"), ) - isolate_fastq_path = work_path / "mapped.fq" - isolate_index_path = work_path / "isolates" - isolate_bam_path = work_path / "to_isolates.bam" - proc = 1 await map_isolates( @@ -654,18 +960,18 @@ async def test_map_isolates( ) async def test_eliminate_subtraction( example_path: Path, + isolate_bam_path: Path, + isolate_fastq_path: Path, no_subtractions: bool, + subtracted_bam_path: Path, subtractions: list[WFSubtraction], subtraction_indexes_path: Path, run_subprocess: RunSubprocess, snapshot: SnapshotAssertion, tmp_path: Path, + workflow_cache: WorkflowCache, work_path: Path, ): - isolate_fastq_path = work_path / "to_isolates.fq" - isolate_bam_path = work_path / "to_isolates.bam" - subtracted_path = work_path / "subtracted.bam" - shutil.copyfile(example_path / "to_isolates.bam", isolate_bam_path) shutil.copyfile(example_path / "to_isolates.fq", isolate_fastq_path) @@ -681,12 +987,19 @@ async def test_eliminate_subtraction( if subtractions: cached_subtraction_path = tmp_path / subtractions[0].id shutil.copytree(subtractions[0].path, cached_subtraction_path) + params = await get_mapping_index_cache_params( + "subtraction_mapping_index", + subtractions[0].id, + run_subprocess, + ) + key = derive_key(params) + assert await workflow_cache.put(key, cached_subtraction_path, params) await create_subtraction_index( - FakeWorkflowCache(cached_subtraction_path), + workflow_cache, logger, proc, - FakeRunSubprocess(), + run_subprocess, subtractions, subtraction_indexes_path, ) @@ -702,7 +1015,7 @@ async def test_eliminate_subtraction( run_subprocess, subtraction_indexes_path, subtractions, - subtracted_path, + subtracted_bam_path, work_path, ) @@ -752,8 +1065,8 @@ def parse_alignments(path: Path) -> set[tuple]: for read in alignment_file } - assert parse_alignments(work_path / "subtracted.bam") == snapshot(name="alignments") - assert parse_headers(work_path / "subtracted.bam") == snapshot(name="headers") + assert parse_alignments(subtracted_bam_path) == snapshot(name="alignments") + assert parse_headers(subtracted_bam_path) == snapshot(name="headers") async def test_pathoscope( @@ -763,10 +1076,9 @@ async def test_pathoscope( mocker, ref_lengths, snapshot: SnapshotAssertion, + subtracted_bam_path: Path, work_path: Path, ): - subtracted_bam_path = work_path / "to_isolates.bam" - shutil.copyfile(example_path / "to_isolates.bam", subtracted_bam_path) intermediate = SimpleNamespace(lengths=ref_lengths) diff --git a/workflow.py b/workflow.py index 349d87a..21931ca 100644 --- a/workflow.py +++ b/workflow.py @@ -1,4 +1,5 @@ import asyncio +import json import os import shlex import shutil @@ -7,16 +8,20 @@ from types import SimpleNamespace from typing import Any +from virtool.caches.utils import derive_key from virtool.workflow import hooks, step from virtool.workflow.data.analyses import WFAnalysis -from virtool.workflow.data.cache import WorkflowCache +from virtool.workflow.data.cache import CacheHit, WorkflowCache from virtool.workflow.data.indexes import WFIndex from virtool.workflow.data.samples import WFSample from virtool.workflow.data.subtractions import WFSubtraction from virtool.workflow.runtime.run_subprocess import RunSubprocess from workflow_pathoscope.utils import ( + CD_HIT_EST_IDENTITY, build_bowtie2_index, + collapse_reference_json, create_mapping_index, + get_reference_collapse_cache_params, run_pathoscope, write_default_isolate_fasta, write_isolate_fasta, @@ -46,9 +51,67 @@ async def delete_analysis_document(analysis: WFAnalysis): await analysis.delete() +@step +async def collapse_reference( + cache: WorkflowCache, + collapsed_reference_path: Path, + index: WFIndex, + logger, + proc: int, + run_subprocess: RunSubprocess, +) -> Path: + """Ensure a cd-hit-est collapsed reference JSON exists locally.""" + params = await get_reference_collapse_cache_params(index.id, run_subprocess) + key = derive_key(params) + collapsed_reference_dir = collapsed_reference_path.parent + log = logger.bind( + identity=CD_HIT_EST_IDENTITY, + index_kind="collapsed_reference", + key=key, + parent_id=index.id, + ) + + log.info("checking workflow cache") + + result = await cache.get(key, collapsed_reference_dir.parent) + + if isinstance(result, CacheHit): + log.info("restored cached collapsed reference", outcome="hit") + manifest_path = collapsed_reference_dir / "collapse-manifest.json" + + with open(manifest_path) as handle: + log.info("reference collapse restored", **json.load(handle)) + + return collapsed_reference_path + + log.info("collapsing reference", outcome="miss", source=str(index.json_path)) + + stats = await collapse_reference_json( + index.json_path, + collapsed_reference_path, + proc, + run_subprocess, + ) + + with open(collapsed_reference_dir / "collapse-manifest.json", "w") as handle: + json.dump(stats, handle) + + log.info("reference collapse complete", **stats) + + created = await cache.put(key, collapsed_reference_dir, params=params) + + if created: + log.info("cached collapsed reference", outcome="put") + else: + log.info("collapsed reference cache already exists", outcome="put_skipped") + + return collapsed_reference_path + + @step async def create_reference_index( cache: WorkflowCache, + collapsed_reference_path: Path, index: WFIndex, logger, proc: int, @@ -61,13 +124,13 @@ async def create_reference_index( await asyncio.to_thread( write_default_isolate_fasta, - index.json_path, + collapsed_reference_path, reference_fasta_path, ) logger.info( "assembled default reference fasta", - source=str(index.json_path), + source=str(collapsed_reference_path), ) await create_mapping_index( @@ -80,7 +143,8 @@ async def create_reference_index( index_prefix=reference_index_path, parent_id=index.id, extra_params={ - "source": "index_json", + "collapse_identity": CD_HIT_EST_IDENTITY, + "source": "collapsed_reference", "selection": "default_isolates", }, ) @@ -149,6 +213,7 @@ async def map_default_isolates( @step async def build_isolate_index( + collapsed_reference_path: Path, index: WFIndex, intermediate: SimpleNamespace, isolate_fasta_path: Path, @@ -161,7 +226,7 @@ async def build_isolate_index( intermediate.lengths = await asyncio.to_thread( write_isolate_fasta, {index.get_otu_id_by_sequence_id(id_) for id_ in intermediate.to_otus}, - index.json_path, + collapsed_reference_path, isolate_fasta_path, )