From b948f8522bee00f884ec99ebe9dd6fc6bb7ec5d0 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Thu, 25 Jun 2026 00:30:33 +0000 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=20Bolt:=20Optimize=20yEnc=20decoding?= =?UTF-8?q?=20via=20bytes.translate?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: xbmc4lyfe <273732874+xbmc4lyfe@users.noreply.github.com> --- .jules/bolt.md | 4 ++ verify_nzb.py | 147 +++++++++++++++++++++++++++++++++++++------------ 2 files changed, 116 insertions(+), 35 deletions(-) create mode 100644 .jules/bolt.md diff --git a/.jules/bolt.md b/.jules/bolt.md new file mode 100644 index 0000000..d454cba --- /dev/null +++ b/.jules/bolt.md @@ -0,0 +1,4 @@ +## 2024-06-25 - Optimize yEnc decoding via bytes.translate + +**Learning:** Python's byte-by-byte iteration using a `while` loop is slow compared to C-backed functions like `bytes.translate()` and `bytes.find()`. Calculating `(byte - 42) % 256` in Python inside a loop is a performance bottleneck for intensive byte manipulations like yEnc decoding. +**Action:** When performing bulk byte transformations in Python, try to utilize `bytes.translate()` with a pre-computed translation table and use fast string search functions like `find()` rather than custom iterative algorithms. diff --git a/verify_nzb.py b/verify_nzb.py index 953dccd..fb2c011 100644 --- a/verify_nzb.py +++ b/verify_nzb.py @@ -115,19 +115,33 @@ def _parse_yenc_attrs(line: bytes) -> dict[str, str]: return attrs +_yenc_decode_table = bytes((i - 42) % 256 for i in range(256)) + + def _decode_yenc_lines(lines: Iterable[bytes]) -> bytes: decoded = bytearray() for line in lines: - index = 0 - while index < len(line): - byte = line[index] - if byte == 61: - index += 1 - if index >= len(line): - raise ValueError("dangling yEnc escape") - byte = (line[index] - 64) % 256 - decoded.append((byte - 42) % 256) - index += 1 + escape_pos = line.find(b"=") + if escape_pos == -1: + decoded.extend(line.translate(_yenc_decode_table)) + continue + + pos = 0 + while True: + escape_pos = line.find(b"=", pos) + if escape_pos == -1: + decoded.extend(line[pos:].translate(_yenc_decode_table)) + break + + if escape_pos > pos: + decoded.extend(line[pos:escape_pos].translate(_yenc_decode_table)) + + if escape_pos + 1 >= len(line): + raise ValueError("dangling yEnc escape") + + char = line[escape_pos + 1] + decoded.append((char - 106) % 256) + pos = escape_pos + 2 return bytes(decoded) @@ -138,7 +152,9 @@ def validate_yenc_body(lines: Iterable[bytes | str]) -> YencValidationResult: data_lines: list[bytes] = [] for raw_line in lines: - line = raw_line.encode("latin-1") if isinstance(raw_line, str) else bytes(raw_line) + line = ( + raw_line.encode("latin-1") if isinstance(raw_line, str) else bytes(raw_line) + ) line = line.rstrip(b"\r\n") if line.startswith(b"=ybegin"): ybegin_attrs = _parse_yenc_attrs(line) @@ -235,7 +251,9 @@ def select_deep_sample( unique_ids = list(dict.fromkeys(message_ids)) if not unique_ids: return [] - sample_size = min(len(unique_ids), max(1, math.ceil(len(unique_ids) * sample_percent / 100))) + sample_size = min( + len(unique_ids), max(1, math.ceil(len(unique_ids) * sample_percent / 100)) + ) if sample_size == len(unique_ids): return unique_ids return random.Random(sample_seed).sample(unique_ids, sample_size) @@ -271,9 +289,13 @@ def load_config(path: str | Path) -> list[ServerConfig]: if value is None ] if missing: - raise ValueError(f"missing required options in [{section}]: {', '.join(missing)}") + raise ValueError( + f"missing required options in [{section}]: {', '.join(missing)}" + ) if (username is None) ^ (password is None): - raise ValueError(f"[{section}] must set both username and password or neither") + raise ValueError( + f"[{section}] must set both username and password or neither" + ) if max_connections < 1: raise ValueError(f"[{section}] max_connections must be at least 1") if port < 1: @@ -295,7 +317,9 @@ def load_config(path: str | Path) -> list[ServerConfig]: ) if not servers: - raise ValueError("configuration must contain at least one [server.] section") + raise ValueError( + "configuration must contain at least one [server.] section" + ) return servers @@ -413,13 +437,20 @@ async def _send_command(self, command: str) -> tuple[int, str]: raise ProtocolNntpError("NNTP command is not ASCII encodable") from exc except asyncio.TimeoutError as exc: raise TransientNntpError("command timeout") from exc - except (ConnectionResetError, BrokenPipeError, OSError, asyncio.IncompleteReadError) as exc: + except ( + ConnectionResetError, + BrokenPipeError, + OSError, + asyncio.IncompleteReadError, + ) as exc: raise TransientNntpError("connection lost") from exc async def _read_response(self) -> tuple[int, str]: assert self._reader is not None try: - line = await asyncio.wait_for(self._reader.readline(), timeout=self.config.timeout) + line = await asyncio.wait_for( + self._reader.readline(), timeout=self.config.timeout + ) except asyncio.TimeoutError as exc: raise TransientNntpError("read timeout") from exc if not line: @@ -434,7 +465,9 @@ async def _read_multiline(self) -> list[bytes]: lines: list[bytes] = [] while True: try: - line = await asyncio.wait_for(self._reader.readline(), timeout=self.config.timeout) + line = await asyncio.wait_for( + self._reader.readline(), timeout=self.config.timeout + ) except asyncio.TimeoutError as exc: raise TransientNntpError("read timeout") from exc if not line: @@ -494,7 +527,9 @@ def __init__( self._progress_was_written = False self._progress_failed = False - async def run(self, message_ids: Iterable[str], missing_output: str | Path | None = None) -> VerificationSummary: + async def run( + self, message_ids: Iterable[str], missing_output: str | Path | None = None + ) -> VerificationSummary: start = time.monotonic() try: await self._start_workers() @@ -522,7 +557,9 @@ async def run(self, message_ids: Iterable[str], missing_output: str | Path | Non present=self.present, missing=self.missing, error=self.error, - stat_requests=sum(conn.request_count for server in self.connections for conn in server), + stat_requests=sum( + conn.request_count for server in self.connections for conn in server + ), elapsed_seconds=elapsed, ) @@ -534,7 +571,11 @@ async def _start_workers(self) -> None: self.workers.append(task) await asyncio.gather( - *(connection.connect(self.retries) for server_connections in self.connections for connection in server_connections), + *( + connection.connect(self.retries) + for server_connections in self.connections + for connection in server_connections + ), return_exceptions=True, ) @@ -548,14 +589,21 @@ async def _stop_workers(self) -> None: for connection in server_connections: await connection.close() - async def _worker_loop(self, server_index: int, connection: AsyncNntpConnection) -> None: + async def _worker_loop( + self, server_index: int, connection: AsyncNntpConnection + ) -> None: try: while True: job = await self._take_job(server_index) if job is None: return try: - await self._handle_job(server_index, connection, job.message_id, job.target_server_index) + await self._handle_job( + server_index, + connection, + job.message_id, + job.target_server_index, + ) except Exception: async with self.job_condition: state = self.states.get(job.message_id) @@ -581,7 +629,10 @@ async def _take_job(self, server_index: int) -> _Job | None: def _find_job_for_server(self, server_index: int) -> _Job | None: for job in self.jobs: - if job.target_server_index is None or job.target_server_index == server_index: + if ( + job.target_server_index is None + or job.target_server_index == server_index + ): return job return None @@ -632,7 +683,9 @@ async def _handle_job( if next_index is not None: self._defer_message_locked(message_id, next_index) return - await self._finalize_locked(message_id, "error" if state.had_error else "missing") + await self._finalize_locked( + message_id, "error" if state.had_error else "missing" + ) return state.had_error = True if next_index is not None: @@ -660,7 +713,9 @@ def _defer_message_locked(self, message_id: str, server_index: int) -> None: self.jobs.append(_Job(message_id=message_id, target_server_index=server_index)) self.job_condition.notify_all() - def _next_server_index_locked(self, state: _MessageState, current_server_index: int) -> int | None: + def _next_server_index_locked( + self, state: _MessageState, current_server_index: int + ) -> int | None: total = len(self.servers) for offset in range(1, total + 1): candidate = (current_server_index + offset) % total @@ -742,7 +797,10 @@ async def run( try: await self._start() results = await asyncio.gather( - *(self._check_one(index, message_id) for index, message_id in enumerate(message_ids)) + *( + self._check_one(index, message_id) + for index, message_id in enumerate(message_ids) + ) ) finally: await self._stop() @@ -751,7 +809,9 @@ async def run( with Path(deep_output).open("w", encoding="utf-8") as handle: for result in results: server = result.server or "-" - handle.write(f"{result.message_id}\t{result.status}\t{result.detail}\t{server}\n") + handle.write( + f"{result.message_id}\t{result.status}\t{result.detail}\t{server}\n" + ) elapsed = time.monotonic() - start return DeepCheckSummary( @@ -769,7 +829,11 @@ async def run( async def _start(self) -> None: await asyncio.gather( - *(connection.connect(self.retries) for server_connections in self.connections for connection in server_connections), + *( + connection.connect(self.retries) + for server_connections in self.connections + for connection in server_connections + ), return_exceptions=True, ) for server_index, server_connections in enumerate(self.connections): @@ -814,7 +878,9 @@ async def _check_one(self, sample_index: int, message_id: str) -> DeepCheckResul finally: self.connection_queues[server_index].put_nowait(connection) - return DeepCheckResult(message_id=message_id, status="error", detail=last_error, server=None) + return DeepCheckResult( + message_id=message_id, status="error", detail=last_error, server=None + ) async def verify_nzb( @@ -827,11 +893,13 @@ async def verify_nzb( sample_percent: float = 1.0, sample_seed: int | None = None, deep_output: str | Path | None = None, - progress_stream = sys.stdout, + progress_stream=sys.stdout, ) -> VerificationSummary: servers = load_config(config_path) verifier = _Verifier(servers, retries=retries, progress_stream=progress_stream) - summary = await verifier.run(parse_nzb_message_ids(nzb_path), missing_output=missing_output) + summary = await verifier.run( + parse_nzb_message_ids(nzb_path), missing_output=missing_output + ) if deep_check: sampled_ids = select_deep_sample( verifier.present_message_ids, @@ -844,10 +912,19 @@ async def verify_nzb( def build_arg_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="Verify NZB message IDs against NNTP servers.") + parser = argparse.ArgumentParser( + description="Verify NZB message IDs against NNTP servers." + ) parser.add_argument("nzb_path", type=Path, help="path to the NZB file") - parser.add_argument("--config", required=True, type=Path, help="path to the NNTP INI file") - parser.add_argument("--retries", type=int, default=1, help="retry count for transient network errors") + parser.add_argument( + "--config", required=True, type=Path, help="path to the NNTP INI file" + ) + parser.add_argument( + "--retries", + type=int, + default=1, + help="retry count for transient network errors", + ) parser.add_argument( "--missing-output", type=Path,