diff --git a/.jules/bolt.md b/.jules/bolt.md new file mode 100644 index 0000000..7ab8475 --- /dev/null +++ b/.jules/bolt.md @@ -0,0 +1,3 @@ +## 2024-06-28 - Optimize yEnc decoding +**Learning:** For Python byte string processing, such as yEnc decoding, using C-backed built-in methods like `bytes.translate()` and `bytes.find()` is significantly faster than manual byte-by-byte iteration. The `translate` method can handle the simple subtraction without looping in Python. When checking for dangling escapes, it is safer to check the bounds `escape_pos + 1 >= len(line)` rather than looking at the last byte. Furthermore, the standard logic for calculating escaped yEnc bytes `(char - 64 - 42) % 256` can be algebraically simplified to `(char - 106) % 256`. Avoid `bytes.split()` since it struggles with consecutive escapes (e.g., `==`). +**Action:** Always prioritize `bytes.translate()` and `bytes.find()` over manual loops when processing large binary strings in Python. Use simplified modular math when possible. diff --git a/verify_nzb.py b/verify_nzb.py index 953dccd..520d82b 100644 --- a/verify_nzb.py +++ b/verify_nzb.py @@ -115,19 +115,43 @@ def _parse_yenc_attrs(line: bytes) -> dict[str, str]: return attrs +_YENC_TRANS_TABLE = bytes((i - 42) % 256 for i in range(256)) + + def _decode_yenc_lines(lines: Iterable[bytes]) -> bytes: + """ + Decodes yEnc encoded lines into bytes. + Optimized to use C-backed bytes.translate() and bytes.find() + instead of iterating byte-by-byte in Python. + """ decoded = bytearray() + for line in lines: - index = 0 - while index < len(line): - byte = line[index] - if byte == 61: - index += 1 - if index >= len(line): + escape_pos = line.find(b"=") + if escape_pos == -1: + # Fast path: No escapes in this line + decoded.extend(line.translate(_YENC_TRANS_TABLE)) + else: + start = 0 + while True: + escape_pos = line.find(b"=", start) + if escape_pos == -1: + # Translate the remainder of the line + decoded.extend(line[start:].translate(_YENC_TRANS_TABLE)) + break + + if escape_pos > start: + # Translate characters before the escape + decoded.extend(line[start:escape_pos].translate(_YENC_TRANS_TABLE)) + + if escape_pos + 1 >= len(line): raise ValueError("dangling yEnc escape") - byte = (line[index] - 64) % 256 - decoded.append((byte - 42) % 256) - index += 1 + + # Handle the escaped character: + # (char - 64 - 42) % 256 algebraically simplifies to (char - 106) % 256 + decoded.append((line[escape_pos + 1] - 106) % 256) + start = escape_pos + 2 + return bytes(decoded) @@ -138,7 +162,9 @@ def validate_yenc_body(lines: Iterable[bytes | str]) -> YencValidationResult: data_lines: list[bytes] = [] for raw_line in lines: - line = raw_line.encode("latin-1") if isinstance(raw_line, str) else bytes(raw_line) + line = ( + raw_line.encode("latin-1") if isinstance(raw_line, str) else bytes(raw_line) + ) line = line.rstrip(b"\r\n") if line.startswith(b"=ybegin"): ybegin_attrs = _parse_yenc_attrs(line) @@ -235,7 +261,9 @@ def select_deep_sample( unique_ids = list(dict.fromkeys(message_ids)) if not unique_ids: return [] - sample_size = min(len(unique_ids), max(1, math.ceil(len(unique_ids) * sample_percent / 100))) + sample_size = min( + len(unique_ids), max(1, math.ceil(len(unique_ids) * sample_percent / 100)) + ) if sample_size == len(unique_ids): return unique_ids return random.Random(sample_seed).sample(unique_ids, sample_size) @@ -271,9 +299,13 @@ def load_config(path: str | Path) -> list[ServerConfig]: if value is None ] if missing: - raise ValueError(f"missing required options in [{section}]: {', '.join(missing)}") + raise ValueError( + f"missing required options in [{section}]: {', '.join(missing)}" + ) if (username is None) ^ (password is None): - raise ValueError(f"[{section}] must set both username and password or neither") + raise ValueError( + f"[{section}] must set both username and password or neither" + ) if max_connections < 1: raise ValueError(f"[{section}] max_connections must be at least 1") if port < 1: @@ -295,7 +327,9 @@ def load_config(path: str | Path) -> list[ServerConfig]: ) if not servers: - raise ValueError("configuration must contain at least one [server.] section") + raise ValueError( + "configuration must contain at least one [server.] section" + ) return servers @@ -413,13 +447,20 @@ async def _send_command(self, command: str) -> tuple[int, str]: raise ProtocolNntpError("NNTP command is not ASCII encodable") from exc except asyncio.TimeoutError as exc: raise TransientNntpError("command timeout") from exc - except (ConnectionResetError, BrokenPipeError, OSError, asyncio.IncompleteReadError) as exc: + except ( + ConnectionResetError, + BrokenPipeError, + OSError, + asyncio.IncompleteReadError, + ) as exc: raise TransientNntpError("connection lost") from exc async def _read_response(self) -> tuple[int, str]: assert self._reader is not None try: - line = await asyncio.wait_for(self._reader.readline(), timeout=self.config.timeout) + line = await asyncio.wait_for( + self._reader.readline(), timeout=self.config.timeout + ) except asyncio.TimeoutError as exc: raise TransientNntpError("read timeout") from exc if not line: @@ -434,7 +475,9 @@ async def _read_multiline(self) -> list[bytes]: lines: list[bytes] = [] while True: try: - line = await asyncio.wait_for(self._reader.readline(), timeout=self.config.timeout) + line = await asyncio.wait_for( + self._reader.readline(), timeout=self.config.timeout + ) except asyncio.TimeoutError as exc: raise TransientNntpError("read timeout") from exc if not line: @@ -494,7 +537,9 @@ def __init__( self._progress_was_written = False self._progress_failed = False - async def run(self, message_ids: Iterable[str], missing_output: str | Path | None = None) -> VerificationSummary: + async def run( + self, message_ids: Iterable[str], missing_output: str | Path | None = None + ) -> VerificationSummary: start = time.monotonic() try: await self._start_workers() @@ -522,7 +567,9 @@ async def run(self, message_ids: Iterable[str], missing_output: str | Path | Non present=self.present, missing=self.missing, error=self.error, - stat_requests=sum(conn.request_count for server in self.connections for conn in server), + stat_requests=sum( + conn.request_count for server in self.connections for conn in server + ), elapsed_seconds=elapsed, ) @@ -534,7 +581,11 @@ async def _start_workers(self) -> None: self.workers.append(task) await asyncio.gather( - *(connection.connect(self.retries) for server_connections in self.connections for connection in server_connections), + *( + connection.connect(self.retries) + for server_connections in self.connections + for connection in server_connections + ), return_exceptions=True, ) @@ -548,14 +599,21 @@ async def _stop_workers(self) -> None: for connection in server_connections: await connection.close() - async def _worker_loop(self, server_index: int, connection: AsyncNntpConnection) -> None: + async def _worker_loop( + self, server_index: int, connection: AsyncNntpConnection + ) -> None: try: while True: job = await self._take_job(server_index) if job is None: return try: - await self._handle_job(server_index, connection, job.message_id, job.target_server_index) + await self._handle_job( + server_index, + connection, + job.message_id, + job.target_server_index, + ) except Exception: async with self.job_condition: state = self.states.get(job.message_id) @@ -581,7 +639,10 @@ async def _take_job(self, server_index: int) -> _Job | None: def _find_job_for_server(self, server_index: int) -> _Job | None: for job in self.jobs: - if job.target_server_index is None or job.target_server_index == server_index: + if ( + job.target_server_index is None + or job.target_server_index == server_index + ): return job return None @@ -632,7 +693,9 @@ async def _handle_job( if next_index is not None: self._defer_message_locked(message_id, next_index) return - await self._finalize_locked(message_id, "error" if state.had_error else "missing") + await self._finalize_locked( + message_id, "error" if state.had_error else "missing" + ) return state.had_error = True if next_index is not None: @@ -660,7 +723,9 @@ def _defer_message_locked(self, message_id: str, server_index: int) -> None: self.jobs.append(_Job(message_id=message_id, target_server_index=server_index)) self.job_condition.notify_all() - def _next_server_index_locked(self, state: _MessageState, current_server_index: int) -> int | None: + def _next_server_index_locked( + self, state: _MessageState, current_server_index: int + ) -> int | None: total = len(self.servers) for offset in range(1, total + 1): candidate = (current_server_index + offset) % total @@ -742,7 +807,10 @@ async def run( try: await self._start() results = await asyncio.gather( - *(self._check_one(index, message_id) for index, message_id in enumerate(message_ids)) + *( + self._check_one(index, message_id) + for index, message_id in enumerate(message_ids) + ) ) finally: await self._stop() @@ -751,7 +819,9 @@ async def run( with Path(deep_output).open("w", encoding="utf-8") as handle: for result in results: server = result.server or "-" - handle.write(f"{result.message_id}\t{result.status}\t{result.detail}\t{server}\n") + handle.write( + f"{result.message_id}\t{result.status}\t{result.detail}\t{server}\n" + ) elapsed = time.monotonic() - start return DeepCheckSummary( @@ -769,7 +839,11 @@ async def run( async def _start(self) -> None: await asyncio.gather( - *(connection.connect(self.retries) for server_connections in self.connections for connection in server_connections), + *( + connection.connect(self.retries) + for server_connections in self.connections + for connection in server_connections + ), return_exceptions=True, ) for server_index, server_connections in enumerate(self.connections): @@ -814,7 +888,9 @@ async def _check_one(self, sample_index: int, message_id: str) -> DeepCheckResul finally: self.connection_queues[server_index].put_nowait(connection) - return DeepCheckResult(message_id=message_id, status="error", detail=last_error, server=None) + return DeepCheckResult( + message_id=message_id, status="error", detail=last_error, server=None + ) async def verify_nzb( @@ -827,11 +903,13 @@ async def verify_nzb( sample_percent: float = 1.0, sample_seed: int | None = None, deep_output: str | Path | None = None, - progress_stream = sys.stdout, + progress_stream=sys.stdout, ) -> VerificationSummary: servers = load_config(config_path) verifier = _Verifier(servers, retries=retries, progress_stream=progress_stream) - summary = await verifier.run(parse_nzb_message_ids(nzb_path), missing_output=missing_output) + summary = await verifier.run( + parse_nzb_message_ids(nzb_path), missing_output=missing_output + ) if deep_check: sampled_ids = select_deep_sample( verifier.present_message_ids, @@ -844,10 +922,19 @@ async def verify_nzb( def build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="Verify NZB message IDs against NNTP servers.") + parser = argparse.ArgumentParser( + description="Verify NZB message IDs against NNTP servers." + ) parser.add_argument("nzb_path", type=Path, help="path to the NZB file") - parser.add_argument("--config", required=True, type=Path, help="path to the NNTP INI file") - parser.add_argument("--retries", type=int, default=1, help="retry count for transient network errors") + parser.add_argument( + "--config", required=True, type=Path, help="path to the NNTP INI file" + ) + parser.add_argument( + "--retries", + type=int, + default=1, + help="retry count for transient network errors", + ) parser.add_argument( "--missing-output", type=Path,