Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 57 additions & 42 deletions xrspatial/geotiff/_backends/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,56 +1352,71 @@ def _read_geotiff_gpu_chunked(source, *, dtype, chunks, overview_level,

src_path = _coerce_path(source)

# Per-tile compressed-byte cap, mirroring the eager GPU path and
# Local files: read the bytes once and parse the header/IFDs once,
# then reuse that single parse for both the per-tile cap check and
# the GDS qualification probe below. Parsing the IFDs is
# O(tile_count), so a 10k+ tile COG paid that cost twice (plus a
# second full-file read) at graph-build time before issue #3373.
#
# The per-tile compressed-byte cap mirrors the eager GPU path and
# the CPU readers. The chunked dask + GPU path either qualifies for
# the GDS fast
# path (handled in ``_read_geotiff_gpu_chunked_gds`` which runs
# the same cap on its own metadata parse) or falls through to
# ``_read_geotiff_dask`` whose per-chunk ``read_to_array`` calls
# apply the cap inside the CPU reader. The check here closes the
# window between "qualification probe parses the IFDs" and "the
# dispatch decides which path to take" so a forged tile is
# rejected at graph-build time rather than at first ``.compute()``.
# Sparse tiles (``byte_count == 0``) pass under any positive cap
# by design.
# the GDS fast path (handled in ``_read_geotiff_gpu_chunked_gds``
# which runs the same cap on its own metadata parse) or falls
# through to ``_read_geotiff_dask`` whose per-chunk
# ``read_to_array`` calls apply the cap inside the CPU reader. The
# check here closes the window between "qualification probe parses
# the IFDs" and "the dispatch decides which path to take" so a
# forged tile is rejected at graph-build time rather than at first
# ``.compute()``. Sparse tiles (``byte_count == 0``) pass under any
# positive cap by design.
raw = header = ifds = None
if isinstance(src_path, str) and not src_path.startswith(
('http://', 'https://')):
try:
with _FileSource(src_path) as _cap_fs:
_cap_raw = _cap_fs.read_all()
_cap_header = parse_header(_cap_raw)
_cap_ifds = parse_all_ifds(_cap_raw, _cap_header)
_cap_ifd = select_overview_ifd(_cap_ifds, overview_level)
_cap_byte_counts = _cap_ifd.tile_byte_counts
except Exception:
# If metadata parse fails here, the downstream path will
# surface a clear error; do not double-report.
_cap_byte_counts = None
if _cap_byte_counts is not None:
_cap = _max_tile_bytes_from_env()
for _tile_idx, _bc in enumerate(_cap_byte_counts):
if _bc > _cap:
raise ValueError(
f"TIFF tile {_tile_idx} declares "
f"TileByteCount={_bc:,} bytes, which exceeds "
f"the per-tile safety cap of {_cap:,} bytes. "
f"The file is malformed or attempting "
f"denial-of-service. Override via "
f"XRSPATIAL_COG_MAX_TILE_BYTES if this file "
f"is legitimate."
)

# Try the disk->GPU path. Parse metadata once; if the file does not
# qualify, fall through to the CPU-decode path. Any unexpected
# exception during the qualification probe also falls through so we
# never lose the ability to return a result.
try:
if isinstance(src_path, str) and not src_path.startswith(
('http://', 'https://')):
with _FileSource(src_path) as fs:
raw = fs.read_all()
header = parse_header(raw)
ifds = parse_all_ifds(raw, header)
except Exception:
# If metadata parse fails here, the downstream CPU path will
# surface a clear error; do not double-report. Leave
# raw/header/ifds as None so the qualification probe below
# is skipped and the read falls through to the CPU path.
raw = header = ifds = None

if ifds:
try:
_cap_byte_counts = select_overview_ifd(
ifds, overview_level).tile_byte_counts
except Exception:
# A bad overview level or malformed IFD here is surfaced
# by the downstream CPU path; skip the cap rather than
# raise an error unrelated to a denial-of-service tile.
_cap_byte_counts = None
if _cap_byte_counts is not None:
_cap = _max_tile_bytes_from_env()
for _tile_idx, _bc in enumerate(_cap_byte_counts):
if _bc > _cap:
raise ValueError(
f"TIFF tile {_tile_idx} declares "
f"TileByteCount={_bc:,} bytes, which exceeds "
f"the per-tile safety cap of {_cap:,} bytes. "
f"The file is malformed or attempting "
f"denial-of-service. Override via "
f"XRSPATIAL_COG_MAX_TILE_BYTES if this file "
f"is legitimate."
)

# Try the disk->GPU path, reusing the single parse above. If the
# file does not qualify, fall through to the CPU-decode path. Any
# unexpected exception during the qualification probe also falls
# through so we never lose the ability to return a result.
try:
if ifds is not None:
# An empty IFD list (non-None, falsy) raises here and the
# broad ``except`` below routes it to the CPU path, the same
# as a parse failure -- the chunked GPU read never qualifies
# a file with no IFDs.
if not ifds:
raise ValueError("No IFDs found in TIFF file")
ifd = select_overview_ifd(ifds, overview_level)
Expand Down
119 changes: 119 additions & 0 deletions xrspatial/geotiff/tests/gpu/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,3 +1547,122 @@ def test_chunked_gpu_eager_paths_keep_source_dtype_1909(

np_da = open_geotiff(uint16_no_sentinel_path_1909)
assert np_da.dtype == np.uint16


# ---------------------------------------------------------------------------
# Section: _read_geotiff_gpu_chunked reads + parses a local file once (#3373)
# ---------------------------------------------------------------------------


@pytest.fixture
def tiled_cog_path_3373(tmp_path):
"""Small deflate-tiled file for the chunked GPU read path."""
from xrspatial.geotiff import to_geotiff

arr = np.arange(32 * 32, dtype=np.float32).reshape(32, 32)
da = xr.DataArray(arr, dims=['y', 'x'],
attrs={'crs': 4326,
'transform': (1.0, 0, 0, 0, -1.0, 32.0)})
path = str(tmp_path / 'tiled_cog_3373.tif')
to_geotiff(da, path, compression='deflate', tile_size=16)
return path, arr


@_gpu_only
def test_chunked_gpu_reads_and_parses_local_file_once_3373(
tiled_cog_path_3373, monkeypatch):
"""The local chunked GPU read parses header/IFDs exactly once.

The per-tile cap check and the GDS qualification probe used to read
the whole file and parse all IFDs twice at graph-build time. Count
``read_all`` / ``parse_header`` / ``parse_all_ifds`` to pin the
single-parse behaviour so the redundant pass cannot creep back.
"""
from xrspatial.geotiff import _header as header_mod
from xrspatial.geotiff import _read_geotiff_gpu
from xrspatial.geotiff import _sources as sources_mod

path, arr = tiled_cog_path_3373

counts = {'read_all': 0, 'parse_header': 0, 'parse_all_ifds': 0}

real_read_all = sources_mod._FileSource.read_all
real_parse_header = header_mod.parse_header
real_parse_all_ifds = header_mod.parse_all_ifds

def counting_read_all(self, *a, **k):
counts['read_all'] += 1
return real_read_all(self, *a, **k)

def counting_parse_header(*a, **k):
counts['parse_header'] += 1
return real_parse_header(*a, **k)

def counting_parse_all_ifds(*a, **k):
counts['parse_all_ifds'] += 1
return real_parse_all_ifds(*a, **k)

monkeypatch.setattr(sources_mod._FileSource, 'read_all', counting_read_all)
monkeypatch.setattr(header_mod, 'parse_header', counting_parse_header)
monkeypatch.setattr(header_mod, 'parse_all_ifds', counting_parse_all_ifds)

result = _read_geotiff_gpu(path, chunks=8)
# Force the graph to build but stay lazy -- the redundant work was at
# graph-build time, before any compute.
assert hasattr(result.data, 'dask')

assert counts['read_all'] == 1, counts
assert counts['parse_header'] == 1, counts
assert counts['parse_all_ifds'] == 1, counts


@_gpu_only
def test_chunked_gpu_per_tile_cap_still_raises_3373(
tiled_cog_path_3373, monkeypatch):
"""The denial-of-service per-tile cap still raises at graph build."""
from xrspatial.geotiff import _read_geotiff_gpu

path, _ = tiled_cog_path_3373
# A 1-byte cap rejects any real tile; the merged single-parse path
# must still raise before any GDS qualification work runs.
monkeypatch.setenv('XRSPATIAL_COG_MAX_TILE_BYTES', '1')

with pytest.raises(ValueError, match='per-tile safety cap'):
_read_geotiff_gpu(path, chunks=8)


@_gpu_only
def test_chunked_gpu_malformed_metadata_falls_through_to_cpu_3373(
tiled_cog_path_3373, monkeypatch):
"""A metadata-parse failure routes to the CPU path, never raises out.

The single read+parse is wrapped so a parse error leaves
raw/header/ifds as ``None`` and the GDS qualification probe is
skipped, exactly as the two-parse version fell through.
"""
import cupy

from xrspatial.geotiff import _header as header_mod
from xrspatial.geotiff import _read_geotiff_gpu

path, arr = tiled_cog_path_3373

real_parse_all_ifds = header_mod.parse_all_ifds
calls = {'n': 0}

def failing_then_real_parse_all_ifds(*a, **k):
# Fail the in-function probe parse so the GDS qualification is
# skipped; the CPU fallback (_read_geotiff_dask) re-parses
# through its own code path, which must still succeed.
calls['n'] += 1
if calls['n'] == 1:
raise ValueError("synthetic metadata parse failure 3373")
return real_parse_all_ifds(*a, **k)

monkeypatch.setattr(header_mod, 'parse_all_ifds',
failing_then_real_parse_all_ifds)

result = _read_geotiff_gpu(path, chunks=8)
assert isinstance(result.data._meta, cupy.ndarray)
host = cupy.asnumpy(result.compute().data)
np.testing.assert_array_equal(host, arr)
Loading