Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .jules/bolt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
## 2024-06-28 - Optimize yEnc decoding
**Learning:** For Python byte string processing, such as yEnc decoding, using C-backed built-in methods like `bytes.translate()` and `bytes.find()` is significantly faster than manual byte-by-byte iteration. The `translate` method can handle the simple subtraction without looping in Python. When checking for dangling escapes, it is safer to check the bounds `escape_pos + 1 >= len(line)` rather than looking at the last byte. Furthermore, the standard logic for calculating escaped yEnc bytes `(char - 64 - 42) % 256` can be algebraically simplified to `(char - 106) % 256`. Avoid `bytes.split()` since it struggles with consecutive escapes (e.g., `==`).
**Action:** Always prioritize `bytes.translate()` and `bytes.find()` over manual loops when processing large binary strings in Python. Use simplified modular math when possible.
155 changes: 121 additions & 34 deletions verify_nzb.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,43 @@ def _parse_yenc_attrs(line: bytes) -> dict[str, str]:
return attrs


_YENC_TRANS_TABLE = bytes((i - 42) % 256 for i in range(256))


def _decode_yenc_lines(lines: Iterable[bytes]) -> bytes:
"""
Decodes yEnc encoded lines into bytes.
Optimized to use C-backed bytes.translate() and bytes.find()
instead of iterating byte-by-byte in Python.
"""
decoded = bytearray()

for line in lines:
index = 0
while index < len(line):
byte = line[index]
if byte == 61:
index += 1
if index >= len(line):
escape_pos = line.find(b"=")
if escape_pos == -1:
# Fast path: No escapes in this line
decoded.extend(line.translate(_YENC_TRANS_TABLE))
else:
start = 0
while True:
escape_pos = line.find(b"=", start)
if escape_pos == -1:
# Translate the remainder of the line
decoded.extend(line[start:].translate(_YENC_TRANS_TABLE))
break

if escape_pos > start:
# Translate characters before the escape
decoded.extend(line[start:escape_pos].translate(_YENC_TRANS_TABLE))

if escape_pos + 1 >= len(line):
raise ValueError("dangling yEnc escape")
byte = (line[index] - 64) % 256
decoded.append((byte - 42) % 256)
index += 1

# Handle the escaped character:
# (char - 64 - 42) % 256 algebraically simplifies to (char - 106) % 256
decoded.append((line[escape_pos + 1] - 106) % 256)
start = escape_pos + 2

return bytes(decoded)


Expand All @@ -138,7 +162,9 @@ def validate_yenc_body(lines: Iterable[bytes | str]) -> YencValidationResult:
data_lines: list[bytes] = []

for raw_line in lines:
line = raw_line.encode("latin-1") if isinstance(raw_line, str) else bytes(raw_line)
line = (
raw_line.encode("latin-1") if isinstance(raw_line, str) else bytes(raw_line)
)
line = line.rstrip(b"\r\n")
if line.startswith(b"=ybegin"):
ybegin_attrs = _parse_yenc_attrs(line)
Expand Down Expand Up @@ -235,7 +261,9 @@ def select_deep_sample(
unique_ids = list(dict.fromkeys(message_ids))
if not unique_ids:
return []
sample_size = min(len(unique_ids), max(1, math.ceil(len(unique_ids) * sample_percent / 100)))
sample_size = min(
len(unique_ids), max(1, math.ceil(len(unique_ids) * sample_percent / 100))
)
if sample_size == len(unique_ids):
return unique_ids
return random.Random(sample_seed).sample(unique_ids, sample_size)
Expand Down Expand Up @@ -271,9 +299,13 @@ def load_config(path: str | Path) -> list[ServerConfig]:
if value is None
]
if missing:
raise ValueError(f"missing required options in [{section}]: {', '.join(missing)}")
raise ValueError(
f"missing required options in [{section}]: {', '.join(missing)}"
)
if (username is None) ^ (password is None):
raise ValueError(f"[{section}] must set both username and password or neither")
raise ValueError(
f"[{section}] must set both username and password or neither"
)
if max_connections < 1:
raise ValueError(f"[{section}] max_connections must be at least 1")
if port < 1:
Expand All @@ -295,7 +327,9 @@ def load_config(path: str | Path) -> list[ServerConfig]:
)

if not servers:
raise ValueError("configuration must contain at least one [server.<name>] section")
raise ValueError(
"configuration must contain at least one [server.<name>] section"
)
return servers


Expand Down Expand Up @@ -413,13 +447,20 @@ async def _send_command(self, command: str) -> tuple[int, str]:
raise ProtocolNntpError("NNTP command is not ASCII encodable") from exc
except asyncio.TimeoutError as exc:
raise TransientNntpError("command timeout") from exc
except (ConnectionResetError, BrokenPipeError, OSError, asyncio.IncompleteReadError) as exc:
except (
ConnectionResetError,
BrokenPipeError,
OSError,
asyncio.IncompleteReadError,
) as exc:
raise TransientNntpError("connection lost") from exc

async def _read_response(self) -> tuple[int, str]:
assert self._reader is not None
try:
line = await asyncio.wait_for(self._reader.readline(), timeout=self.config.timeout)
line = await asyncio.wait_for(
self._reader.readline(), timeout=self.config.timeout
)
except asyncio.TimeoutError as exc:
raise TransientNntpError("read timeout") from exc
if not line:
Expand All @@ -434,7 +475,9 @@ async def _read_multiline(self) -> list[bytes]:
lines: list[bytes] = []
while True:
try:
line = await asyncio.wait_for(self._reader.readline(), timeout=self.config.timeout)
line = await asyncio.wait_for(
self._reader.readline(), timeout=self.config.timeout
)
except asyncio.TimeoutError as exc:
raise TransientNntpError("read timeout") from exc
if not line:
Expand Down Expand Up @@ -494,7 +537,9 @@ def __init__(
self._progress_was_written = False
self._progress_failed = False

async def run(self, message_ids: Iterable[str], missing_output: str | Path | None = None) -> VerificationSummary:
async def run(
self, message_ids: Iterable[str], missing_output: str | Path | None = None
) -> VerificationSummary:
start = time.monotonic()
try:
await self._start_workers()
Expand Down Expand Up @@ -522,7 +567,9 @@ async def run(self, message_ids: Iterable[str], missing_output: str | Path | Non
present=self.present,
missing=self.missing,
error=self.error,
stat_requests=sum(conn.request_count for server in self.connections for conn in server),
stat_requests=sum(
conn.request_count for server in self.connections for conn in server
),
elapsed_seconds=elapsed,
)

Expand All @@ -534,7 +581,11 @@ async def _start_workers(self) -> None:
self.workers.append(task)

await asyncio.gather(
*(connection.connect(self.retries) for server_connections in self.connections for connection in server_connections),
*(
connection.connect(self.retries)
for server_connections in self.connections
for connection in server_connections
),
return_exceptions=True,
)

Expand All @@ -548,14 +599,21 @@ async def _stop_workers(self) -> None:
for connection in server_connections:
await connection.close()

async def _worker_loop(self, server_index: int, connection: AsyncNntpConnection) -> None:
async def _worker_loop(
self, server_index: int, connection: AsyncNntpConnection
) -> None:
try:
while True:
job = await self._take_job(server_index)
if job is None:
return
try:
await self._handle_job(server_index, connection, job.message_id, job.target_server_index)
await self._handle_job(
server_index,
connection,
job.message_id,
job.target_server_index,
)
except Exception:
async with self.job_condition:
state = self.states.get(job.message_id)
Expand All @@ -581,7 +639,10 @@ async def _take_job(self, server_index: int) -> _Job | None:

def _find_job_for_server(self, server_index: int) -> _Job | None:
for job in self.jobs:
if job.target_server_index is None or job.target_server_index == server_index:
if (
job.target_server_index is None
or job.target_server_index == server_index
):
return job
return None

Expand Down Expand Up @@ -632,7 +693,9 @@ async def _handle_job(
if next_index is not None:
self._defer_message_locked(message_id, next_index)
return
await self._finalize_locked(message_id, "error" if state.had_error else "missing")
await self._finalize_locked(
message_id, "error" if state.had_error else "missing"
)
return
state.had_error = True
if next_index is not None:
Expand Down Expand Up @@ -660,7 +723,9 @@ def _defer_message_locked(self, message_id: str, server_index: int) -> None:
self.jobs.append(_Job(message_id=message_id, target_server_index=server_index))
self.job_condition.notify_all()

def _next_server_index_locked(self, state: _MessageState, current_server_index: int) -> int | None:
def _next_server_index_locked(
self, state: _MessageState, current_server_index: int
) -> int | None:
total = len(self.servers)
for offset in range(1, total + 1):
candidate = (current_server_index + offset) % total
Expand Down Expand Up @@ -742,7 +807,10 @@ async def run(
try:
await self._start()
results = await asyncio.gather(
*(self._check_one(index, message_id) for index, message_id in enumerate(message_ids))
*(
self._check_one(index, message_id)
for index, message_id in enumerate(message_ids)
)
)
finally:
await self._stop()
Expand All @@ -751,7 +819,9 @@ async def run(
with Path(deep_output).open("w", encoding="utf-8") as handle:
for result in results:
server = result.server or "-"
handle.write(f"{result.message_id}\t{result.status}\t{result.detail}\t{server}\n")
handle.write(
f"{result.message_id}\t{result.status}\t{result.detail}\t{server}\n"
)

elapsed = time.monotonic() - start
return DeepCheckSummary(
Expand All @@ -769,7 +839,11 @@ async def run(

async def _start(self) -> None:
await asyncio.gather(
*(connection.connect(self.retries) for server_connections in self.connections for connection in server_connections),
*(
connection.connect(self.retries)
for server_connections in self.connections
for connection in server_connections
),
return_exceptions=True,
)
for server_index, server_connections in enumerate(self.connections):
Expand Down Expand Up @@ -814,7 +888,9 @@ async def _check_one(self, sample_index: int, message_id: str) -> DeepCheckResul
finally:
self.connection_queues[server_index].put_nowait(connection)

return DeepCheckResult(message_id=message_id, status="error", detail=last_error, server=None)
return DeepCheckResult(
message_id=message_id, status="error", detail=last_error, server=None
)


async def verify_nzb(
Expand All @@ -827,11 +903,13 @@ async def verify_nzb(
sample_percent: float = 1.0,
sample_seed: int | None = None,
deep_output: str | Path | None = None,
progress_stream = sys.stdout,
progress_stream=sys.stdout,
) -> VerificationSummary:
servers = load_config(config_path)
verifier = _Verifier(servers, retries=retries, progress_stream=progress_stream)
summary = await verifier.run(parse_nzb_message_ids(nzb_path), missing_output=missing_output)
summary = await verifier.run(
parse_nzb_message_ids(nzb_path), missing_output=missing_output
)
if deep_check:
sampled_ids = select_deep_sample(
verifier.present_message_ids,
Expand All @@ -844,10 +922,19 @@ async def verify_nzb(


def build_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Verify NZB message IDs against NNTP servers.")
parser = argparse.ArgumentParser(
description="Verify NZB message IDs against NNTP servers."
)
parser.add_argument("nzb_path", type=Path, help="path to the NZB file")
parser.add_argument("--config", required=True, type=Path, help="path to the NNTP INI file")
parser.add_argument("--retries", type=int, default=1, help="retry count for transient network errors")
parser.add_argument(
"--config", required=True, type=Path, help="path to the NNTP INI file"
)
parser.add_argument(
"--retries",
type=int,
default=1,
help="retry count for transient network errors",
)
parser.add_argument(
"--missing-output",
type=Path,
Expand Down