Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .jules/bolt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
## 2024-06-25 - Optimize yEnc decoding via bytes.translate

**Learning:** Python's byte-by-byte iteration using a `while` loop is slow compared to C-backed functions like `bytes.translate()` and `bytes.find()`. Calculating `(byte - 42) % 256` in Python inside a loop is a performance bottleneck for intensive byte manipulations like yEnc decoding.
**Action:** When performing bulk byte transformations in Python, try to utilize `bytes.translate()` with a pre-computed translation table and use fast string search functions like `find()` rather than custom iterative algorithms.
147 changes: 112 additions & 35 deletions verify_nzb.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,33 @@ def _parse_yenc_attrs(line: bytes) -> dict[str, str]:
return attrs


_yenc_decode_table = bytes((i - 42) % 256 for i in range(256))


def _decode_yenc_lines(lines: Iterable[bytes]) -> bytes:
decoded = bytearray()
for line in lines:
index = 0
while index < len(line):
byte = line[index]
if byte == 61:
index += 1
if index >= len(line):
raise ValueError("dangling yEnc escape")
byte = (line[index] - 64) % 256
decoded.append((byte - 42) % 256)
index += 1
escape_pos = line.find(b"=")
if escape_pos == -1:
decoded.extend(line.translate(_yenc_decode_table))
continue

pos = 0
while True:
escape_pos = line.find(b"=", pos)
if escape_pos == -1:
decoded.extend(line[pos:].translate(_yenc_decode_table))
break

if escape_pos > pos:
decoded.extend(line[pos:escape_pos].translate(_yenc_decode_table))

if escape_pos + 1 >= len(line):
raise ValueError("dangling yEnc escape")

char = line[escape_pos + 1]
decoded.append((char - 106) % 256)
pos = escape_pos + 2
return bytes(decoded)


Expand All @@ -138,7 +152,9 @@ def validate_yenc_body(lines: Iterable[bytes | str]) -> YencValidationResult:
data_lines: list[bytes] = []

for raw_line in lines:
line = raw_line.encode("latin-1") if isinstance(raw_line, str) else bytes(raw_line)
line = (
raw_line.encode("latin-1") if isinstance(raw_line, str) else bytes(raw_line)
)
line = line.rstrip(b"\r\n")
if line.startswith(b"=ybegin"):
ybegin_attrs = _parse_yenc_attrs(line)
Expand Down Expand Up @@ -235,7 +251,9 @@ def select_deep_sample(
unique_ids = list(dict.fromkeys(message_ids))
if not unique_ids:
return []
sample_size = min(len(unique_ids), max(1, math.ceil(len(unique_ids) * sample_percent / 100)))
sample_size = min(
len(unique_ids), max(1, math.ceil(len(unique_ids) * sample_percent / 100))
)
if sample_size == len(unique_ids):
return unique_ids
return random.Random(sample_seed).sample(unique_ids, sample_size)
Expand Down Expand Up @@ -271,9 +289,13 @@ def load_config(path: str | Path) -> list[ServerConfig]:
if value is None
]
if missing:
raise ValueError(f"missing required options in [{section}]: {', '.join(missing)}")
raise ValueError(
f"missing required options in [{section}]: {', '.join(missing)}"
)
if (username is None) ^ (password is None):
raise ValueError(f"[{section}] must set both username and password or neither")
raise ValueError(
f"[{section}] must set both username and password or neither"
)
if max_connections < 1:
raise ValueError(f"[{section}] max_connections must be at least 1")
if port < 1:
Expand All @@ -295,7 +317,9 @@ def load_config(path: str | Path) -> list[ServerConfig]:
)

if not servers:
raise ValueError("configuration must contain at least one [server.<name>] section")
raise ValueError(
"configuration must contain at least one [server.<name>] section"
)
return servers


Expand Down Expand Up @@ -413,13 +437,20 @@ async def _send_command(self, command: str) -> tuple[int, str]:
raise ProtocolNntpError("NNTP command is not ASCII encodable") from exc
except asyncio.TimeoutError as exc:
raise TransientNntpError("command timeout") from exc
except (ConnectionResetError, BrokenPipeError, OSError, asyncio.IncompleteReadError) as exc:
except (
ConnectionResetError,
BrokenPipeError,
OSError,
asyncio.IncompleteReadError,
) as exc:
raise TransientNntpError("connection lost") from exc

async def _read_response(self) -> tuple[int, str]:
assert self._reader is not None
try:
line = await asyncio.wait_for(self._reader.readline(), timeout=self.config.timeout)
line = await asyncio.wait_for(
self._reader.readline(), timeout=self.config.timeout
)
except asyncio.TimeoutError as exc:
raise TransientNntpError("read timeout") from exc
if not line:
Expand All @@ -434,7 +465,9 @@ async def _read_multiline(self) -> list[bytes]:
lines: list[bytes] = []
while True:
try:
line = await asyncio.wait_for(self._reader.readline(), timeout=self.config.timeout)
line = await asyncio.wait_for(
self._reader.readline(), timeout=self.config.timeout
)
except asyncio.TimeoutError as exc:
raise TransientNntpError("read timeout") from exc
if not line:
Expand Down Expand Up @@ -494,7 +527,9 @@ def __init__(
self._progress_was_written = False
self._progress_failed = False

async def run(self, message_ids: Iterable[str], missing_output: str | Path | None = None) -> VerificationSummary:
async def run(
self, message_ids: Iterable[str], missing_output: str | Path | None = None
) -> VerificationSummary:
start = time.monotonic()
try:
await self._start_workers()
Expand Down Expand Up @@ -522,7 +557,9 @@ async def run(self, message_ids: Iterable[str], missing_output: str | Path | Non
present=self.present,
missing=self.missing,
error=self.error,
stat_requests=sum(conn.request_count for server in self.connections for conn in server),
stat_requests=sum(
conn.request_count for server in self.connections for conn in server
),
elapsed_seconds=elapsed,
)

Expand All @@ -534,7 +571,11 @@ async def _start_workers(self) -> None:
self.workers.append(task)

await asyncio.gather(
*(connection.connect(self.retries) for server_connections in self.connections for connection in server_connections),
*(
connection.connect(self.retries)
for server_connections in self.connections
for connection in server_connections
),
return_exceptions=True,
)

Expand All @@ -548,14 +589,21 @@ async def _stop_workers(self) -> None:
for connection in server_connections:
await connection.close()

async def _worker_loop(self, server_index: int, connection: AsyncNntpConnection) -> None:
async def _worker_loop(
self, server_index: int, connection: AsyncNntpConnection
) -> None:
try:
while True:
job = await self._take_job(server_index)
if job is None:
return
try:
await self._handle_job(server_index, connection, job.message_id, job.target_server_index)
await self._handle_job(
server_index,
connection,
job.message_id,
job.target_server_index,
)
except Exception:
async with self.job_condition:
state = self.states.get(job.message_id)
Expand All @@ -581,7 +629,10 @@ async def _take_job(self, server_index: int) -> _Job | None:

def _find_job_for_server(self, server_index: int) -> _Job | None:
for job in self.jobs:
if job.target_server_index is None or job.target_server_index == server_index:
if (
job.target_server_index is None
or job.target_server_index == server_index
):
return job
return None

Expand Down Expand Up @@ -632,7 +683,9 @@ async def _handle_job(
if next_index is not None:
self._defer_message_locked(message_id, next_index)
return
await self._finalize_locked(message_id, "error" if state.had_error else "missing")
await self._finalize_locked(
message_id, "error" if state.had_error else "missing"
)
return
state.had_error = True
if next_index is not None:
Expand Down Expand Up @@ -660,7 +713,9 @@ def _defer_message_locked(self, message_id: str, server_index: int) -> None:
self.jobs.append(_Job(message_id=message_id, target_server_index=server_index))
self.job_condition.notify_all()

def _next_server_index_locked(self, state: _MessageState, current_server_index: int) -> int | None:
def _next_server_index_locked(
self, state: _MessageState, current_server_index: int
) -> int | None:
total = len(self.servers)
for offset in range(1, total + 1):
candidate = (current_server_index + offset) % total
Expand Down Expand Up @@ -742,7 +797,10 @@ async def run(
try:
await self._start()
results = await asyncio.gather(
*(self._check_one(index, message_id) for index, message_id in enumerate(message_ids))
*(
self._check_one(index, message_id)
for index, message_id in enumerate(message_ids)
)
)
finally:
await self._stop()
Expand All @@ -751,7 +809,9 @@ async def run(
with Path(deep_output).open("w", encoding="utf-8") as handle:
for result in results:
server = result.server or "-"
handle.write(f"{result.message_id}\t{result.status}\t{result.detail}\t{server}\n")
handle.write(
f"{result.message_id}\t{result.status}\t{result.detail}\t{server}\n"
)

elapsed = time.monotonic() - start
return DeepCheckSummary(
Expand All @@ -769,7 +829,11 @@ async def run(

async def _start(self) -> None:
await asyncio.gather(
*(connection.connect(self.retries) for server_connections in self.connections for connection in server_connections),
*(
connection.connect(self.retries)
for server_connections in self.connections
for connection in server_connections
),
return_exceptions=True,
)
for server_index, server_connections in enumerate(self.connections):
Expand Down Expand Up @@ -814,7 +878,9 @@ async def _check_one(self, sample_index: int, message_id: str) -> DeepCheckResul
finally:
self.connection_queues[server_index].put_nowait(connection)

return DeepCheckResult(message_id=message_id, status="error", detail=last_error, server=None)
return DeepCheckResult(
message_id=message_id, status="error", detail=last_error, server=None
)


async def verify_nzb(
Expand All @@ -827,11 +893,13 @@ async def verify_nzb(
sample_percent: float = 1.0,
sample_seed: int | None = None,
deep_output: str | Path | None = None,
progress_stream = sys.stdout,
progress_stream=sys.stdout,
) -> VerificationSummary:
servers = load_config(config_path)
verifier = _Verifier(servers, retries=retries, progress_stream=progress_stream)
summary = await verifier.run(parse_nzb_message_ids(nzb_path), missing_output=missing_output)
summary = await verifier.run(
parse_nzb_message_ids(nzb_path), missing_output=missing_output
)
if deep_check:
sampled_ids = select_deep_sample(
verifier.present_message_ids,
Expand All @@ -844,10 +912,19 @@ async def verify_nzb(


def build_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Verify NZB message IDs against NNTP servers.")
parser = argparse.ArgumentParser(
description="Verify NZB message IDs against NNTP servers."
)
parser.add_argument("nzb_path", type=Path, help="path to the NZB file")
parser.add_argument("--config", required=True, type=Path, help="path to the NNTP INI file")
parser.add_argument("--retries", type=int, default=1, help="retry count for transient network errors")
parser.add_argument(
"--config", required=True, type=Path, help="path to the NNTP INI file"
)
parser.add_argument(
"--retries",
type=int,
default=1,
help="retry count for transient network errors",
)
parser.add_argument(
"--missing-output",
type=Path,
Expand Down