Skip to content

Lightning Trainer for MA OSS#1213

Merged
dkurra merged 9 commits into
mainfrom
trainer-snapshot-from-internal
May 22, 2026
Merged

Lightning Trainer for MA OSS#1213
dkurra merged 9 commits into
mainfrom
trainer-snapshot-from-internal

Conversation

@dkurra
Copy link
Copy Markdown
Contributor

@dkurra dkurra commented May 20, 2026

PyTorch Lightning trainer + MovieLens demo

Summary

One-time snapshot of the internal Uber Michelangelo sdk/trainer/torch/pytorch_lightning/ package into OSS, plus a runnable MovieLens-100k NCF example that exercises it end-to-end. Per the original Slack thread: snapshot only, no upstream sync.

Goal: give external users a thin Ray Train wrapper around PyTorch Lightning that the Michelangelo team uses internally, without the closed-source orchestration layer (workflow/, model_manager/, artifact/, CFS filesystem, M3 metrics) it normally sits behind.

Screenshot 2026-05-20 at 4 33 58 PM

What's in the snapshot

8 files, ~1,400 LOC under python/michelangelo/lib/trainer/torch/:

File LOC Role
pytorch_lightning/lightning_trainer.py 184 Public API: LightningTrainer (Ray TorchTrainer subclass), LightningTrainerParam, CometParam, LightningTrainerWithStateDict
pytorch_lightning/_private/util.py 452 _train_loop_per_worker — the per-Ray-worker training loop. Builds model, applies warm-start, resolves strategy/logger/callbacks, wraps Ray dataset shards, fits via pl.Trainer
pytorch_lightning/_private/callbacks.py 114 RayTrainReportCallback + per-node variant — bridges Lightning's epoch-end hooks to ray.train.report(...) for checkpointing
pytorch_lightning/schema.py 76 Warm-start types: TransferLearningSpec, IncrementalTrainingSpec, ModelSpec, TrainingType/LearningMode enums
data_collate_functions.py 259 Ray Data → torch.Tensor collation. Default literal_eval_data_collate_function + ragged-list padding
_numpy_utils.py 167 Inlined pad_ragged_tensor + sentinel-dtype helpers (replaces internal uber.shared.utils.numpy_utils.*)
utils.py 131 Standalone memory-footprint estimators for transformers + generic nn.Module (used by callers to size batches)

Execution flow (one Ray worker)

LightningTrainer(...).train()
  └── ray.train.torch.TorchTrainer spawns N workers
      └── _train_loop_per_worker(config)        ◄── _private/util.py
          ├── model = create_model_fn(**kwargs)
          ├── _load_weights_from_path / _apply_layer_freeze   (if specs set)
          ├── strategy  = _resolve_strategy()    # RayDDPStrategy by default
          ├── logger    = _resolve_logger(comet_param OR lightning_trainer_kwargs["logger"])
          ├── callbacks = [..., RayTrainReportCallback]   ◄── _private/callbacks.py
          ├── data_iter = wrap(get_dataset_shard("train"), collate=literal_eval_data_collate_function)
          ├── pl_trainer = ray.train.lightning.prepare_trainer(pl.Trainer(...))
          └── pl_trainer.fit(model, train_iter, val_iter)
              └── each epoch end → ray.train.report(metrics, checkpoint)

What's intentionally stripped from the internal version

To make the snapshot standalone:

  • ArtifactStore epoch-resume — internal trainer polls CFS for the latest checkpoint and resumes mid-run. OSS users start from local paths only.
  • M3 metric countersMichelangeloStatsLogger, CANVAS_V2_TRAINING_JOB_RESULT_METRIC, CHECKPOINT_UPLOAD_METRIC all removed.
  • register_cfs_into_fsspec — Uber-internal filesystem; OSS users use local / S3 / GCS via standard fsspec.
  • Rich UserException — replaced with a minimal local class.
  • Internal numpy_utils / sentinel packages — relevant helpers inlined into _numpy_utils.py.

Result: zero uber.* imports in the snapshot.

MovieLens-100k demo (python/examples/movielens/)

Smallest viable smoke test for the trainer. Trains a tiny NCF (~92K params) on CPU with one Ray Train worker, 3 epochs, ~5s wall time on the validation run.

File Role
data.py Downloads ml-100k (~5 MB), remaps user/item IDs to dense indices, returns Ray Datasets. Has a GitHub mirror fallback for sandboxed environments.
model.py NCFLightningModule: user + item embeddings → 2-layer MLP → sigmoid, MSE on [0,1]-normalized ratings.
train.py Wires up LightningTrainerParam, LightningTrainer, RunConfig, ScalingConfig. Optionally enables Comet or MLflow tracking via env vars (Comet wins if both set).
README.md One-shot run instructions + env-var recipes for each tracking backend.

What the demo exercises end-to-end:

  • Loading LightningTrainer + LightningTrainerParam from the snapshot.
  • _train_loop_per_worker for a non-trivial Lightning fit, including RayTrainReportCallback epoch checkpointing.
  • Default Ray Data → torch tensor collation (no custom data_collate_fn).
  • Resolving the default RayDDPStrategy even with a single worker.
  • The _resolve_logger Comet path (when Comet env vars are set).
  • The _resolve_logger pre-built-pl.Logger path (when MLflow env vars are set).

pyproject changes

Added two optional dependencies + a trainer extras group:

comet_ml  = { version = "^3.49.0", optional = true }
deepspeed = { version = "^0.14.0", optional = true }

trainer = ["ray", "torch", "pytorch_lightning", "transformers", "numpy", "comet_ml", "deepspeed"]

How to run

cd python/
poetry install --extras "trainer example"
python -m examples.movielens.train

First run downloads MovieLens-100k to /tmp/movielens_data/; checkpoints land in /tmp/movielens_runs/ncf_movielens100k/.

Opt-in tracking:

# Comet
export COMET_API_KEY=... COMET_WORKSPACE=...
python -m examples.movielens.train

# OR MLflow (Comet wins if both env-sets are present)
export MLFLOW_TRACKING_URI=file:///tmp/mlflow_movielens
python -m examples.movielens.train
mlflow ui --backend-store-uri file:///tmp/mlflow_movielens

Test plan

  • python -m examples.movielens.train completes in ~5s, 3 epochs, train loss 0.063 → 0.054, val checkpoints written.
  • MLflow path verified: experiment created, run FINISHED, all metrics + hyperparameters + tags logged.
  • Comet path verified by external reviewer with a Comet workspace.
  • poetry install --extras "trainer example" resolves cleanly on a fresh checkout.

Commits

  1. Snapshot internal Lightning trainer into lib/trainer/torch — the core port + pyproject deps.
  2. Add MovieLens-100k NCF example exercising lib/trainer/torchexamples/movielens/{data,model,train,README}.py.
  3. examples/movielens: wire optional Comet logging via env vars.
  4. examples/movielens: add MLflow as an alternate optional tracker.

Out of scope (deliberately)

  • Sync mechanism with internal repo. This is a one-time snapshot.
  • The non-Lightning subpackages of the internal trainer (comet/, custom/, huggingface/).
  • CFS / ArtifactStore equivalents — out of scope for the Lightning trainer alone.

Imports the LightningTrainer + supporting code from the internal
Michelangelo SDK (uber/ai/michelangelo/sdk/trainer/torch/*) as a
one-time snapshot. Goal is to keep CanvasFlex (internal) and OSS on a
common trainer surface, per discussion in #ml-platform.

New files:
- python/michelangelo/lib/trainer/torch/_numpy_utils.py
    Inlined pad_ragged_tensor / sentinel_for_numpy_dtype / infer_dtype +
    sentinel constants (ported from shared/utils/numpy_utils).
- python/michelangelo/lib/trainer/torch/data_collate_functions.py
    LiteralEvalFloat32Collate + collate helpers used by Ray Data.
- python/michelangelo/lib/trainer/torch/utils.py
    Training-memory estimators (transformers and nn.Module).
- python/michelangelo/lib/trainer/torch/pytorch_lightning/schema.py
    TransferLearningSpec / IncrementalTrainingSpec dataclasses + enums.
- python/michelangelo/lib/trainer/torch/pytorch_lightning/_private/__init__.py
- python/michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py
    RayTrainReportCallback + RayTrainReportPerNodeCallback.
- python/michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py
    _train_loop_per_worker, strategy/plugin/logger/callbacks resolvers,
    Comet logger, init-weights loader, layer-freeze re-application.

Replaced:
- python/michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py
    The previous OSS skeleton (LightningTrainer composing TorchTrainer
    with create_run_config/create_scaling_config helpers) is replaced
    with the internal LightningTrainer + LightningTrainerWithStateDict
    that subclass ray.train.torch.TorchTrainer directly. Callers now
    construct ray.train.RunConfig / ScalingConfig themselves.

Stripped from the internal source (no OSS equivalent):
- ArtifactStoreTrainUtils / epoch-resume bookkeeping
- M3 metric counters (MichelangeloStatsLogger, m3_gauge_time, ...)
- CFS fsspec registration (uber.core.fsspec_cfs)
- workflow.framework.exceptions.UserException -> local class
- model_manager reflection_utils.get_module_attr -> inlined

Packaging:
- pyproject.toml: adds comet_ml + deepspeed as optional deps and a
  new `trainer` extras group covering the full runtime needed to
  import the lib/trainer/torch tree.

Out of scope (separate PR per ownership split):
- sdk/native_transform/ray/* (separate owner)
- Tests under sdk/trainer/torch/**/tests/
@CLAassistant
Copy link
Copy Markdown

CLAassistant commented May 20, 2026

CLA assistant check
All committers have signed the CLA.

@dkurra dkurra marked this pull request as draft May 20, 2026 20:32
dkurra added 2 commits May 20, 2026 21:05
A minimal end-to-end smoke test for the LightningTrainer snapshot just
added in the prior commit. Trains a small Neural Collaborative Filtering
model (user + item embeddings -> 2-layer MLP -> sigmoid, MSE loss) on
MovieLens-100k with a single Ray Train worker on CPU.

Files:
- python/examples/movielens/__init__.py
- python/examples/movielens/data.py
    Downloads ml-100k from grouplens; falls back to a github mirror if
    the canonical host is unreachable. Builds dense user/item indices,
    normalizes ratings to [0, 1], splits 80/20 train/val, returns Ray
    datasets.
- python/examples/movielens/model.py
    NCFLightningModule + create_ncf_model factory used as
    LightningTrainerParam.create_model_fn.
- python/examples/movielens/train.py
    Wires LightningTrainerParam + LightningTrainer +
    ray.train.{RunConfig, ScalingConfig} for a 3-epoch CPU run.
- python/examples/movielens/README.md
    Install + run instructions.

Verified locally: ml-100k downloads, model builds (92.4K params), Ray
Train worker spins up, 3 epochs complete in ~5s on CPU, three
checkpoints saved to /tmp/movielens_runs/, final result returned with
checkpoint_path / path / metrics keys matching internal API.
`LightningTrainer` only attaches a `CometLogger` when `comet_param` is
provided on `LightningTrainerParam`. Until now the MovieLens demo never
set it, so the Comet code path in lib/trainer/torch was untested by the
example.

Adds `_build_comet_param()` in train.py that constructs a `CometParam`
from COMET_API_KEY + COMET_WORKSPACE (required), with optional
COMET_PROJECT_NAME / COMET_EXPERIMENT_NAME / COMET_TAGS overrides. When
the required vars are unset the function returns None and the trainer
falls back to Lightning's default local logger (prior behavior).

README documents the env-var contract.

Verified locally:
- With no env vars set: logs "Comet logging disabled (...)" and completes
  training using Lightning's default logger.
@dkurra dkurra force-pushed the trainer-snapshot-from-internal branch from 220c096 to 387c729 Compare May 20, 2026 21:47
Comet was the only tracking backend the demo exposed. Adds MLflow as a
second opt-in path. Users pick at most one per run — Comet wins if both
env-sets are present, otherwise MLflow if MLFLOW_TRACKING_URI is set,
otherwise Lightning's default local logger.

`_build_mlflow_logger()` constructs `pytorch_lightning.loggers.MLFlowLogger`
directly from env vars (lazy-imported so MLflow isn't required when only
Comet or neither is in use). The logger instance is passed through
`lightning_trainer_kwargs["logger"]` — `_resolve_logger` already forwards
a pre-built Logger instance unchanged, so no trainer changes are needed.

Env contract for MLflow:
- MLFLOW_TRACKING_URI         required (e.g. file:///tmp/mlflow_run, http://...)
- MLFLOW_EXPERIMENT_NAME      optional, default ncf-movielens100k
- MLFLOW_RUN_NAME             optional
- MLFLOW_TAGS                 optional, comma-separated key=value pairs

README documents both Comet and MLflow paths with copy-paste env-var
recipes and notes that Comet wins precedence.

Verified locally:
- No env vars: "Experiment tracking disabled (no COMET_* or MLFLOW_TRACKING_URI env vars set)"
- MLFLOW_TRACKING_URI=file:///tmp/mlflow_movielens: experiment created,
  run FINISHED, val_loss/train_loss/epoch + all 5 hyperparameters logged,
  custom tags applied. Verified via mlflow.search_runs() against the
  file store.

mlflow is already in OSS pyproject's `example` extras; no pyproject change.
@dkurra dkurra marked this pull request as ready for review May 20, 2026 23:41
@dkurra dkurra marked this pull request as draft May 20, 2026 23:41
Comment thread python/pyproject.toml
pytorch_lightning = { version = "2.2.0", optional = true }
einops = { version = "0.8.0", optional = true }
transformers = { version = "4.48.2", optional = true }
comet_ml = { version = "^3.49.0", optional = true }
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't have comet, can we not use it here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uber uses comet, how would you recommend to have backward compatibility?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good questions, i think we need to refactor the trainer for different observability. Stamp for now to refactor later

The previous commit (5abd603) added comet_ml + deepspeed as optional deps
and a trainer extras group to pyproject.toml but did not refresh the lock.
CI's poetry install steps (Coverage, Ruff, python-build, build-and-push)
all fail with 'pyproject.toml changed significantly since poetry.lock was
last generated'.

Regenerated with Poetry 2.2.1 to match CI's pinned version.
@github-actions
Copy link
Copy Markdown

🛠 Ruff Check & Format Results

⚠️ Format Issues Detected

Show details
--- examples/movielens/data.py
+++ examples/movielens/data.py
@@ -23,7 +23,9 @@
 # Canonical source. Some sandboxed environments can't reach files.grouplens.org;
 # fall back to a github mirror of the same u.data when the canonical URL fails.
 _DATA_URL = "https://files.grouplens.org/datasets/movielens/ml-100k.zip"
-_UDATA_MIRROR_URL = "https://raw.githubusercontent.com/vinjn/MLStudy/master/data/movielens-100k/u.data"
+_UDATA_MIRROR_URL = (
+    "https://raw.githubusercontent.com/vinjn/MLStudy/master/data/movielens-100k/u.data"
+)
 _DEFAULT_CACHE_DIR = "/tmp/movielens_data"
 _NETWORK_TIMEOUT_SECONDS = 30
 
@@ -55,17 +57,23 @@
 
     try:
         _logger.info("Downloading MovieLens-100k from %s", _DATA_URL)
-        with urllib.request.urlopen(_DATA_URL, timeout=_NETWORK_TIMEOUT_SECONDS) as resp:
+        with urllib.request.urlopen(
+            _DATA_URL, timeout=_NETWORK_TIMEOUT_SECONDS
+        ) as resp:
             data = resp.read()
         with zipfile.ZipFile(io.BytesIO(data)) as z:
             z.extractall(cache_dir)
         _logger.info("Extracted MovieLens-100k to %s", cache_dir)
         return udata_path
     except (urllib.error.URLError, TimeoutError, OSError) as exc:
-        _logger.warning("Canonical URL failed (%s); falling back to %s", exc, _UDATA_MIRROR_URL)
+        _logger.warning(
+            "Canonical URL failed (%s); falling back to %s", exc, _UDATA_MIRROR_URL
+        )
 
     os.makedirs(extracted_dir, exist_ok=True)
-    with urllib.request.urlopen(_UDATA_MIRROR_URL, timeout=_NETWORK_TIMEOUT_SECONDS) as resp:
+    with urllib.request.urlopen(
+        _UDATA_MIRROR_URL, timeout=_NETWORK_TIMEOUT_SECONDS
+    ) as resp:
         udata_bytes = resp.read()
     with open(udata_path, "wb") as f:
         f.write(udata_bytes)
@@ -91,15 +99,26 @@
         sep="\t",
         header=None,
         names=["user_id", "item_id", "rating", "timestamp"],
-        dtype={"user_id": np.int64, "item_id": np.int64, "rating": np.int64, "timestamp": np.int64},
+        dtype={
+            "user_id": np.int64,
+            "item_id": np.int64,
+            "rating": np.int64,
+            "timestamp": np.int64,
+        },
     )
     _logger.info("Loaded %d ratings", len(df))
 
-    user_id_to_idx = {uid: idx for idx, uid in enumerate(sorted(df["user_id"].unique()))}
-    item_id_to_idx = {iid: idx for idx, iid in enumerate(sorted(df["item_id"].unique()))}
+    user_id_to_idx = {
+        uid: idx for idx, uid in enumerate(sorted(df["user_id"].unique()))
+    }
+    item_id_to_idx = {
+        iid: idx for idx, iid in enumerate(sorted(df["item_id"].unique()))
+    }
     df["user_idx"] = df["user_id"].map(user_id_to_idx).astype(np.int64)
     df["item_idx"] = df["item_id"].map(item_id_to_idx).astype(np.int64)
-    df["rating_norm"] = ((df["rating"].astype(np.float32) - 1.0) / 4.0).astype(np.float32)
+    df["rating_norm"] = ((df["rating"].astype(np.float32) - 1.0) / 4.0).astype(
+        np.float32
+    )
 
     num_users = len(user_id_to_idx)
     num_items = len(item_id_to_idx)

--- examples/movielens/model.py
+++ examples/movielens/model.py
@@ -52,7 +52,14 @@
         preds = self(user_idx, item_idx)
         loss = F.mse_loss(preds, target)
         # sync_dist=True so the metric is averaged across Ray Train workers.
-        self.log(f"{stage}_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
+        self.log(
+            f"{stage}_loss",
+            loss,
+            prog_bar=True,
+            on_step=False,
+            on_epoch=True,
+            sync_dist=True,
+        )
         return loss
 
     def training_step(self, batch, batch_idx):  # noqa: ARG002

--- examples/movielens/train.py
+++ examples/movielens/train.py
@@ -35,7 +35,9 @@
     LightningTrainerParam,
 )
 
-logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
+logging.basicConfig(
+    level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
+)
 log = logging.getLogger("examples.movielens.train")
 
 _STORAGE_DIR = "/tmp/movielens_runs"
@@ -60,7 +62,9 @@
         api_key=api_key,
         workspace=workspace,
         project_name=os.environ.get("COMET_PROJECT_NAME", _DEFAULT_COMET_PROJECT),
-        experiment_name=os.environ.get("COMET_EXPERIMENT_NAME", _DEFAULT_COMET_EXPERIMENT),
+        experiment_name=os.environ.get(
+            "COMET_EXPERIMENT_NAME", _DEFAULT_COMET_EXPERIMENT
+        ),
         tags=tags,
     )
 
@@ -95,7 +99,9 @@
     from pytorch_lightning.loggers import MLFlowLogger  # noqa: PLC0415
 
     return MLFlowLogger(
-        experiment_name=os.environ.get("MLFLOW_EXPERIMENT_NAME", _DEFAULT_MLFLOW_EXPERIMENT),
+        experiment_name=os.environ.get(
+            "MLFLOW_EXPERIMENT_NAME", _DEFAULT_MLFLOW_EXPERIMENT
+        ),
         tracking_uri=tracking_uri,
         run_name=os.environ.get("MLFLOW_RUN_NAME"),
         tags=_parse_mlflow_tags(os.environ.get("MLFLOW_TAGS", "")),
@@ -116,7 +122,9 @@
             comet_param.experiment_name,
         )
         if os.environ.get("MLFLOW_TRACKING_URI"):
-            log.info("MLFLOW_TRACKING_URI is also set but Comet takes precedence; MLflow logging skipped.")
+            log.info(
+                "MLFLOW_TRACKING_URI is also set but Comet takes precedence; MLflow logging skipped."
+            )
     else:
         mlflow_logger = _build_mlflow_logger()
         if mlflow_logger is not None:
@@ -126,7 +134,9 @@
                 mlflow_logger.experiment_id,
             )
         else:
-            log.info("Experiment tracking disabled (no COMET_* or MLFLOW_TRACKING_URI env vars set)")
+            log.info(
+                "Experiment tracking disabled (no COMET_* or MLFLOW_TRACKING_URI env vars set)"
+            )
 
     lightning_trainer_kwargs = {
         # Don't pass accelerator/devices here: ray.train.lightning.prepare_trainer

--- michelangelo/lib/trainer/torch/_numpy_utils.py
+++ michelangelo/lib/trainer/torch/_numpy_utils.py
@@ -64,9 +64,17 @@
         try:
             if all(isinstance(elem, np.ndarray) and elem.ndim == 1 for elem in arr):
                 first_non_empty = next((a for a in arr if a.size > 0), None)
-                if first_non_empty is not None and first_non_empty.dtype.kind not in ("U", "S", "O"):
+                if first_non_empty is not None and first_non_empty.dtype.kind not in (
+                    "U",
+                    "S",
+                    "O",
+                ):
                     dtype = next((a.dtype for a in arr if a.size > 0), np.int32)
-                    pad_value = sentinel_for_numpy_dtype(dtype) if pad_value is None else pad_value
+                    pad_value = (
+                        sentinel_for_numpy_dtype(dtype)
+                        if pad_value is None
+                        else pad_value
+                    )
                     return _pad_1d_arrays_fast(arr, pad_value, dtype)
         except (AttributeError, TypeError):
             pass
@@ -146,7 +154,9 @@
         if i < arr_len:
             elem = arr[i] if isinstance(arr, (np.ndarray, list)) else arr
             if isinstance(elem, (list, np.ndarray)):
-                padded_elem = _pad_array_recursive(elem, target_shape, pad_value, dtype, level + 1)
+                padded_elem = _pad_array_recursive(
+                    elem, target_shape, pad_value, dtype, level + 1
+                )
             else:
                 padded_elem = elem
             padded_list.append(padded_elem)

--- michelangelo/lib/trainer/torch/data_collate_functions.py
+++ michelangelo/lib/trainer/torch/data_collate_functions.py
@@ -21,7 +21,10 @@
 import numpy as np
 import torch
 
-from michelangelo.lib.trainer.torch._numpy_utils import pad_ragged_tensor, sentinel_for_numpy_dtype
+from michelangelo.lib.trainer.torch._numpy_utils import (
+    pad_ragged_tensor,
+    sentinel_for_numpy_dtype,
+)
 
 # Default dtypes for all collate paths (subclass / kwargs may override per call).
 DEFAULT_COLLATE_NUMPY_DTYPE: np.dtype = np.dtype(np.float32)
@@ -97,7 +100,9 @@
     if isinstance(obj, np.ndarray) and obj.dtype == np.dtype(object):
         if obj.ndim == 0:
             return _literal_eval_str_cells_in_object_array(obj.item())
-        return [_literal_eval_str_cells_in_object_array(obj[i]) for i in range(obj.shape[0])]
+        return [
+            _literal_eval_str_cells_in_object_array(obj[i]) for i in range(obj.shape[0])
+        ]
     return obj
 
 
@@ -108,7 +113,11 @@
     numpy_dtype: np.dtype | None = None,
 ) -> np.ndarray:
     """Pad nested lists to a rectangular array of *numpy_dtype* (default: :data:`DEFAULT_COLLATE_NUMPY_DTYPE`)."""
-    target = np.dtype(numpy_dtype) if numpy_dtype is not None else DEFAULT_COLLATE_NUMPY_DTYPE
+    target = (
+        np.dtype(numpy_dtype)
+        if numpy_dtype is not None
+        else DEFAULT_COLLATE_NUMPY_DTYPE
+    )
 
     if not items:
         return np.array([], dtype=target)
@@ -125,12 +134,16 @@
 
     flat0 = items[0]
     if row_is_list_of_nested_cells(flat0):
-        normalized = [[np.asarray(sub, dtype=target).ravel() for sub in row] for row in items]
+        normalized = [
+            [np.asarray(sub, dtype=target).ravel() for sub in row] for row in items
+        ]
     else:
         normalized = [np.asarray(seq, dtype=target).ravel() for seq in items]
 
     obj = np.asarray(normalized, dtype=object)
-    effective_pad = pad_value if pad_value is not None else sentinel_for_numpy_dtype(target)
+    effective_pad = (
+        pad_value if pad_value is not None else sentinel_for_numpy_dtype(target)
+    )
     padded = pad_ragged_tensor(obj, effective_pad)
     if padded.dtype == np.object_:
         return np.array(padded.tolist(), dtype=target)
@@ -145,7 +158,11 @@
     numpy_dtype: np.dtype | None = None,
 ) -> np.ndarray:
     """Convert a single batch column value to a :class:`numpy.ndarray` of *numpy_dtype*."""
-    target = np.dtype(numpy_dtype) if numpy_dtype is not None else DEFAULT_COLLATE_NUMPY_DTYPE
+    target = (
+        np.dtype(numpy_dtype)
+        if numpy_dtype is not None
+        else DEFAULT_COLLATE_NUMPY_DTYPE
+    )
 
     if parse_string_with_literal_eval and isinstance(value, str):
         try:
@@ -179,7 +196,11 @@
     numpy_dtype: np.dtype | None = None,
 ) -> torch.Tensor:
     """Convert one column value to :class:`torch.Tensor` on *device* (see :func:`collate_value_to_float32_numpy`)."""
-    target = np.dtype(numpy_dtype) if numpy_dtype is not None else DEFAULT_COLLATE_NUMPY_DTYPE
+    target = (
+        np.dtype(numpy_dtype)
+        if numpy_dtype is not None
+        else DEFAULT_COLLATE_NUMPY_DTYPE
+    )
     arr = collate_value_to_float32_numpy(
         value,
         reshape_1d_features=reshape_1d_features,
@@ -199,7 +220,11 @@
     numpy_dtype: np.dtype | None = None,
 ) -> dict[str, torch.Tensor]:
     """Map a batch dict of Python / NumPy values to tensors (default element dtype: float32)."""
-    target = np.dtype(numpy_dtype) if numpy_dtype is not None else DEFAULT_COLLATE_NUMPY_DTYPE
+    target = (
+        np.dtype(numpy_dtype)
+        if numpy_dtype is not None
+        else DEFAULT_COLLATE_NUMPY_DTYPE
+    )
     return {
         k: collate_value_to_float32_tensor(
             v,
@@ -226,7 +251,11 @@
         self.device = device
         self.reshape_1d_features = reshape_1d_features
         self.parse_string_with_literal_eval = parse_string_with_literal_eval
-        self.numpy_dtype = np.dtype(numpy_dtype) if numpy_dtype is not None else DEFAULT_COLLATE_NUMPY_DTYPE
+        self.numpy_dtype = (
+            np.dtype(numpy_dtype)
+            if numpy_dtype is not None
+            else DEFAULT_COLLATE_NUMPY_DTYPE
+        )
 
     def collate_value_to_numpy(self, value) -> np.ndarray:
         """Convert one column value to :class:`~numpy.ndarray` (override in subclasses)."""

--- michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py
+++ michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py
@@ -72,17 +72,26 @@
         """Called when the train batch ends."""
         if self.step_checkpoint_frequency > 0:
             current_step = trainer.global_step
-            if current_step - self.last_step_checkpoint >= self.step_checkpoint_frequency:
+            if (
+                current_step - self.last_step_checkpoint
+                >= self.step_checkpoint_frequency
+            ):
                 checkpoint_id = f"step_{trainer.global_step}"
-                self._create_and_report_checkpoint(trainer, checkpoint_id, is_step_checkpoint=True)
+                self._create_and_report_checkpoint(
+                    trainer, checkpoint_id, is_step_checkpoint=True
+                )
                 self.last_step_checkpoint = current_step
 
     def on_train_epoch_end(self, trainer, pl_module) -> None:  # noqa: ARG002
         """Called when the train epoch ends."""
         checkpoint_id = f"epoch_{trainer.current_epoch}"
-        self._create_and_report_checkpoint(trainer, checkpoint_id, is_step_checkpoint=False)
+        self._create_and_report_checkpoint(
+            trainer, checkpoint_id, is_step_checkpoint=False
+        )
 
-    def _create_and_report_checkpoint(self, trainer, checkpoint_id: str, is_step_checkpoint: bool) -> None:
+    def _create_and_report_checkpoint(
+        self, trainer, checkpoint_id: str, is_step_checkpoint: bool
+    ) -> None:
         """Creates a checkpoint and reports it to Ray Train.
 
         Args:
@@ -96,7 +105,13 @@
 
         metrics = trainer.callback_metrics
         metrics = {k: v.item() for k, v in metrics.items()}
-        metrics.update({"epoch": trainer.current_epoch, "step": trainer.global_step, "is_step_checkpoint": is_step_checkpoint})
+        metrics.update(
+            {
+                "epoch": trainer.current_epoch,
+                "step": trainer.global_step,
+                "is_step_checkpoint": is_step_checkpoint,
+            }
+        )
 
         # Save checkpoint and report to Ray Train
         ckpt_path = Path(tmpdir, self.CHECKPOINT_NAME).as_posix()

--- michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py
+++ michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py
@@ -15,13 +15,23 @@
 
 from pytorch_lightning.callbacks import Callback, ModelCheckpoint
 from pytorch_lightning.loggers import CometLogger, Logger
-from pytorch_lightning.plugins import CheckpointIO, ClusterEnvironment, LayerSync, Precision
+from pytorch_lightning.plugins import (
+    CheckpointIO,
+    ClusterEnvironment,
+    LayerSync,
+    Precision,
+)
 from pytorch_lightning.strategies import Strategy
 from michelangelo.lib.trainer.torch.pytorch_lightning._private.callbacks import (
     RayTrainReportCallback,
     RayTrainReportPerNodeCallback,
 )
-from ray.train.lightning import RayDDPStrategy, RayDeepSpeedStrategy, RayFSDPStrategy, RayLightningEnvironment
+from ray.train.lightning import (
+    RayDDPStrategy,
+    RayDeepSpeedStrategy,
+    RayFSDPStrategy,
+    RayLightningEnvironment,
+)
 
 
 class UserException(Exception):
@@ -83,7 +93,9 @@
     - layer_names: substring match (pattern in layer_name)
     - layer_names_regex: re.search (matches anywhere in the string)
     """
-    print(f"Applying layer freeze based on transfer_learning_spec: {transfer_learning_spec}")
+    print(
+        f"Applying layer freeze based on transfer_learning_spec: {transfer_learning_spec}"
+    )
     names_to_freeze = transfer_learning_spec.get("layer_names_to_freeze") or []
     regex_to_freeze = transfer_learning_spec.get("layer_names_to_freeze_regex") or []
 
@@ -92,7 +104,9 @@
     # in layers_to_freeze but are correctly skipped in the named_parameters() loop below,
     # since buffers have no requires_grad. Actual parameters are always frozen correctly.
     model_layer_names = list(model.state_dict().keys())
-    print(f"[freeze] Model layer names ({len(model_layer_names)}): {model_layer_names!r}")
+    print(
+        f"[freeze] Model layer names ({len(model_layer_names)}): {model_layer_names!r}"
+    )
 
     layers_to_freeze = set()
     for available_name in model_layer_names:
@@ -113,7 +127,9 @@
             frozen_count += 1
 
     rank = ray.train.get_context().get_world_rank()
-    print(f"[freeze] [Rank {rank}] Layer freeze re-applied: {frozen_count} params frozen")
+    print(
+        f"[freeze] [Rank {rank}] Layer freeze re-applied: {frozen_count} params frozen"
+    )
 
 
 def _get_comet_logger(
@@ -140,7 +156,9 @@
         api_experiment = api.get_experiment_by_key(experiment_id)
         if api_experiment is None:
             # Create an experiment object
-            comet_ml.Experiment(api_key=api_key, project_name=project_name, workspace=workspace)
+            comet_ml.Experiment(
+                api_key=api_key, project_name=project_name, workspace=workspace
+            )
 
     torch.distributed.barrier()
     # Attach logger with existing experiment_id
@@ -164,12 +182,19 @@
     return comet_logger
 
 
-def _resolve_strategy(strategy: Optional[Union[str, Strategy]] = None, strategy_kwargs: Optional[dict[str, Any]] = None) -> Strategy:
+def _resolve_strategy(
+    strategy: Optional[Union[str, Strategy]] = None,
+    strategy_kwargs: Optional[dict[str, Any]] = None,
+) -> Strategy:
     """Factory to create the correct Ray/Lightning strategy based on strategy name or instance."""
     if strategy is not None and not isinstance(strategy, (str, Strategy)):
-        raise TypeError(f"strategy must be a str, Strategy instance, or None, got {type(strategy)!r}")
+        raise TypeError(
+            f"strategy must be a str, Strategy instance, or None, got {type(strategy)!r}"
+        )
     if strategy_kwargs is not None and not isinstance(strategy_kwargs, dict):
-        raise TypeError(f"strategy_kwargs must be a dict or None, got {type(strategy_kwargs)!r}")
+        raise TypeError(
+            f"strategy_kwargs must be a dict or None, got {type(strategy_kwargs)!r}"
+        )
 
     if isinstance(strategy, Strategy):
         return strategy
@@ -183,7 +208,9 @@
     elif strategy.lower() == "fsdp":
         return RayFSDPStrategy(**strategy_kwargs)
     else:
-        raise ValueError(f"Unsupported strategy: {strategy!r}; expected 'ddp', 'deepspeed', 'fsdp', or None")
+        raise ValueError(
+            f"Unsupported strategy: {strategy!r}; expected 'ddp', 'deepspeed', 'fsdp', or None"
+        )
 
 
 def _resolve_plugins(
@@ -191,12 +218,20 @@
     plugins_kwargs: Optional[dict[str, Any]] = None,
 ) -> list:
     """Resolve plugins for the Lightning Trainer, always ensuring RayLightningEnvironment is present."""
-    if plugins is not None and not isinstance(plugins, (str, list, tuple, *_PLUGIN_INPUT.__args__)):
-        raise TypeError(f"plugins must be a str import path, a plugin instance, a list of plugin instances, or None; got {type(plugins)!r}")
+    if plugins is not None and not isinstance(
+        plugins, (str, list, tuple, *_PLUGIN_INPUT.__args__)
+    ):
+        raise TypeError(
+            f"plugins must be a str import path, a plugin instance, a list of plugin instances, or None; got {type(plugins)!r}"
+        )
     if plugins_kwargs is not None and not isinstance(plugins_kwargs, dict):
-        raise TypeError(f"plugins_kwargs must be a dict or None, got {type(plugins_kwargs)!r}")
+        raise TypeError(
+            f"plugins_kwargs must be a dict or None, got {type(plugins_kwargs)!r}"
+        )
     if plugins_kwargs is not None and not isinstance(plugins, str):
-        raise TypeError("plugins_kwargs can only be used when plugins is a str import path")
+        raise TypeError(
+            "plugins_kwargs can only be used when plugins is a str import path"
+        )
 
     plugin_kwargs = plugins_kwargs or {}
 
@@ -206,7 +241,11 @@
         # Create the plugin instances from the provided plugins function
         plugins_fn = _get_module_attr(plugins)
         plugin_instances = plugins_fn(**plugin_kwargs)
-        result = list(plugin_instances) if isinstance(plugin_instances, (list, tuple)) else [plugin_instances]
+        result = (
+            list(plugin_instances)
+            if isinstance(plugin_instances, (list, tuple))
+            else [plugin_instances]
+        )
     elif isinstance(plugins, (list, tuple)):
         result = list(plugins)
     else:
@@ -233,11 +272,17 @@
 ) -> Optional[Union[bool, Logger, list[Logger]]]:
     """Resolve the logger for the Lightning Trainer."""
     if logger_kwargs is not None and not isinstance(logger_kwargs, dict):
-        raise TypeError(f"logger_kwargs must be a dict or None, got {type(logger_kwargs)!r}")
+        raise TypeError(
+            f"logger_kwargs must be a dict or None, got {type(logger_kwargs)!r}"
+        )
     if logger_kwargs is not None and not isinstance(logger, str):
-        raise TypeError("logger_kwargs can only be used when logger is a str import path")
+        raise TypeError(
+            "logger_kwargs can only be used when logger is a str import path"
+        )
     if comet_param is not None and not isinstance(comet_param, dict):
-        raise TypeError(f"comet_param must be a dict or None, got {type(comet_param)!r}")
+        raise TypeError(
+            f"comet_param must be a dict or None, got {type(comet_param)!r}"
+        )
 
     if isinstance(logger, bool):
         return logger
@@ -245,14 +290,18 @@
         return logger
     if isinstance(logger, (list, tuple)):
         if any(not isinstance(elem, Logger) for elem in logger):
-            raise TypeError(f"All elements of logger list must be Logger instances, got {logger!r}")
+            raise TypeError(
+                f"All elements of logger list must be Logger instances, got {logger!r}"
+            )
         return list(logger)
     if isinstance(logger, str):
         logger_fn = _get_module_attr(logger)
         result = logger_fn(**(logger_kwargs or {}))
         return list(result) if isinstance(result, (list, tuple)) else result
     if logger is not None:
-        raise TypeError(f"logger must be a str, bool, Logger instance, list of Logger instances, or None, got {type(logger)!r}")
+        raise TypeError(
+            f"logger must be a str, bool, Logger instance, list of Logger instances, or None, got {type(logger)!r}"
+        )
     if comet_param and run_id:
         return _get_comet_logger(
             run_id,
@@ -275,12 +324,22 @@
 
     A RayTrainReportCallback or RayTrainReportPerNodeCallback is always appended to the list.
     """
-    if callbacks is not None and not isinstance(callbacks, (str, Callback, list, tuple)):
-        raise TypeError(f"callbacks must be a str import path, a Callback instance, a list of Callback instances, or None; got {type(callbacks)!r}")
+    if callbacks is not None and not isinstance(
+        callbacks, (str, Callback, list, tuple)
+    ):
+        raise TypeError(
+            f"callbacks must be a str import path, a Callback instance, a list of Callback instances, or None; got {type(callbacks)!r}"
+        )
     if callback_kwargs is not None and not isinstance(callback_kwargs, dict):
-        raise TypeError(f"callback_kwargs must be a dict or None, got {type(callback_kwargs)!r}")
-    if per_node_callback_kwargs is not None and not isinstance(per_node_callback_kwargs, dict):
-        raise TypeError(f"per_node_callback_kwargs must be a dict or None, got {type(per_node_callback_kwargs)!r}")
+        raise TypeError(
+            f"callback_kwargs must be a dict or None, got {type(callback_kwargs)!r}"
+        )
+    if per_node_callback_kwargs is not None and not isinstance(
+        per_node_callback_kwargs, dict
+    ):
+        raise TypeError(
+            f"per_node_callback_kwargs must be a dict or None, got {type(per_node_callback_kwargs)!r}"
+        )
 
     callback_kwargs = callback_kwargs or {}
     resolved_callbacks: list[Callback] = []
@@ -292,29 +351,41 @@
         if isinstance(result, (list, tuple)):
             for obj in result:
                 if not isinstance(obj, Callback):
-                    raise TypeError(f"Expected Callback instances from {callbacks!r}, got {type(obj)!r}")
+                    raise TypeError(
+                        f"Expected Callback instances from {callbacks!r}, got {type(obj)!r}"
+                    )
                 resolved_callbacks.append(obj)
         elif isinstance(result, Callback):
             resolved_callbacks.append(result)
         else:
-            raise TypeError(f"Expected a Callback instance or list of Callback instances from {callbacks!r}, got {type(result)!r}")
+            raise TypeError(
+                f"Expected a Callback instance or list of Callback instances from {callbacks!r}, got {type(result)!r}"
+            )
     elif isinstance(callbacks, (list, tuple)):
         for obj in callbacks:
             if not isinstance(obj, Callback):
-                raise TypeError(f"All callbacks must be Callback instances, got {type(obj)!r}")
+                raise TypeError(
+                    f"All callbacks must be Callback instances, got {type(obj)!r}"
+                )
             resolved_callbacks.append(obj)
     elif callbacks is not None:
         resolved_callbacks.append(callbacks)
 
-    has_model_checkpoint = any(isinstance(c, ModelCheckpoint) for c in resolved_callbacks)
+    has_model_checkpoint = any(
+        isinstance(c, ModelCheckpoint) for c in resolved_callbacks
+    )
 
     # Always append a callback that calls ray.train.report() to report metrics and checkpoint.
     # Per-node reporting is required for model-parallel strategies (DeepSpeed ZeRO, FSDP) because
     # each node holds shards of the model and must upload its own checkpoint shard.
-    _use_per_node = per_node_callback_kwargs is not None or isinstance(strategy, (RayDeepSpeedStrategy, RayFSDPStrategy))
+    _use_per_node = per_node_callback_kwargs is not None or isinstance(
+        strategy, (RayDeepSpeedStrategy, RayFSDPStrategy)
+    )
     if _use_per_node:
         per_node_callback_kwargs = per_node_callback_kwargs or {}
-        resolved_callbacks.append(RayTrainReportPerNodeCallback(**per_node_callback_kwargs))
+        resolved_callbacks.append(
+            RayTrainReportPerNodeCallback(**per_node_callback_kwargs)
+        )
     else:
         resolved_callbacks.append(RayTrainReportCallback())
 
@@ -331,7 +402,10 @@
     and calls trainer.fit.
     """
     if torch.cuda.is_available():
-        print("CUDA is available with torch, training on GPU with CUDA version:", torch.version.cuda)
+        print(
+            "CUDA is available with torch, training on GPU with CUDA version:",
+            torch.version.cuda,
+        )
     else:
         print("CUDA is not available with torch, training on CPU.")
 
@@ -361,7 +435,9 @@
     train_dataloader = train_dataset_shard.iter_torch_batches(
         batch_size=batch_size,
         collate_fn=collate_fn_to_torch,
-        local_shuffle_buffer_size=None if num_shuffle_batches == 0 else num_shuffle_batches * batch_size,
+        local_shuffle_buffer_size=None
+        if num_shuffle_batches == 0
+        else num_shuffle_batches * batch_size,
     )
     val_dataloader = val_dataset_shard.iter_torch_batches(
         batch_size=batch_size,
@@ -380,10 +456,14 @@
     # it must be re-applied here using transfer_learning_spec.
     # =========================================================
     initial_weights_path = train_loop_config.get("initial_weights_path")
-    print(f"[init_weights] [Rank {rank}] Initial weights path: {initial_weights_path!r}")
+    print(
+        f"[init_weights] [Rank {rank}] Initial weights path: {initial_weights_path!r}"
+    )
     if initial_weights_path:
         if rank == 0:
-            print(f"[init_weights] [Rank 0] Loading initial weights from: {initial_weights_path!r}")
+            print(
+                f"[init_weights] [Rank 0] Loading initial weights from: {initial_weights_path!r}"
+            )
             try:
                 _load_weights_from_path(model, initial_weights_path)
                 print("[init_weights] [Rank 0] Weights loaded successfully.")
@@ -412,8 +492,13 @@
 
     # Convert values from trainer_kwargs to their corresponding arguments for the Lightning Trainer.
     # We pop the values from trainer_kwargs to avoid passing invalid values to the Lightning Trainer.
-    strategy = _resolve_strategy(trainer_kwargs.pop("strategy", None), trainer_kwargs.pop("strategy_kwargs", None))
-    plugins = _resolve_plugins(trainer_kwargs.pop("plugins", None), trainer_kwargs.pop("plugins_kwargs", None))
+    strategy = _resolve_strategy(
+        trainer_kwargs.pop("strategy", None),
+        trainer_kwargs.pop("strategy_kwargs", None),
+    )
+    plugins = _resolve_plugins(
+        trainer_kwargs.pop("plugins", None), trainer_kwargs.pop("plugins_kwargs", None)
+    )
     logger = _resolve_logger(
         trainer_kwargs.pop("logger", None),
         trainer_kwargs.pop("logger_kwargs", None),
@@ -432,7 +517,9 @@
     trainer_kwargs["plugins"] = plugins
     trainer_kwargs["logger"] = logger
     trainer_kwargs["callbacks"] = callbacks
-    trainer_kwargs["enable_checkpointing"] = has_model_checkpoint  # enable_checkpointing must be set to True if a ModelCheckpoint callback is used
+    trainer_kwargs["enable_checkpointing"] = (
+        has_model_checkpoint  # enable_checkpointing must be set to True if a ModelCheckpoint callback is used
+    )
 
     trainer = pl.Trainer(
         **trainer_kwargs,
@@ -449,4 +536,9 @@
     if checkpoint:
         local_ckpt_dir = checkpoint.to_directory()
         ckpt_path = os.path.join(local_ckpt_dir, CHECKPOINT_FILENAME)
-    trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, ckpt_path=ckpt_path)
+    trainer.fit(
+        model,
+        train_dataloaders=train_dataloader,
+        val_dataloaders=val_dataloader,
+        ckpt_path=ckpt_path,
+    )

--- michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py
+++ michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py
@@ -20,7 +20,9 @@
 from michelangelo.lib.trainer.torch.pytorch_lightning._private.util import (
     _train_loop_per_worker,
 )
-from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
+from pytorch_lightning.utilities.deepspeed import (
+    convert_zero_checkpoint_to_fp32_state_dict,
+)
 from contextlib import contextmanager
 
 _logger = logging.getLogger(__name__)
@@ -45,7 +47,9 @@
     train_data: ray.data.Dataset
     val_data: ray.data.Dataset
     batch_size: int = 8
-    num_shuffle_batches: int = 10  # By default we reserve 10 batches in ray data shuffle buffer.
+    num_shuffle_batches: int = (
+        10  # By default we reserve 10 batches in ray data shuffle buffer.
+    )
     num_epochs: Optional[int] = field(default=_UNSET)  # type: ignore[assignment]  # sentinel replaced in __post_init__
     data_collate_fn: Callable = None
     comet_param: CometParam = None
@@ -73,7 +77,9 @@
         scaling_config: Optional[ray.train.ScalingConfig] = None,
     ):
         self.trainer_param = trainer_param
-        _logger.info("LightningTrainer initialized with trainer_param: %r", trainer_param)
+        _logger.info(
+            "LightningTrainer initialized with trainer_param: %r", trainer_param
+        )
         train_loop_config = asdict(trainer_param)
         # Unique run id for Comet experiment
         train_loop_config["run_id"] = str(uuid.uuid4())
@@ -142,11 +148,15 @@
         Update the model state dict with the local checkpoint.
         """
         if not hasattr(self, "checkpoint") or self.checkpoint is None:
-            raise ValueError("No checkpoint available. Please call train() first to generate a checkpoint.")
+            raise ValueError(
+                "No checkpoint available. Please call train() first to generate a checkpoint."
+            )
         used_deepspeed = self._is_deepspeed_strategy()
         # use the ray checkpoint as_directory() to get the local temp checkpoint directory
         with self.checkpoint.as_directory() as d:
-            _logger.info(f"Saving Ray Checkpoint to local temp Checkpoint directory: {d}")
+            _logger.info(
+                f"Saving Ray Checkpoint to local temp Checkpoint directory: {d}"
+            )
             data_dir_contents = os.listdir(d)
             _logger.info(f"Data directory contents: {data_dir_contents}")
             lightning_ckpt_path = os.path.join(d, CHECKPOINT_NAME)
@@ -158,15 +168,21 @@
                 # explicitly pass weights_only, covering both pytorch_lightning and deepspeed internals.
                 # TODO: Remove this once we upgrade to Lightning 2.6+ https://github.com/Lightning-AI/pytorch-lightning/pull/21194
                 with _torch_weights_only_disabled():
-                    model_state_dict = convert_zero_checkpoint_to_fp32_state_dict(lightning_ckpt_path, local_model_path)
-                _logger.info(f"Loaded DeepSpeed checkpoint from {lightning_ckpt_path} to {local_model_path}")
+                    model_state_dict = convert_zero_checkpoint_to_fp32_state_dict(
+                        lightning_ckpt_path, local_model_path
+                    )
+                _logger.info(
+                    f"Loaded DeepSpeed checkpoint from {lightning_ckpt_path} to {local_model_path}"
+                )
             else:
                 # DDP checkpoint
                 checkpoint = torch.load(lightning_ckpt_path, map_location="cpu")
                 model_state_dict = checkpoint["state_dict"]
                 _logger.info(f"Loaded DDP checkpoint from {lightning_ckpt_path}")
             torch_model.load_state_dict(model_state_dict, strict=False)
-            _logger.info("Updated the state dict of the torch model in the ModelVariable")
+            _logger.info(
+                "Updated the state dict of the torch model in the ModelVariable"
+            )
 
 
 @contextmanager

--- michelangelo/lib/trainer/torch/utils.py
+++ michelangelo/lib/trainer/torch/utils.py
@@ -5,7 +5,9 @@
 from transformers import AutoModel
 
 
-def get_total_training_memory_transformers(model: AutoModel, batch_size: int, sequence_length: int) -> float:
+def get_total_training_memory_transformers(
+    model: AutoModel, batch_size: int, sequence_length: int
+) -> float:
     """
     Get the total memory (in MB) required for training the model.
     This function is specific to transformers models.
@@ -19,7 +21,9 @@
     dtype = model.config.torch_dtype  # Parameter type
     tensor_parallelism = 1  # Number of tensor parallelism
 
-    bytes_per_parameter = torch.tensor([1]).to(dtype).element_size()  # Bytes per parameter
+    bytes_per_parameter = (
+        torch.tensor([1]).to(dtype).element_size()
+    )  # Bytes per parameter
 
     # Calculating each memory component in MB
     # 1. Parameter Memory
@@ -40,22 +44,32 @@
         batch_size
         * sequence_length
         * hidden_size
-        * (10 + 24 / tensor_parallelism + 5 * num_atten_heads * sequence_length / hidden_size / tensor_parallelism)
+        * (
+            10
+            + 24 / tensor_parallelism
+            + 5 * num_atten_heads * sequence_length / hidden_size / tensor_parallelism
+        )
         / (1024**2)
     )
 
     # fp16 uses 2 bytes
-    activation_memory_per_layer = bytes_per_parameter / 2 * fp16_activation_memory_per_layer
+    activation_memory_per_layer = (
+        bytes_per_parameter / 2 * fp16_activation_memory_per_layer
+    )
     activation_memory_total = activation_memory_per_layer * num_layers
 
     # Summing up and adding 20% for additional buffers and overheads for GPU memory fragmentation.
-    total_memory = (parameter_memory + activation_memory_total + gradient_memory + optimizer_memory) * 1.2
+    total_memory = (
+        parameter_memory + activation_memory_total + gradient_memory + optimizer_memory
+    ) * 1.2
 
     return total_memory
 
 
 # Function to estimate activation memory on non-transformer layers
-def estimate_activation_memory_non_transformer(layer_output_dims, batch_size, bytes_per_value):
+def estimate_activation_memory_non_transformer(
+    layer_output_dims, batch_size, bytes_per_value
+):
     total_activation_memory_mb = 0
 
     for output_shape in layer_output_dims.values():
@@ -68,7 +82,9 @@
     return total_activation_memory_mb
 
 
-def get_total_training_memory_nn_module(model: torch.nn.Module, batch_size: int, input_size: int) -> float:
+def get_total_training_memory_nn_module(
+    model: torch.nn.Module, batch_size: int, input_size: int
+) -> float:
     """
     Get the total memory (in MB) required for training the model.
     This function is specific to non-transformers models.
@@ -81,7 +97,9 @@
         dtype = param.dtype
         break
 
-    bytes_per_parameter = torch.tensor([1]).to(dtype).element_size()  # Bytes per parameter
+    bytes_per_parameter = (
+        torch.tensor([1]).to(dtype).element_size()
+    )  # Bytes per parameter
 
     # Calculating each memory component in MB
     # 1. Parameter Memory
@@ -105,7 +123,12 @@
     # Register hooks for each layer in the model
     # We only count Linear layers, Conv layers, Norm layers, and RNN layers
     hooks = []
-    supported_layer_types = (nn.Linear, nn.modules.conv._ConvNd, nn.modules.batchnorm._NormBase, nn.modules.rnn.RNNBase)
+    supported_layer_types = (
+        nn.Linear,
+        nn.modules.conv._ConvNd,
+        nn.modules.batchnorm._NormBase,
+        nn.modules.rnn.RNNBase,
+    )
 
     for layer in model.children():
         if isinstance(layer, supported_layer_types):
@@ -123,9 +146,13 @@
         hook.remove()
 
     # Use captured output dimensions to estimate activation memory
-    total_activation_memory = estimate_activation_memory_non_transformer(layer_output_dims, batch_size, bytes_per_parameter)
+    total_activation_memory = estimate_activation_memory_non_transformer(
+        layer_output_dims, batch_size, bytes_per_parameter
+    )
 
     # Summing up and adding 20% for additional buffers and overheads for GPU memory fragmentation.
-    total_memory_mb = (parameter_memory + total_activation_memory + gradient_memory + optimizer_memory) * 1.2
+    total_memory_mb = (
+        parameter_memory + total_activation_memory + gradient_memory + optimizer_memory
+    ) * 1.2
 
     return total_memory_mb

📌 Please run poetry run ruff format . locally to fix formatting issues.

🚨 Lint Issues Detected

Show details
examples/movielens/data.py:58:89: E501 Line too long (89 > 88)
   |
56 |     try:
57 |         _logger.info("Downloading MovieLens-100k from %s", _DATA_URL)
58 |         with urllib.request.urlopen(_DATA_URL, timeout=_NETWORK_TIMEOUT_SECONDS) as resp:
   |                                                                                         ^ E501
59 |             data = resp.read()
60 |         with zipfile.ZipFile(io.BytesIO(data)) as z:
   |

examples/movielens/data.py:65:89: E501 Line too long (96 > 88)
   |
63 |         return udata_path
64 |     except (urllib.error.URLError, TimeoutError, OSError) as exc:
65 |         _logger.warning("Canonical URL failed (%s); falling back to %s", exc, _UDATA_MIRROR_URL)
   |                                                                                         ^^^^^^^^ E501
66 |
67 |     os.makedirs(extracted_dir, exist_ok=True)
   |

examples/movielens/data.py:68:89: E501 Line too long (93 > 88)
   |
67 |     os.makedirs(extracted_dir, exist_ok=True)
68 |     with urllib.request.urlopen(_UDATA_MIRROR_URL, timeout=_NETWORK_TIMEOUT_SECONDS) as resp:
   |                                                                                         ^^^^^ E501
69 |         udata_bytes = resp.read()
70 |     with open(udata_path, "wb") as f:
   |

examples/movielens/data.py:94:89: E501 Line too long (100 > 88)
   |
92 |         header=None,
93 |         names=["user_id", "item_id", "rating", "timestamp"],
94 |         dtype={"user_id": np.int64, "item_id": np.int64, "rating": np.int64, "timestamp": np.int64},
   |                                                                                         ^^^^^^^^^^^^ E501
95 |     )
96 |     _logger.info("Loaded %d ratings", len(df))
   |

examples/movielens/data.py:98:89: E501 Line too long (89 > 88)
    |
 96 |     _logger.info("Loaded %d ratings", len(df))
 97 |
 98 |     user_id_to_idx = {uid: idx for idx, uid in enumerate(sorted(df["user_id"].unique()))}
    |                                                                                         ^ E501
 99 |     item_id_to_idx = {iid: idx for idx, iid in enumerate(sorted(df["item_id"].unique()))}
100 |     df["user_idx"] = df["user_id"].map(user_id_to_idx).astype(np.int64)
    |

examples/movielens/data.py:99:89: E501 Line too long (89 > 88)
    |
 98 |     user_id_to_idx = {uid: idx for idx, uid in enumerate(sorted(df["user_id"].unique()))}
 99 |     item_id_to_idx = {iid: idx for idx, iid in enumerate(sorted(df["item_id"].unique()))}
    |                                                                                         ^ E501
100 |     df["user_idx"] = df["user_id"].map(user_id_to_idx).astype(np.int64)
101 |     df["item_idx"] = df["item_id"].map(item_id_to_idx).astype(np.int64)
    |

examples/movielens/data.py:102:89: E501 Line too long (90 > 88)
    |
100 |     df["user_idx"] = df["user_id"].map(user_id_to_idx).astype(np.int64)
101 |     df["item_idx"] = df["item_id"].map(item_id_to_idx).astype(np.int64)
102 |     df["rating_norm"] = ((df["rating"].astype(np.float32) - 1.0) / 4.0).astype(np.float32)
    |                                                                                         ^^ E501
103 |
104 |     num_users = len(user_id_to_idx)
    |

examples/movielens/model.py:8:8: N812 Lowercase `functional` imported as non-lowercase `F`
  |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
  |        ^^^^^^^^^^^^^^^^^^^^^^^^ N812
  |

examples/movielens/model.py:19:9: D107 Missing docstring in `__init__`
   |
17 |     """
18 |
19 |     def __init__(
   |         ^^^^^^^^ D107
20 |         self,
21 |         num_users: int,
   |

examples/movielens/model.py:41:9: D102 Missing docstring in public method
   |
39 |         nn.init.normal_(self.item_emb.weight, std=0.01)
40 |
41 |     def forward(self, user_idx: torch.Tensor, item_idx: torch.Tensor) -> torch.Tensor:
   |         ^^^^^^^ D102
42 |         u = self.user_emb(user_idx)
43 |         i = self.item_emb(item_idx)
   |

examples/movielens/model.py:55:89: E501 Line too long (100 > 88)
   |
53 |         loss = F.mse_loss(preds, target)
54 |         # sync_dist=True so the metric is averaged across Ray Train workers.
55 |         self.log(f"{stage}_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
   |                                                                                         ^^^^^^^^^^^^ E501
56 |         return loss
   |

examples/movielens/model.py:58:9: D102 Missing docstring in public method
   |
56 |         return loss
57 |
58 |     def training_step(self, batch, batch_idx):  # noqa: ARG002
   |         ^^^^^^^^^^^^^ D102
59 |         return self._step(batch, "train")
   |

examples/movielens/model.py:58:49: RUF100 [*] Unused `noqa` directive (non-enabled: `ARG002`)
   |
56 |         return loss
57 |
58 |     def training_step(self, batch, batch_idx):  # noqa: ARG002
   |                                                 ^^^^^^^^^^^^^^ RUF100
59 |         return self._step(batch, "train")
   |
   = help: Remove unused `noqa` directive

examples/movielens/model.py:61:9: D102 Missing docstring in public method
   |
59 |         return self._step(batch, "train")
60 |
61 |     def validation_step(self, batch, batch_idx):  # noqa: ARG002
   |         ^^^^^^^^^^^^^^^ D102
62 |         return self._step(batch, "val")
   |

examples/movielens/model.py:61:51: RUF100 [*] Unused `noqa` directive (non-enabled: `ARG002`)
   |
59 |         return self._step(batch, "train")
60 |
61 |     def validation_step(self, batch, batch_idx):  # noqa: ARG002
   |                                                   ^^^^^^^^^^^^^^ RUF100
62 |         return self._step(batch, "val")
   |
   = help: Remove unused `noqa` directive

examples/movielens/model.py:64:9: D102 Missing docstring in public method
   |
62 |         return self._step(batch, "val")
63 |
64 |     def configure_optimizers(self):
   |         ^^^^^^^^^^^^^^^^^^^^ D102
65 |         return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
   |

examples/movielens/train.py:38:89: E501 Line too long (97 > 88)
   |
36 | )
37 |
38 | logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
   |                                                                                         ^^^^^^^^^ E501
39 | log = logging.getLogger("examples.movielens.train")
   |

examples/movielens/train.py:47:29: UP007 Use `X | Y` for type annotations
   |
47 | def _build_comet_param() -> Optional[CometParam]:
   |                             ^^^^^^^^^^^^^^^^^^^^ UP007
48 |     """Build a CometParam from env vars, or return None to skip Comet logging.
   |
   = help: Convert to `X | Y`

examples/movielens/train.py:63:89: E501 Line too long (91 > 88)
   |
61 |         workspace=workspace,
62 |         project_name=os.environ.get("COMET_PROJECT_NAME", _DEFAULT_COMET_PROJECT),
63 |         experiment_name=os.environ.get("COMET_EXPERIMENT_NAME", _DEFAULT_COMET_EXPERIMENT),
   |                                                                                         ^^^ E501
64 |         tags=tags,
65 |     )
   |

examples/movielens/train.py:68:42: UP007 Use `X | Y` for type annotations
   |
68 | def _parse_mlflow_tags(tags_env: str) -> Optional[dict]:
   |                                          ^^^^^^^^^^^^^^ UP007
69 |     """Parse ``key1=val1,key2=val2`` into a dict; return None if empty/malformed."""
70 |     if not tags_env.strip():
   |
   = help: Convert to `X | Y`

examples/movielens/train.py:95:57: RUF100 [*] Unused `noqa` directive (non-enabled: `PLC0415`)
   |
93 |         return None
94 |     # Import lazily so an unused MLflow path doesn't force the dependency.
95 |     from pytorch_lightning.loggers import MLFlowLogger  # noqa: PLC0415
   |                                                         ^^^^^^^^^^^^^^^ RUF100
96 |
97 |     return MLFlowLogger(
   |
   = help: Remove unused `noqa` directive

examples/movielens/train.py:98:89: E501 Line too long (93 > 88)
    |
 97 |     return MLFlowLogger(
 98 |         experiment_name=os.environ.get("MLFLOW_EXPERIMENT_NAME", _DEFAULT_MLFLOW_EXPERIMENT),
    |                                                                                         ^^^^^ E501
 99 |         tracking_uri=tracking_uri,
100 |         run_name=os.environ.get("MLFLOW_RUN_NAME"),
    |

examples/movielens/train.py:105:5: D103 Missing docstring in public function
    |
105 | def main() -> dict:
    |     ^^^^ D103
106 |     splits = load_movielens_100k()
    |

examples/movielens/train.py:119:89: E501 Line too long (107 > 88)
    |
117 |         )
118 |         if os.environ.get("MLFLOW_TRACKING_URI"):
119 |             log.info("MLFLOW_TRACKING_URI is also set but Comet takes precedence; MLflow logging skipped.")
    |                                                                                         ^^^^^^^^^^^^^^^^^^^ E501
120 |     else:
121 |         mlflow_logger = _build_mlflow_logger()
    |

examples/movielens/train.py:129:89: E501 Line too long (101 > 88)
    |
127 |             )
128 |         else:
129 |             log.info("Experiment tracking disabled (no COMET_* or MLFLOW_TRACKING_URI env vars set)")
    |                                                                                         ^^^^^^^^^^^^^ E501
130 |
131 |     lightning_trainer_kwargs = {
    |

michelangelo/lib/trainer/torch/_numpy_utils.py:6:89: E501 Line too long (89 > 88)
  |
4 | ``uber.ai.michelangelo.shared.utils.numpy_utils.{pad,sentinel,type}`` and
5 | ``uber.ai.michelangelo.shared.constants.sentinel``. Kept private to the trainer
6 | package; callers should use :mod:`michelangelo.lib.trainer.torch.data_collate_functions`.
  |                                                                                         ^ E501
7 | """
  |

michelangelo/lib/trainer/torch/_numpy_utils.py:22:50: UP007 Use `X | Y` for type annotations
   |
22 | def sentinel_for_numpy_dtype(dtype: np.dtype) -> Union[float, int, str, bytes, bool]:
   |                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ UP007
23 |     """Return the type-native sentinel value for *dtype*."""
24 |     if np.issubdtype(dtype, np.floating):
   |
   = help: Convert to `X | Y`

michelangelo/lib/trainer/torch/_numpy_utils.py:67:89: E501 Line too long (101 > 88)
   |
65 |             if all(isinstance(elem, np.ndarray) and elem.ndim == 1 for elem in arr):
66 |                 first_non_empty = next((a for a in arr if a.size > 0), None)
67 |                 if first_non_empty is not None and first_non_empty.dtype.kind not in ("U", "S", "O"):
   |                                                                                         ^^^^^^^^^^^^^ E501
68 |                     dtype = next((a.dtype for a in arr if a.size > 0), np.int32)
69 |                     pad_value = sentinel_for_numpy_dtype(dtype) if pad_value is None else pad_value
   |

michelangelo/lib/trainer/torch/_numpy_utils.py:69:89: E501 Line too long (99 > 88)
   |
67 |                 if first_non_empty is not None and first_non_empty.dtype.kind not in ("U", "S", "O"):
68 |                     dtype = next((a.dtype for a in arr if a.size > 0), np.int32)
69 |                     pad_value = sentinel_for_numpy_dtype(dtype) if pad_value is None else pad_value
   |                                                                                         ^^^^^^^^^^^ E501
70 |                     return _pad_1d_arrays_fast(arr, pad_value, dtype)
71 |         except (AttributeError, TypeError):
   |

michelangelo/lib/trainer/torch/_numpy_utils.py:149:89: E501 Line too long (99 > 88)
    |
147 |             elem = arr[i] if isinstance(arr, (np.ndarray, list)) else arr
148 |             if isinstance(elem, (list, np.ndarray)):
149 |                 padded_elem = _pad_array_recursive(elem, target_shape, pad_value, dtype, level + 1)
    |                                                                                         ^^^^^^^^^^^ E501
150 |             else:
151 |                 padded_elem = elem
    |

michelangelo/lib/trainer/torch/data_collate_functions.py:3:89: E501 Line too long (90 > 88)
  |
1 | """Collate helpers for Ray Data / PyTorch training.
2 |
3 | This module exposes small building blocks so callers can compose custom collate functions:
  |                                                                                         ^^ E501
4 |
5 | - :data:`DEFAULT_COLLATE_NUMPY_DTYPE` / :data:`DEFAULT_COLLATE_TORCH_DTYPE` — default dtypes
  |

michelangelo/lib/trainer/torch/data_collate_functions.py:5:89: E501 Line too long (92 > 88)
  |
3 | This module exposes small building blocks so callers can compose custom collate functions:
4 |
5 | - :data:`DEFAULT_COLLATE_NUMPY_DTYPE` / :data:`DEFAULT_COLLATE_TORCH_DTYPE` — default dtypes
  |                                                                                         ^^^^ E501
6 |   (``float32`` unless overridden via function or :class:`LiteralEvalFloat32Collate` kwargs).
7 | - :func:`pad_ragged_lists` — pad nested Python lists to a dense array of *numpy_dtype*.
  |

michelangelo/lib/trainer/torch/data_collate_functions.py:6:89: E501 Line too long (92 > 88)
  |
5 | - :data:`DEFAULT_COLLATE_NUMPY_DTYPE` / :data:`DEFAULT_COLLATE_TORCH_DTYPE` — default dtypes
6 |   (``float32`` unless overridden via function or :class:`LiteralEvalFloat32Collate` kwargs).
  |                                                                                         ^^^^ E501
7 | - :func:`pad_ragged_lists` — pad nested Python lists to a dense array of *numpy_dtype*.
8 | - :func:`cell_is_nested_subsequence` / :func:`row_is_list_of_nested_cells` — structure checks.
  |

michelangelo/lib/trainer/torch/data_collate_functions.py:8:89: E501 Line too long (94 > 88)
   |
 6 |   (``float32`` unless overridden via function or :class:`LiteralEvalFloat32Collate` kwargs).
 7 | - :func:`pad_ragged_lists` — pad nested Python lists to a dense array of *numpy_dtype*.
 8 | - :func:`cell_is_nested_subsequence` / :func:`row_is_list_of_nested_cells` — structure checks.
   |                                                                                         ^^^^^^ E501
 9 | - :func:`collate_value_to_float32_numpy` — one feature column → :class:`numpy.ndarray`.
10 | - :func:`collate_value_to_float32_tensor` — one feature column → :class:`torch.Tensor`.
   |

michelangelo/lib/trainer/torch/data_collate_functions.py:14:89: E501 Line too long (98 > 88)
   |
13 | The default :func:`literal_eval_data_collate_function` is implemented on top of these.
14 | :class:`LiteralEvalFloat32Collate` wraps the same behavior for subclassing (custom device, hooks).
   |                                                                                         ^^^^^^^^^^ E501
15 | """
   |

michelangelo/lib/trainer/torch/data_collate_functions.py:17:1: I001 [*] Import block is un-sorted or un-formatted
   |
15 |   """
16 |
17 | / from __future__ import annotations
18 | |
19 | | import ast
20 | |
21 | | import numpy as np
22 | | import torch
23 | |
24 | | from michelangelo.lib.trainer.torch._numpy_utils import pad_ragged_tensor, sentinel_for_numpy_dtype
   | |___________________________________________________________________________________________________^ I001
25 |
26 |   # Default dtypes for all collate paths (subclass / kwargs may override per call).
   |
   = help: Organize imports

michelangelo/lib/trainer/torch/data_collate_functions.py:24:89: E501 Line too long (99 > 88)
   |
22 | import torch
23 |
24 | from michelangelo.lib.trainer.torch._numpy_utils import pad_ragged_tensor, sentinel_for_numpy_dtype
   |                                                                                         ^^^^^^^^^^^ E501
25 |
26 | # Default dtypes for all collate paths (subclass / kwargs may override per call).
   |

michelangelo/lib/trainer/torch/data_collate_functions.py:61:89: E501 Line too long (92 > 88)
   |
60 | def cell_is_nested_subsequence(cell) -> bool:
61 |     """Return True if *cell* is a vector-valued slot (list/tuple or ndarray with ndim >= 1).
   |                                                                                         ^^^^ E501
62 |
63 |     Scalars and 0-D ndarrays are leaves for the 2-D-ragged path (one flat vector per row).
   |

michelangelo/lib/trainer/torch/data_collate_functions.py:63:89: E501 Line too long (90 > 88)
   |
61 |     """Return True if *cell* is a vector-valued slot (list/tuple or ndarray with ndim >= 1).
62 |
63 |     Scalars and 0-D ndarrays are leaves for the 2-D-ragged path (one flat vector per row).
   |                                                                                         ^^ E501
64 |     """
65 |     if isinstance(cell, (list, tuple)):
   |

michelangelo/lib/trainer/torch/data_collate_functions.py:73:89: E501 Line too long (103 > 88)
   |
72 | def row_is_list_of_nested_cells(flat0: list | np.ndarray) -> bool:
73 |     """Return True when *flat0* is a row of cells where at least one cell is a sub-sequence (3-D path).
   |                                                                                         ^^^^^^^^^^^^^^^ E501
74 |
75 |     Uses every cell, not only ``flat0[0]``, so a leading scalar with later list cells still
   |

michelangelo/lib/trainer/torch/data_collate_functions.py:75:89: E501 Line too long (91 > 88)
   |
73 |     """Return True when *flat0* is a row of cells where at least one cell is a sub-sequence (3-D path).
74 |
75 |     Uses every cell, not only ``flat0[0]``, so a leading scalar with later list cells still
   |                                                                                         ^^^ E501
76 |     selects the 3-D normalization branch.
77 |     """
   |

michelangelo/lib/trainer/torch/data_collate_functions.py:88:5: SIM103 Return the condition `not isinstance(flat0, (list, tuple, np.ndarray))` directly
   |
86 |       if isinstance(flat0, np.ndarray) and flat0.ndim == 0:
87 |           return True
88 | /     if isinstance(flat0, (list, tuple, np.ndarray)):
89 | |         return False
90 | |     return True
   | |_______________^ SIM103
   |
   = help: Replace with `return not isinstance(flat0, (list, tuple, np.ndarray))`

michelangelo/lib/trainer/torch/data_collate_functions.py:100:89: E501 Line too long (93 > 88)
    |
 98 |         if obj.ndim == 0:
 99 |             return _literal_eval_str_cells_in_object_array(obj.item())
100 |         return [_literal_eval_str_cells_in_object_array(obj[i]) for i in range(obj.shape[0])]
    |                                                                                         ^^^^^ E501
101 |     return obj
    |

michelangelo/lib/trainer/torch/data_collate_functions.py:110:89: E501 Line too long (114 > 88)
    |
108 |     numpy_dtype: np.dtype | None = None,
109 | ) -> np.ndarray:
110 |     """Pad nested lists to a rectangular array of *numpy_dtype* (default: :data:`DEFAULT_COLLATE_NUMPY_DTYPE`)."""
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^ E501
111 |     target = np.dtype(numpy_dtype) if numpy_dtype is not None else DEFAULT_COLLATE_NUMPY_DTYPE
    |

michelangelo/lib/trainer/torch/data_collate_functions.py:111:89: E501 Line too long (94 > 88)
    |
109 | ) -> np.ndarray:
110 |     """Pad nested lists to a rectangular array of *numpy_dtype* (default: :data:`DEFAULT_COLLATE_NUMPY_DTYPE`)."""
111 |     target = np.dtype(numpy_dtype) if numpy_dtype is not None else DEFAULT_COLLATE_NUMPY_DTYPE
    |                                                                                         ^^^^^^ E501
112 |
113 |     if not items:
    |

michelangelo/lib/trainer/torch/data_collate_functions.py:128:89: E501 Line too long (94 > 88)
    |
126 |     flat0 = items[0]
127 |     if row_is_list_of_nested_cells(flat0):
128 |         normalized = [[np.asarray(sub, dtype=target).ravel() for sub in row] for row in items]
    |                                                                                         ^^^^^^ E501
129 |     else:
130 |         normalized = [np.asarray(seq, dtype=target).ravel() for seq in items]
    |

michelangelo/lib/trainer/torch/data_collate_functions.py:133:89: E501 Line too long (92 > 88)
    |
132 |     obj = np.asarray(normalized, dtype=object)
133 |     effective_pad = pad_value if pad_value is not None else sentinel_for_numpy_dtype(target)
    |                                                                                         ^^^^ E501
134 |     padded = pad_ragged_tensor(obj, effective_pad)
135 |     if padded.dtype == np.object_:
    |

michelangelo/lib/trainer/torch/data_collate_functions.py:147:89: E501 Line too long (91 > 88)
    |
145 |     numpy_dtype: np.dtype | None = None,
146 | ) -> np.ndarray:
147 |     """Convert a single batch column value to a :class:`numpy.ndarray` of *numpy_dtype*."""
    |                                                                                         ^^^ E501
148 |     target = np.dtype(numpy_dtype) if numpy_dtype is not None else DEFAULT_COLLATE_NUMPY_DTYPE
    |

michelangelo/lib/trainer/torch/data_collate_functions.py:148:89: E501 Line too long (94 > 88)
    |
146 | ) -> np.ndarray:
147 |     """Convert a single batch column value to a :class:`numpy.ndarray` of *numpy_dtype*."""
148 |     target = np.dtype(numpy_dtype) if numpy_dtype is not None else DEFAULT_COLLATE_NUMPY_DTYPE
    |                                                                                         ^^^^^^ E501
149 |
150 |     if parse_string_with_literal_eval and isinstance(value, str):
    |

michelangelo/lib/trainer/torch/data_collate_functions.py:151:9: SIM105 Use `contextlib.suppress(ValueError, SyntaxError)` instead of `try`-`except`-`pass`
    |
150 |       if parse_string_with_literal_eval and isinstance(value, str):
151 | /         try:
152 | |             value = ast.literal_eval(value)
153 | |         except (ValueError, SyntaxError):
154 | |             pass
    | |________________^ SIM105
155 |
156 |       if not isinstance(value, np.ndarray):
    |
    = help: Replace with `contextlib.suppress(ValueError, SyntaxError)`

michelangelo/lib/trainer/torch/data_collate_functions.py:181:89: E501 Line too long (117 > 88)
    |
179 |     numpy_dtype: np.dtype | None = None,
180 | ) -> torch.Tensor:
181 |     """Convert one column value to :class:`torch.Tensor` on *device* (see :func:`collate_value_to_float32_numpy`)."""
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E501
182 |     target = np.dtype(numpy_dtype) if numpy_dtype is not None else DEFAULT_COLLATE_NUMPY_DTYPE
183 |     arr = collate_value_to_float32_numpy(
    |

michelangelo/lib/trainer/torch/data_collate_functions.py:182:89: E501 Line too long (94 > 88)
    |
180 | ) -> torch.Tensor:
181 |     """Convert one column value to :class:`torch.Tensor` on *device* (see :func:`collate_value_to_float32_numpy`)."""
182 |     target = np.dtype(numpy_dtype) if numpy_dtype is not None else DEFAULT_COLLATE_NUMPY_DTYPE
    |                                                                                         ^^^^^^ E501
183 |     arr = collate_value_to_float32_numpy(
184 |         value,
    |

michelangelo/lib/trainer/torch/data_collate_functions.py:201:89: E501 Line too long (96 > 88)
    |
199 |     numpy_dtype: np.dtype | None = None,
200 | ) -> dict[str, torch.Tensor]:
201 |     """Map a batch dict of Python / NumPy values to tensors (default element dtype: float32)."""
    |                                                                                         ^^^^^^^^ E501
202 |     target = np.dtype(numpy_dtype) if numpy_dtype is not None else DEFAULT_COLLATE_NUMPY_DTYPE
203 |     return {
    |

michelangelo/lib/trainer/torch/data_collate_functions.py:202:89: E501 Line too long (94 > 88)
    |
200 | ) -> dict[str, torch.Tensor]:
201 |     """Map a batch dict of Python / NumPy values to tensors (default element dtype: float32)."""
202 |     target = np.dtype(numpy_dtype) if numpy_dtype is not None else DEFAULT_COLLATE_NUMPY_DTYPE
    |                                                                                         ^^^^^^ E501
203 |     return {
204 |         k: collate_value_to_float32_tensor(
    |

michelangelo/lib/trainer/torch/data_collate_functions.py:218:9: D107 Missing docstring in `__init__`
    |
216 |     """Default collate with :func:`ast.literal_eval` for stringified arrays."""
217 |
218 |     def __init__(
    |         ^^^^^^^^ D107
219 |         self,
220 |         *,
    |

michelangelo/lib/trainer/torch/data_collate_functions.py:229:89: E501 Line too long (108 > 88)
    |
227 |         self.reshape_1d_features = reshape_1d_features
228 |         self.parse_string_with_literal_eval = parse_string_with_literal_eval
229 |         self.numpy_dtype = np.dtype(numpy_dtype) if numpy_dtype is not None else DEFAULT_COLLATE_NUMPY_DTYPE
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^ E501
230 |
231 |     def collate_value_to_numpy(self, value) -> np.ndarray:
    |

michelangelo/lib/trainer/torch/data_collate_functions.py:232:89: E501 Line too long (91 > 88)
    |
231 |     def collate_value_to_numpy(self, value) -> np.ndarray:
232 |         """Convert one column value to :class:`~numpy.ndarray` (override in subclasses)."""
    |                                                                                         ^^^ E501
233 |         return collate_value_to_float32_numpy(
234 |             value,
    |

michelangelo/lib/trainer/torch/data_collate_functions.py:250:9: D102 Missing docstring in public method
    |
248 |         return {k: self.collate_value_to_tensor(v) for k, v in batch_data.items()}
249 |
250 |     def __call__(self, batch_data: dict) -> dict[str, torch.Tensor]:
    |         ^^^^^^^^ D102
251 |         return self.collate_batch(batch_data)
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py:12:5: D205 1 blank line required between summary line and description
   |
11 |   class RayTrainReportCallback(ray.train.lightning.RayTrainReportCallback):
12 | /     """
13 | |     We follow existing implementation of RayTrainReportCallback,
14 | |     only force rank zero to report checkpoint.
15 | |     Reference: https://docs.ray.io/en/latest/_modules/ray/train/lightning/_lightning_utils.html#RayTrainReportCallback
16 | |     """
   | |_______^ D205
17 |
18 |       def __init__(self) -> None:
   |
   = help: Insert single blank line

michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py:12:5: D212 [*] Multi-line docstring summary should start at the first line
   |
11 |   class RayTrainReportCallback(ray.train.lightning.RayTrainReportCallback):
12 | /     """
13 | |     We follow existing implementation of RayTrainReportCallback,
14 | |     only force rank zero to report checkpoint.
15 | |     Reference: https://docs.ray.io/en/latest/_modules/ray/train/lightning/_lightning_utils.html#RayTrainReportCallback
16 | |     """
   | |_______^ D212
17 |
18 |       def __init__(self) -> None:
   |
   = help: Remove whitespace after opening quotes

michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py:12:5: D415 First line should end with a period, question mark, or exclamation point
   |
11 |   class RayTrainReportCallback(ray.train.lightning.RayTrainReportCallback):
12 | /     """
13 | |     We follow existing implementation of RayTrainReportCallback,
14 | |     only force rank zero to report checkpoint.
15 | |     Reference: https://docs.ray.io/en/latest/_modules/ray/train/lightning/_lightning_utils.html#RayTrainReportCallback
16 | |     """
   | |_______^ D415
17 |
18 |       def __init__(self) -> None:
   |
   = help: Add closing punctuation

michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py:22:64: RUF100 [*] Unused `noqa` directive (non-enabled: `ARG002`)
   |
20 |         self.world_rank = ray.train.get_context().get_world_rank()
21 |
22 |     def on_train_epoch_end(self, trainer, pl_module) -> None:  # noqa: ARG002
   |                                                                ^^^^^^^^^^^^^^ RUF100
23 |         # Creates a checkpoint dir with fixed name
24 |         tmpdir = Path(self.tmpdir_prefix, str(trainer.current_epoch)).as_posix()
   |
   = help: Remove unused `noqa` directive

michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py:55:5: D205 1 blank line required between summary line and description
   |
54 |   class RayTrainReportPerNodeCallback(RayTrainReportCallback):
55 | /     """
56 | |     We derive from RayTrainReportCallback, but report checkpoint per node instead of on head rank.
57 | |     Report per node is necessary for model parallelism for deepspeed zeros and FSDP.
58 | |     Also supports step-wise checkpointing in addition to epoch-based checkpointing.
59 | |     """
   | |_______^ D205
60 |
61 |       def __init__(self, step_checkpoint_frequency: int = 0) -> None:
   |
   = help: Insert single blank line

michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py:55:5: D212 [*] Multi-line docstring summary should start at the first line
   |
54 |   class RayTrainReportPerNodeCallback(RayTrainReportCallback):
55 | /     """
56 | |     We derive from RayTrainReportCallback, but report checkpoint per node instead of on head rank.
57 | |     Report per node is necessary for model parallelism for deepspeed zeros and FSDP.
58 | |     Also supports step-wise checkpointing in addition to epoch-based checkpointing.
59 | |     """
   | |_______^ D212
60 |
61 |       def __init__(self, step_checkpoint_frequency: int = 0) -> None:
   |
   = help: Remove whitespace after opening quotes

michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py:56:89: E501 Line too long (98 > 88)
   |
54 | class RayTrainReportPerNodeCallback(RayTrainReportCallback):
55 |     """
56 |     We derive from RayTrainReportCallback, but report checkpoint per node instead of on head rank.
   |                                                                                         ^^^^^^^^^^ E501
57 |     Report per node is necessary for model parallelism for deepspeed zeros and FSDP.
58 |     Also supports step-wise checkpointing in addition to epoch-based checkpointing.
   |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py:62:9: D205 1 blank line required between summary line and description
   |
61 |       def __init__(self, step_checkpoint_frequency: int = 0) -> None:
62 | /         """
63 | |         Args:
64 | |             step_checkpoint_frequency: How often to create checkpoints during training steps.
65 | |                 Default is 0 steps. Set to 0 to disable step-wise checkpointing.
66 | |         """
   | |___________^ D205
67 |           super().__init__()
68 |           self.step_checkpoint_frequency = step_checkpoint_frequency
   |
   = help: Insert single blank line

michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py:62:9: D212 [*] Multi-line docstring summary should start at the first line
   |
61 |       def __init__(self, step_checkpoint_frequency: int = 0) -> None:
62 | /         """
63 | |         Args:
64 | |             step_checkpoint_frequency: How often to create checkpoints during training steps.
65 | |                 Default is 0 steps. Set to 0 to disable step-wise checkpointing.
66 | |         """
   | |___________^ D212
67 |           super().__init__()
68 |           self.step_checkpoint_frequency = step_checkpoint_frequency
   |
   = help: Remove whitespace after opening quotes

michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py:64:89: E501 Line too long (93 > 88)
   |
62 |         """
63 |         Args:
64 |             step_checkpoint_frequency: How often to create checkpoints during training steps.
   |                                                                                         ^^^^^ E501
65 |                 Default is 0 steps. Set to 0 to disable step-wise checkpointing.
66 |         """
   |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py:71:70: RUF100 [*] Unused `noqa` directive (non-enabled: `ARG002`)
   |
69 |         self.last_step_checkpoint = 0
70 |
71 |     def on_train_batch_end(self, trainer, *args, **kwargs) -> None:  # noqa: ARG002
   |                                                                      ^^^^^^^^^^^^^^ RUF100
72 |         """Called when the train batch ends."""
73 |         if self.step_checkpoint_frequency > 0:
   |
   = help: Remove unused `noqa` directive

michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py:75:89: E501 Line too long (90 > 88)
   |
73 |         if self.step_checkpoint_frequency > 0:
74 |             current_step = trainer.global_step
75 |             if current_step - self.last_step_checkpoint >= self.step_checkpoint_frequency:
   |                                                                                         ^^ E501
76 |                 checkpoint_id = f"step_{trainer.global_step}"
77 |                 self._create_and_report_checkpoint(trainer, checkpoint_id, is_step_checkpoint=True)
   |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py:77:89: E501 Line too long (99 > 88)
   |
75 |             if current_step - self.last_step_checkpoint >= self.step_checkpoint_frequency:
76 |                 checkpoint_id = f"step_{trainer.global_step}"
77 |                 self._create_and_report_checkpoint(trainer, checkpoint_id, is_step_checkpoint=True)
   |                                                                                         ^^^^^^^^^^^ E501
78 |                 self.last_step_checkpoint = current_step
   |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py:80:64: RUF100 [*] Unused `noqa` directive (non-enabled: `ARG002`)
   |
78 |                 self.last_step_checkpoint = current_step
79 |
80 |     def on_train_epoch_end(self, trainer, pl_module) -> None:  # noqa: ARG002
   |                                                                ^^^^^^^^^^^^^^ RUF100
81 |         """Called when the train epoch ends."""
82 |         checkpoint_id = f"epoch_{trainer.current_epoch}"
   |
   = help: Remove unused `noqa` directive

michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py:83:89: E501 Line too long (92 > 88)
   |
81 |         """Called when the train epoch ends."""
82 |         checkpoint_id = f"epoch_{trainer.current_epoch}"
83 |         self._create_and_report_checkpoint(trainer, checkpoint_id, is_step_checkpoint=False)
   |                                                                                         ^^^^ E501
84 |
85 |     def _create_and_report_checkpoint(self, trainer, checkpoint_id: str, is_step_checkpoint: bool) -> None:
   |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py:85:89: E501 Line too long (107 > 88)
   |
83 |         self._create_and_report_checkpoint(trainer, checkpoint_id, is_step_checkpoint=False)
84 |
85 |     def _create_and_report_checkpoint(self, trainer, checkpoint_id: str, is_step_checkpoint: bool) -> None:
   |                                                                                         ^^^^^^^^^^^^^^^^^^^ E501
86 |         """Creates a checkpoint and reports it to Ray Train.
   |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py:90:89: E501 Line too long (99 > 88)
   |
88 |         Args:
89 |             trainer: The PyTorch Lightning trainer instance
90 |             checkpoint_id: Unique identifier for the checkpoint (e.g., epoch number or step number)
   |                                                                                         ^^^^^^^^^^^ E501
91 |             is_step_checkpoint: Whether this is a step-wise checkpoint (True) or epoch checkpoint (False)
92 |         """
   |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py:91:89: E501 Line too long (105 > 88)
   |
89 |             trainer: The PyTorch Lightning trainer instance
90 |             checkpoint_id: Unique identifier for the checkpoint (e.g., epoch number or step number)
91 |             is_step_checkpoint: Whether this is a step-wise checkpoint (True) or epoch checkpoint (False)
   |                                                                                         ^^^^^^^^^^^^^^^^^ E501
92 |         """
93 |         # Create checkpoint directory and prepare metrics
   |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/callbacks.py:99:89: E501 Line too long (127 > 88)
    |
 97 |         metrics = trainer.callback_metrics
 98 |         metrics = {k: v.item() for k, v in metrics.items()}
 99 |         metrics.update({"epoch": trainer.current_epoch, "step": trainer.global_step, "is_step_checkpoint": is_step_checkpoint})
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E501
100 |
101 |         # Save checkpoint and report to Ray Train
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:18:89: E501 Line too long (92 > 88)
   |
16 | from pytorch_lightning.callbacks import Callback, ModelCheckpoint
17 | from pytorch_lightning.loggers import CometLogger, Logger
18 | from pytorch_lightning.plugins import CheckpointIO, ClusterEnvironment, LayerSync, Precision
   |                                                                                         ^^^^ E501
19 | from pytorch_lightning.strategies import Strategy
20 | from michelangelo.lib.trainer.torch.pytorch_lightning._private.callbacks import (
   |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:24:89: E501 Line too long (110 > 88)
   |
22 |     RayTrainReportPerNodeCallback,
23 | )
24 | from ray.train.lightning import RayDDPStrategy, RayDeepSpeedStrategy, RayFSDPStrategy, RayLightningEnvironment
   |                                                                                         ^^^^^^^^^^^^^^^^^^^^^^ E501
   |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:27:7: N818 Exception name `UserException` should be named with an Error suffix
   |
27 | class UserException(Exception):
   |       ^^^^^^^^^^^^^ N818
28 |     """Raised when a user-supplied input or path causes training to fail."""
   |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:53:5: D205 1 blank line required between summary line and description
   |
52 |   def _load_weights_from_path(model: torch.nn.Module, path: str) -> None:
53 | /     """Download a state-dict file from any supported storage (local, s3://, gs://, ...)
54 | |     and load it into *model* with strict=True.
55 | |     """
   | |_______^ D205
56 |       fs, fs_path = url_to_fs(path)
57 |       with TemporaryDirectory() as tmp_dir:
   |
   = help: Insert single blank line

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:60:89: E501 Line too long (97 > 88)
   |
58 |         local_path = os.path.join(tmp_dir, "init_weights.pt")
59 |         fs.get(fs_path, local_path)
60 |         # Load to CPU first; DDP/DeepSpeed will move tensors to the correct GPU during broadcast.
   |                                                                                         ^^^^^^^^^ E501
61 |         state_dict = torch.load(local_path, map_location="cpu", weights_only=True)
62 |         # strict=True is intentional: initial_weights_path is expected to point to a complete
   |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:62:89: E501 Line too long (93 > 88)
   |
60 |         # Load to CPU first; DDP/DeepSpeed will move tensors to the correct GPU during broadcast.
61 |         state_dict = torch.load(local_path, map_location="cpu", weights_only=True)
62 |         # strict=True is intentional: initial_weights_path is expected to point to a complete
   |                                                                                         ^^^^^ E501
63 |         # state dict produced upstream for the same model architecture.
64 |         model.load_state_dict(state_dict, strict=True)
   |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:68:89: E501 Line too long (101 > 88)
   |
67 | def _print_layer_weights(model: torch.nn.Module, limit: int = 50) -> None:
68 |     """Print a summary of each parameter tensor's name, shape, and first `limit` chars of weights."""
   |                                                                                         ^^^^^^^^^^^^^ E501
69 |     print("=== Layer weights summary ===")
70 |     for name, param in model.named_parameters():
   |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:77:5: D205 1 blank line required between summary line and description
   |
76 |   def _apply_layer_freeze(model: torch.nn.Module, transfer_learning_spec: dict):
77 | /     """
78 | |     Re-apply layer freezing from transfer_learning_spec after loading state_dict.
79 | |     state_dict does not store requires_grad, so freezing applied upstream must be
80 | |     re-applied in each worker.
81 | |
82 | |     Matching logic:
83 | |     - layer_names: substring match (pattern in layer_name)
84 | |     - layer_names_regex: re.search (matches anywhere in the string)
85 | |     """
   | |_______^ D205
86 |       print(f"Applying layer freeze based on transfer_learning_spec: {transfer_learning_spec}")
87 |       names_to_freeze = transfer_learning_spec.get("layer_names_to_freeze") or []
   |
   = help: Insert single blank line

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:77:5: D212 [*] Multi-line docstring summary should start at the first line
   |
76 |   def _apply_layer_freeze(model: torch.nn.Module, transfer_learning_spec: dict):
77 | /     """
78 | |     Re-apply layer freezing from transfer_learning_spec after loading state_dict.
79 | |     state_dict does not store requires_grad, so freezing applied upstream must be
80 | |     re-applied in each worker.
81 | |
82 | |     Matching logic:
83 | |     - layer_names: substring match (pattern in layer_name)
84 | |     - layer_names_regex: re.search (matches anywhere in the string)
85 | |     """
   | |_______^ D212
86 |       print(f"Applying layer freeze based on transfer_learning_spec: {transfer_learning_spec}")
87 |       names_to_freeze = transfer_learning_spec.get("layer_names_to_freeze") or []
   |
   = help: Remove whitespace after opening quotes

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:86:89: E501 Line too long (93 > 88)
   |
84 |     - layer_names_regex: re.search (matches anywhere in the string)
85 |     """
86 |     print(f"Applying layer freeze based on transfer_learning_spec: {transfer_learning_spec}")
   |                                                                                         ^^^^^ E501
87 |     names_to_freeze = transfer_learning_spec.get("layer_names_to_freeze") or []
88 |     regex_to_freeze = transfer_learning_spec.get("layer_names_to_freeze_regex") or []
   |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:90:89: E501 Line too long (90 > 88)
   |
88 |     regex_to_freeze = transfer_learning_spec.get("layer_names_to_freeze_regex") or []
89 |
90 |     # state_dict().keys() is used intentionally as a superset to show the full model state
   |                                                                                         ^^ E501
91 |     # (parameters + buffers) in debug output. Buffers (e.g., bn.running_mean) may appear
92 |     # in layers_to_freeze but are correctly skipped in the named_parameters() loop below,
   |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:92:89: E501 Line too long (89 > 88)
   |
90 |     # state_dict().keys() is used intentionally as a superset to show the full model state
91 |     # (parameters + buffers) in debug output. Buffers (e.g., bn.running_mean) may appear
92 |     # in layers_to_freeze but are correctly skipped in the named_parameters() loop below,
   |                                                                                         ^ E501
93 |     # since buffers have no requires_grad. Actual parameters are always frozen correctly.
94 |     model_layer_names = list(model.state_dict().keys())
   |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:93:89: E501 Line too long (89 > 88)
   |
91 |     # (parameters + buffers) in debug output. Buffers (e.g., bn.running_mean) may appear
92 |     # in layers_to_freeze but are correctly skipped in the named_parameters() loop below,
93 |     # since buffers have no requires_grad. Actual parameters are always frozen correctly.
   |                                                                                         ^ E501
94 |     model_layer_names = list(model.state_dict().keys())
95 |     print(f"[freeze] Model layer names ({len(model_layer_names)}): {model_layer_names!r}")
   |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:95:89: E501 Line too long (90 > 88)
   |
93 |     # since buffers have no requires_grad. Actual parameters are always frozen correctly.
94 |     model_layer_names = list(model.state_dict().keys())
95 |     print(f"[freeze] Model layer names ({len(model_layer_names)}): {model_layer_names!r}")
   |                                                                                         ^^ E501
96 |
97 |     layers_to_freeze = set()
   |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:116:89: E501 Line too long (90 > 88)
    |
115 |     rank = ray.train.get_context().get_world_rank()
116 |     print(f"[freeze] [Rank {rank}] Layer freeze re-applied: {frozen_count} params frozen")
    |                                                                                         ^^ E501
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:143:89: E501 Line too long (96 > 88)
    |
141 |         if api_experiment is None:
142 |             # Create an experiment object
143 |             comet_ml.Experiment(api_key=api_key, project_name=project_name, workspace=workspace)
    |                                                                                         ^^^^^^^^ E501
144 |
145 |     torch.distributed.barrier()
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:167:89: E501 Line too long (133 > 88)
    |
167 | def _resolve_strategy(strategy: Optional[Union[str, Strategy]] = None, strategy_kwargs: Optional[dict[str, Any]] = None) -> Strategy:
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E501
168 |     """Factory to create the correct Ray/Lightning strategy based on strategy name or instance."""
169 |     if strategy is not None and not isinstance(strategy, (str, Strategy)):
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:168:89: E501 Line too long (98 > 88)
    |
167 | def _resolve_strategy(strategy: Optional[Union[str, Strategy]] = None, strategy_kwargs: Optional[dict[str, Any]] = None) -> Strategy:
168 |     """Factory to create the correct Ray/Lightning strategy based on strategy name or instance."""
    |                                                                                         ^^^^^^^^^^ E501
169 |     if strategy is not None and not isinstance(strategy, (str, Strategy)):
170 |         raise TypeError(f"strategy must be a str, Strategy instance, or None, got {type(strategy)!r}")
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:170:89: E501 Line too long (102 > 88)
    |
168 |     """Factory to create the correct Ray/Lightning strategy based on strategy name or instance."""
169 |     if strategy is not None and not isinstance(strategy, (str, Strategy)):
170 |         raise TypeError(f"strategy must be a str, Strategy instance, or None, got {type(strategy)!r}")
    |                                                                                         ^^^^^^^^^^^^^^ E501
171 |     if strategy_kwargs is not None and not isinstance(strategy_kwargs, dict):
172 |         raise TypeError(f"strategy_kwargs must be a dict or None, got {type(strategy_kwargs)!r}")
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:172:89: E501 Line too long (97 > 88)
    |
170 |         raise TypeError(f"strategy must be a str, Strategy instance, or None, got {type(strategy)!r}")
171 |     if strategy_kwargs is not None and not isinstance(strategy_kwargs, dict):
172 |         raise TypeError(f"strategy_kwargs must be a dict or None, got {type(strategy_kwargs)!r}")
    |                                                                                         ^^^^^^^^^ E501
173 |
174 |     if isinstance(strategy, Strategy):
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:186:89: E501 Line too long (109 > 88)
    |
184 |         return RayFSDPStrategy(**strategy_kwargs)
185 |     else:
186 |         raise ValueError(f"Unsupported strategy: {strategy!r}; expected 'ddp', 'deepspeed', 'fsdp', or None")
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^^ E501
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:193:89: E501 Line too long (104 > 88)
    |
191 |     plugins_kwargs: Optional[dict[str, Any]] = None,
192 | ) -> list:
193 |     """Resolve plugins for the Lightning Trainer, always ensuring RayLightningEnvironment is present."""
    |                                                                                         ^^^^^^^^^^^^^^^^ E501
194 |     if plugins is not None and not isinstance(plugins, (str, list, tuple, *_PLUGIN_INPUT.__args__)):
195 |         raise TypeError(f"plugins must be a str import path, a plugin instance, a list of plugin instances, or None; got {type(plugin…
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:194:89: E501 Line too long (100 > 88)
    |
192 | ) -> list:
193 |     """Resolve plugins for the Lightning Trainer, always ensuring RayLightningEnvironment is present."""
194 |     if plugins is not None and not isinstance(plugins, (str, list, tuple, *_PLUGIN_INPUT.__args__)):
    |                                                                                         ^^^^^^^^^^^^ E501
195 |         raise TypeError(f"plugins must be a str import path, a plugin instance, a list of plugin instances, or None; got {type(plugin…
196 |     if plugins_kwargs is not None and not isinstance(plugins_kwargs, dict):
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:195:89: E501 Line too long (140 > 88)
    |
193 | …lways ensuring RayLightningEnvironment is present."""
194 | …ns, (str, list, tuple, *_PLUGIN_INPUT.__args__)):
195 | …ort path, a plugin instance, a list of plugin instances, or None; got {type(plugins)!r}")
    |                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E501
196 | …e(plugins_kwargs, dict):
197 | …dict or None, got {type(plugins_kwargs)!r}")
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:197:89: E501 Line too long (95 > 88)
    |
195 |         raise TypeError(f"plugins must be a str import path, a plugin instance, a list of plugin instances, or None; got {type(plugin…
196 |     if plugins_kwargs is not None and not isinstance(plugins_kwargs, dict):
197 |         raise TypeError(f"plugins_kwargs must be a dict or None, got {type(plugins_kwargs)!r}")
    |                                                                                         ^^^^^^^ E501
198 |     if plugins_kwargs is not None and not isinstance(plugins, str):
199 |         raise TypeError("plugins_kwargs can only be used when plugins is a str import path")
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:199:89: E501 Line too long (92 > 88)
    |
197 |         raise TypeError(f"plugins_kwargs must be a dict or None, got {type(plugins_kwargs)!r}")
198 |     if plugins_kwargs is not None and not isinstance(plugins, str):
199 |         raise TypeError("plugins_kwargs can only be used when plugins is a str import path")
    |                                                                                         ^^^^ E501
200 |
201 |     plugin_kwargs = plugins_kwargs or {}
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:209:89: E501 Line too long (110 > 88)
    |
207 |         plugins_fn = _get_module_attr(plugins)
208 |         plugin_instances = plugins_fn(**plugin_kwargs)
209 |         result = list(plugin_instances) if isinstance(plugin_instances, (list, tuple)) else [plugin_instances]
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^^^ E501
210 |     elif isinstance(plugins, (list, tuple)):
211 |         result = list(plugins)
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:218:89: E501 Line too long (152 > 88)
    |
216 | …
217 | …
218 | …__ for t in _PLUGIN_INPUT.__args__]}; got invalid types: {[type(p).__name__ for p in invalid]}"
    |                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E501
219 | …
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:221:89: E501 Line too long (100 > 88)
    |
219 |         )
220 |
221 |     # We always need to use the RayLightningEnvironment plugin for lightning training with Ray Train
    |                                                                                         ^^^^^^^^^^^^ E501
222 |     if not any(isinstance(p, RayLightningEnvironment) for p in result):
223 |         result.append(RayLightningEnvironment())
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:236:89: E501 Line too long (93 > 88)
    |
234 |     """Resolve the logger for the Lightning Trainer."""
235 |     if logger_kwargs is not None and not isinstance(logger_kwargs, dict):
236 |         raise TypeError(f"logger_kwargs must be a dict or None, got {type(logger_kwargs)!r}")
    |                                                                                         ^^^^^ E501
237 |     if logger_kwargs is not None and not isinstance(logger, str):
238 |         raise TypeError("logger_kwargs can only be used when logger is a str import path")
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:238:89: E501 Line too long (90 > 88)
    |
236 |         raise TypeError(f"logger_kwargs must be a dict or None, got {type(logger_kwargs)!r}")
237 |     if logger_kwargs is not None and not isinstance(logger, str):
238 |         raise TypeError("logger_kwargs can only be used when logger is a str import path")
    |                                                                                         ^^ E501
239 |     if comet_param is not None and not isinstance(comet_param, dict):
240 |         raise TypeError(f"comet_param must be a dict or None, got {type(comet_param)!r}")
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:240:89: E501 Line too long (89 > 88)
    |
238 |         raise TypeError("logger_kwargs can only be used when logger is a str import path")
239 |     if comet_param is not None and not isinstance(comet_param, dict):
240 |         raise TypeError(f"comet_param must be a dict or None, got {type(comet_param)!r}")
    |                                                                                         ^ E501
241 |
242 |     if isinstance(logger, bool):
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:248:89: E501 Line too long (100 > 88)
    |
246 |     if isinstance(logger, (list, tuple)):
247 |         if any(not isinstance(elem, Logger) for elem in logger):
248 |             raise TypeError(f"All elements of logger list must be Logger instances, got {logger!r}")
    |                                                                                         ^^^^^^^^^^^^ E501
249 |         return list(logger)
250 |     if isinstance(logger, str):
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:255:89: E501 Line too long (128 > 88)
    |
253 |         return list(result) if isinstance(result, (list, tuple)) else result
254 |     if logger is not None:
255 |         raise TypeError(f"logger must be a str, bool, Logger instance, list of Logger instances, or None, got {type(logger)!r}")
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E501
256 |     if comet_param and run_id:
257 |         return _get_comet_logger(
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:276:89: E501 Line too long (93 > 88)
    |
274 |     """Build callback list for the Lightning Trainer.
275 |
276 |     A RayTrainReportCallback or RayTrainReportPerNodeCallback is always appended to the list.
    |                                                                                         ^^^^^ E501
277 |     """
278 |     if callbacks is not None and not isinstance(callbacks, (str, Callback, list, tuple)):
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:278:89: E501 Line too long (89 > 88)
    |
276 |     A RayTrainReportCallback or RayTrainReportPerNodeCallback is always appended to the list.
277 |     """
278 |     if callbacks is not None and not isinstance(callbacks, (str, Callback, list, tuple)):
    |                                                                                         ^ E501
279 |         raise TypeError(f"callbacks must be a str import path, a Callback instance, a list of Callback instances, or None; got {type(…
280 |     if callback_kwargs is not None and not isinstance(callback_kwargs, dict):
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:279:89: E501 Line too long (148 > 88)
    |
277 | …
278 | …ks, (str, Callback, list, tuple)):
279 | …t path, a Callback instance, a list of Callback instances, or None; got {type(callbacks)!r}")
    |                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E501
280 | …allback_kwargs, dict):
281 | …t or None, got {type(callback_kwargs)!r}")
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:281:89: E501 Line too long (97 > 88)
    |
279 |         raise TypeError(f"callbacks must be a str import path, a Callback instance, a list of Callback instances, or None; got {type(…
280 |     if callback_kwargs is not None and not isinstance(callback_kwargs, dict):
281 |         raise TypeError(f"callback_kwargs must be a dict or None, got {type(callback_kwargs)!r}")
    |                                                                                         ^^^^^^^^^ E501
282 |     if per_node_callback_kwargs is not None and not isinstance(per_node_callback_kwargs, dict):
283 |         raise TypeError(f"per_node_callback_kwargs must be a dict or None, got {type(per_node_callback_kwargs)!r}")
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:282:89: E501 Line too long (95 > 88)
    |
280 |     if callback_kwargs is not None and not isinstance(callback_kwargs, dict):
281 |         raise TypeError(f"callback_kwargs must be a dict or None, got {type(callback_kwargs)!r}")
282 |     if per_node_callback_kwargs is not None and not isinstance(per_node_callback_kwargs, dict):
    |                                                                                         ^^^^^^^ E501
283 |         raise TypeError(f"per_node_callback_kwargs must be a dict or None, got {type(per_node_callback_kwargs)!r}")
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:283:89: E501 Line too long (115 > 88)
    |
281 |         raise TypeError(f"callback_kwargs must be a dict or None, got {type(callback_kwargs)!r}")
282 |     if per_node_callback_kwargs is not None and not isinstance(per_node_callback_kwargs, dict):
283 |         raise TypeError(f"per_node_callback_kwargs must be a dict or None, got {type(per_node_callback_kwargs)!r}")
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^ E501
284 |
285 |     callback_kwargs = callback_kwargs or {}
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:289:89: E501 Line too long (105 > 88)
    |
288 |     if isinstance(callbacks, str):
289 |         # Import the callable and invoke it — may be a Callback class or a factory returning one or more.
    |                                                                                         ^^^^^^^^^^^^^^^^^ E501
290 |         fn = _get_module_attr(callbacks)
291 |         result = fn(**callback_kwargs)
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:295:89: E501 Line too long (105 > 88)
    |
293 |             for obj in result:
294 |                 if not isinstance(obj, Callback):
295 |                     raise TypeError(f"Expected Callback instances from {callbacks!r}, got {type(obj)!r}")
    |                                                                                         ^^^^^^^^^^^^^^^^^ E501
296 |                 resolved_callbacks.append(obj)
297 |         elif isinstance(result, Callback):
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:300:89: E501 Line too long (131 > 88)
    |
298 |             resolved_callbacks.append(result)
299 |         else:
300 |             raise TypeError(f"Expected a Callback instance or list of Callback instances from {callbacks!r}, got {type(result)!r}")
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E501
301 |     elif isinstance(callbacks, (list, tuple)):
302 |         for obj in callbacks:
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:304:89: E501 Line too long (95 > 88)
    |
302 |         for obj in callbacks:
303 |             if not isinstance(obj, Callback):
304 |                 raise TypeError(f"All callbacks must be Callback instances, got {type(obj)!r}")
    |                                                                                         ^^^^^^^ E501
305 |             resolved_callbacks.append(obj)
306 |     elif callbacks is not None:
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:309:89: E501 Line too long (90 > 88)
    |
307 |         resolved_callbacks.append(callbacks)
308 |
309 |     has_model_checkpoint = any(isinstance(c, ModelCheckpoint) for c in resolved_callbacks)
    |                                                                                         ^^ E501
310 |
311 |     # Always append a callback that calls ray.train.report() to report metrics and checkpoint.
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:311:89: E501 Line too long (94 > 88)
    |
309 |     has_model_checkpoint = any(isinstance(c, ModelCheckpoint) for c in resolved_callbacks)
310 |
311 |     # Always append a callback that calls ray.train.report() to report metrics and checkpoint.
    |                                                                                         ^^^^^^ E501
312 |     # Per-node reporting is required for model-parallel strategies (DeepSpeed ZeRO, FSDP) because
313 |     # each node holds shards of the model and must upload its own checkpoint shard.
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:312:89: E501 Line too long (97 > 88)
    |
311 |     # Always append a callback that calls ray.train.report() to report metrics and checkpoint.
312 |     # Per-node reporting is required for model-parallel strategies (DeepSpeed ZeRO, FSDP) because
    |                                                                                         ^^^^^^^^^ E501
313 |     # each node holds shards of the model and must upload its own checkpoint shard.
314 |     _use_per_node = per_node_callback_kwargs is not None or isinstance(strategy, (RayDeepSpeedStrategy, RayFSDPStrategy))
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:314:89: E501 Line too long (121 > 88)
    |
312 |     # Per-node reporting is required for model-parallel strategies (DeepSpeed ZeRO, FSDP) because
313 |     # each node holds shards of the model and must upload its own checkpoint shard.
314 |     _use_per_node = per_node_callback_kwargs is not None or isinstance(strategy, (RayDeepSpeedStrategy, RayFSDPStrategy))
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E501
315 |     if _use_per_node:
316 |         per_node_callback_kwargs = per_node_callback_kwargs or {}
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:334:89: E501 Line too long (101 > 88)
    |
332 |     """
333 |     if torch.cuda.is_available():
334 |         print("CUDA is available with torch, training on GPU with CUDA version:", torch.version.cuda)
    |                                                                                         ^^^^^^^^^^^^^ E501
335 |     else:
336 |         print("CUDA is not available with torch, training on CPU.")
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:344:89: E501 Line too long (136 > 88)
    |
342 | …
343 | …
344 | …case when the LightningTrainer is used directly without using the tabular_trainer task,
    |                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E501
345 | …g_trainer_kwargs.max_epochs was set.
346 | …
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:364:89: E501 Line too long (105 > 88)
    |
362 |         batch_size=batch_size,
363 |         collate_fn=collate_fn_to_torch,
364 |         local_shuffle_buffer_size=None if num_shuffle_batches == 0 else num_shuffle_batches * batch_size,
    |                                                                                         ^^^^^^^^^^^^^^^^^ E501
365 |     )
366 |     val_dataloader = val_dataset_shard.iter_torch_batches(
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:383:89: E501 Line too long (89 > 88)
    |
381 |     # =========================================================
382 |     initial_weights_path = train_loop_config.get("initial_weights_path")
383 |     print(f"[init_weights] [Rank {rank}] Initial weights path: {initial_weights_path!r}")
    |                                                                                         ^ E501
384 |     if initial_weights_path:
385 |         if rank == 0:
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:386:89: E501 Line too long (100 > 88)
    |
384 |     if initial_weights_path:
385 |         if rank == 0:
386 |             print(f"[init_weights] [Rank 0] Loading initial weights from: {initial_weights_path!r}")
    |                                                                                         ^^^^^^^^^^^^ E501
387 |             try:
388 |                 _load_weights_from_path(model, initial_weights_path)
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:410:89: E501 Line too long (145 > 88)
    |
408 | …
409 | …
410 | …r_kwargs is ignored; its value is determined by the presence of a ModelCheckpoint callback."
    |                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E501
411 | …
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:413:89: E501 Line too long (100 > 88)
    |
411 |         )
412 |
413 |     # Convert values from trainer_kwargs to their corresponding arguments for the Lightning Trainer.
    |                                                                                         ^^^^^^^^^^^^ E501
414 |     # We pop the values from trainer_kwargs to avoid passing invalid values to the Lightning Trainer.
415 |     strategy = _resolve_strategy(trainer_kwargs.pop("strategy", None), trainer_kwargs.pop("strategy_kwargs", None))
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:414:89: E501 Line too long (101 > 88)
    |
413 |     # Convert values from trainer_kwargs to their corresponding arguments for the Lightning Trainer.
414 |     # We pop the values from trainer_kwargs to avoid passing invalid values to the Lightning Trainer.
    |                                                                                         ^^^^^^^^^^^^^ E501
415 |     strategy = _resolve_strategy(trainer_kwargs.pop("strategy", None), trainer_kwargs.pop("strategy_kwargs", None))
416 |     plugins = _resolve_plugins(trainer_kwargs.pop("plugins", None), trainer_kwargs.pop("plugins_kwargs", None))
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:415:89: E501 Line too long (115 > 88)
    |
413 |     # Convert values from trainer_kwargs to their corresponding arguments for the Lightning Trainer.
414 |     # We pop the values from trainer_kwargs to avoid passing invalid values to the Lightning Trainer.
415 |     strategy = _resolve_strategy(trainer_kwargs.pop("strategy", None), trainer_kwargs.pop("strategy_kwargs", None))
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^ E501
416 |     plugins = _resolve_plugins(trainer_kwargs.pop("plugins", None), trainer_kwargs.pop("plugins_kwargs", None))
417 |     logger = _resolve_logger(
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:416:89: E501 Line too long (111 > 88)
    |
414 |     # We pop the values from trainer_kwargs to avoid passing invalid values to the Lightning Trainer.
415 |     strategy = _resolve_strategy(trainer_kwargs.pop("strategy", None), trainer_kwargs.pop("strategy_kwargs", None))
416 |     plugins = _resolve_plugins(trainer_kwargs.pop("plugins", None), trainer_kwargs.pop("plugins_kwargs", None))
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^^^^ E501
417 |     logger = _resolve_logger(
418 |         trainer_kwargs.pop("logger", None),
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:435:89: E501 Line too long (147 > 88)
    |
433 | …
434 | …
435 | …_checkpoint  # enable_checkpointing must be set to True if a ModelCheckpoint callback is used
    |                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E501
436 | …
437 | …
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:446:89: E501 Line too long (91 > 88)
    |
445 |     # Download checkpoint locally to support both DDP and DeepSpeed strategies.
446 |     # DDP checkpoints are single files; DeepSpeed ZeRO checkpoints are sharded directories.
    |                                                                                         ^^^ E501
447 |     # Using to_directory() handles both cases by downloading the full checkpoint to a local path.
448 |     ckpt_path = None
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:447:89: E501 Line too long (97 > 88)
    |
445 |     # Download checkpoint locally to support both DDP and DeepSpeed strategies.
446 |     # DDP checkpoints are single files; DeepSpeed ZeRO checkpoints are sharded directories.
447 |     # Using to_directory() handles both cases by downloading the full checkpoint to a local path.
    |                                                                                         ^^^^^^^^^ E501
448 |     ckpt_path = None
449 |     if checkpoint:
    |

michelangelo/lib/trainer/torch/pytorch_lightning/_private/util.py:452:89: E501 Line too long (111 > 88)
    |
450 |         local_ckpt_dir = checkpoint.to_directory()
451 |         ckpt_path = os.path.join(local_ckpt_dir, CHECKPOINT_FILENAME)
452 |     trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, ckpt_path=ckpt_path)
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^^^^ E501
    |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:1:1: D100 Missing docstring in public module
michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:23:89: E501 Line too long (92 > 88)
   |
21 |     _train_loop_per_worker,
22 | )
23 | from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
   |                                                                                         ^^^^ E501
24 | from contextlib import contextmanager
   |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:33:7: D101 Missing docstring in public class
   |
32 | @dataclass
33 | class CometParam:
   |       ^^^^^^^^^^ D101
34 |     api_key: str
35 |     project_name: str
   |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:42:7: D101 Missing docstring in public class
   |
41 | @dataclass
42 | class LightningTrainerParam:
   |       ^^^^^^^^^^^^^^^^^^^^^ D101
43 |     create_model_fn: Callable
44 |     create_model_fn_kwargs: dict
   |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:48:89: E501 Line too long (97 > 88)
   |
46 |     val_data: ray.data.Dataset
47 |     batch_size: int = 8
48 |     num_shuffle_batches: int = 10  # By default we reserve 10 batches in ray data shuffle buffer.
   |                                                                                         ^^^^^^^^^ E501
49 |     num_epochs: Optional[int] = field(default=_UNSET)  # type: ignore[assignment]  # sentinel replaced in __post_init__
50 |     data_collate_fn: Callable = None
   |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:58:89: E501 Line too long (113 > 88)
   |
56 |     initial_weights_path: Optional[str] = None
57 |
58 |     # Raise warning if the deprecated num_epochs field is set. We default to 1 epoch for backwards compatibility.
   |                                                                                         ^^^^^^^^^^^^^^^^^^^^^^^^^ E501
59 |     def __post_init__(self):
60 |         if self.num_epochs is _UNSET:
   |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:59:9: D105 Missing docstring in magic method
   |
58 |     # Raise warning if the deprecated num_epochs field is set. We default to 1 epoch for backwards compatibility.
59 |     def __post_init__(self):
   |         ^^^^^^^^^^^^^ D105
60 |         if self.num_epochs is _UNSET:
61 |             self.num_epochs = 1
   |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:64:89: E501 Line too long (143 > 88)
   |
62 | …
63 | …
64 | … deprecated. Use LightningTrainerParam.lightning_trainer_kwargs={'max_epochs': N} instead."
   |                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E501
65 | …
   |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:68:7: D101 Missing docstring in public class
   |
68 | class LightningTrainer(TorchTrainer):
   |       ^^^^^^^^^^^^^^^^ D101
69 |     def __init__(
70 |         self,
   |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:69:9: D107 Missing docstring in `__init__`
   |
68 | class LightningTrainer(TorchTrainer):
69 |     def __init__(
   |         ^^^^^^^^ D107
70 |         self,
71 |         trainer_param: LightningTrainerParam,
   |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:76:89: E501 Line too long (90 > 88)
   |
74 |     ):
75 |         self.trainer_param = trainer_param
76 |         _logger.info("LightningTrainer initialized with trainer_param: %r", trainer_param)
   |                                                                                         ^^ E501
77 |         train_loop_config = asdict(trainer_param)
78 |         # Unique run id for Comet experiment
   |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:80:89: E501 Line too long (104 > 88)
   |
78 |         # Unique run id for Comet experiment
79 |         train_loop_config["run_id"] = str(uuid.uuid4())
80 |         # Pop out train and val data since we have to pass them into datasets parameter of TorchTrainer.
   |                                                                                         ^^^^^^^^^^^^^^^^ E501
81 |         train_data = train_loop_config.pop("train_data")
82 |         val_data = train_loop_config.pop("val_data")
   |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:92:9: D102 Missing docstring in public method
   |
90 |         )
91 |
92 |     def train(
   |         ^^^^^ D102
93 |         self,
94 |         run_config: Optional[ray.train.RunConfig] = None,
   |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:106:89: E501 Line too long (110 > 88)
    |
104 |             raise result.error
105 |
106 |         # User-specified LightningModule is saved in config field and cannot be serialized on uniflow for now.
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^^^ E501
107 |         # We take the config out.
108 |         result.metrics.pop("config", None)
    |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:109:89: E501 Line too long (102 > 88)
    |
107 |         # We take the config out.
108 |         result.metrics.pop("config", None)
109 |         # Keep the checkpoint object for subclasses that need it (e.g., LightningTrainerWithStateDict)
    |                                                                                         ^^^^^^^^^^^^^^ E501
110 |         self.checkpoint = result.checkpoint
111 |         return {
    |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:119:5: D200 One-line docstring should fit on one line
    |
118 |   class LightningTrainerWithStateDict(LightningTrainer):
119 | /     """
120 | |     LightningTrainer that provides functions to update model state dict from checkpoint.
121 | |     """
    | |_______^ D200
122 |
123 |       def _is_deepspeed_strategy(self) -> bool:
    |
    = help: Reformat to one line

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:119:5: D212 [*] Multi-line docstring summary should start at the first line
    |
118 |   class LightningTrainerWithStateDict(LightningTrainer):
119 | /     """
120 | |     LightningTrainer that provides functions to update model state dict from checkpoint.
121 | |     """
    | |_______^ D212
122 |
123 |       def _is_deepspeed_strategy(self) -> bool:
    |
    = help: Remove whitespace after opening quotes

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:129:89: E501 Line too long (94 > 88)
    |
127 |             return False
128 |
129 |         # DeepSpeed was used if the strategy is "deepspeed" or a RayDeepSpeedStrategy instance
    |                                                                                         ^^^^^^ E501
130 |         if isinstance(strategy, str):
131 |             return strategy.lower() == "deepspeed"
    |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:134:67: RUF100 [*] Unused `noqa` directive (non-enabled: `PLC0415`)
    |
133 |         try:
134 |             from ray.train.lightning import RayDeepSpeedStrategy  # noqa: PLC0415
    |                                                                   ^^^^^^^^^^^^^^^ RUF100
135 |
136 |             return isinstance(strategy, RayDeepSpeedStrategy)
    |
    = help: Remove unused `noqa` directive

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:141:9: D200 One-line docstring should fit on one line
    |
140 |       def update_model_state_dict(self, torch_model: torch.nn.Module):
141 | /         """
142 | |         Update the model state dict with the local checkpoint.
143 | |         """
    | |___________^ D200
144 |           if not hasattr(self, "checkpoint") or self.checkpoint is None:
145 |               raise ValueError("No checkpoint available. Please call train() first to generate a checkpoint.")
    |
    = help: Reformat to one line

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:141:9: D212 [*] Multi-line docstring summary should start at the first line
    |
140 |       def update_model_state_dict(self, torch_model: torch.nn.Module):
141 | /         """
142 | |         Update the model state dict with the local checkpoint.
143 | |         """
    | |___________^ D212
144 |           if not hasattr(self, "checkpoint") or self.checkpoint is None:
145 |               raise ValueError("No checkpoint available. Please call train() first to generate a checkpoint.")
    |
    = help: Remove whitespace after opening quotes

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:145:89: E501 Line too long (108 > 88)
    |
143 |         """
144 |         if not hasattr(self, "checkpoint") or self.checkpoint is None:
145 |             raise ValueError("No checkpoint available. Please call train() first to generate a checkpoint.")
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^ E501
146 |         used_deepspeed = self._is_deepspeed_strategy()
147 |         # use the ray checkpoint as_directory() to get the local temp checkpoint directory
    |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:147:89: E501 Line too long (90 > 88)
    |
145 |             raise ValueError("No checkpoint available. Please call train() first to generate a checkpoint.")
146 |         used_deepspeed = self._is_deepspeed_strategy()
147 |         # use the ray checkpoint as_directory() to get the local temp checkpoint directory
    |                                                                                         ^^ E501
148 |         with self.checkpoint.as_directory() as d:
149 |             _logger.info(f"Saving Ray Checkpoint to local temp Checkpoint directory: {d}")
    |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:149:89: E501 Line too long (90 > 88)
    |
147 |         # use the ray checkpoint as_directory() to get the local temp checkpoint directory
148 |         with self.checkpoint.as_directory() as d:
149 |             _logger.info(f"Saving Ray Checkpoint to local temp Checkpoint directory: {d}")
    |                                                                                         ^^ E501
150 |             data_dir_contents = os.listdir(d)
151 |             _logger.info(f"Data directory contents: {data_dir_contents}")
    |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:155:89: E501 Line too long (97 > 88)
    |
153 |             if used_deepspeed:
154 |                 local_model_path = os.path.join(lightning_ckpt_path, "model.pt")
155 |                 # PyTorch 2.6+ defaults weights_only=True, which rejects arbitrary Python classes
    |                                                                                         ^^^^^^^^^ E501
156 |                 # (LossScaler, DynamicLossScaler, optimizer states, etc.) embedded in DeepSpeed ZeRO
157 |                 # checkpoints. The env var reverts the default for any torch.load call that doesn't
    |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:156:89: E501 Line too long (100 > 88)
    |
154 |                 local_model_path = os.path.join(lightning_ckpt_path, "model.pt")
155 |                 # PyTorch 2.6+ defaults weights_only=True, which rejects arbitrary Python classes
156 |                 # (LossScaler, DynamicLossScaler, optimizer states, etc.) embedded in DeepSpeed ZeRO
    |                                                                                         ^^^^^^^^^^^^ E501
157 |                 # checkpoints. The env var reverts the default for any torch.load call that doesn't
158 |                 # explicitly pass weights_only, covering both pytorch_lightning and deepspeed internals.
    |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:157:89: E501 Line too long (99 > 88)
    |
155 |                 # PyTorch 2.6+ defaults weights_only=True, which rejects arbitrary Python classes
156 |                 # (LossScaler, DynamicLossScaler, optimizer states, etc.) embedded in DeepSpeed ZeRO
157 |                 # checkpoints. The env var reverts the default for any torch.load call that doesn't
    |                                                                                         ^^^^^^^^^^^ E501
158 |                 # explicitly pass weights_only, covering both pytorch_lightning and deepspeed internals.
159 |                 # TODO: Remove this once we upgrade to Lightning 2.6+ https://github.com/Lightning-AI/pytorch-lightning/pull/21194
    |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:158:89: E501 Line too long (104 > 88)
    |
156 |                 # (LossScaler, DynamicLossScaler, optimizer states, etc.) embedded in DeepSpeed ZeRO
157 |                 # checkpoints. The env var reverts the default for any torch.load call that doesn't
158 |                 # explicitly pass weights_only, covering both pytorch_lightning and deepspeed internals.
    |                                                                                         ^^^^^^^^^^^^^^^^ E501
159 |                 # TODO: Remove this once we upgrade to Lightning 2.6+ https://github.com/Lightning-AI/pytorch-lightning/pull/21194
160 |                 with _torch_weights_only_disabled():
    |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:161:89: E501 Line too long (120 > 88)
    |
159 |                 # TODO: Remove this once we upgrade to Lightning 2.6+ https://github.com/Lightning-AI/pytorch-lightning/pull/21194
160 |                 with _torch_weights_only_disabled():
161 |                     model_state_dict = convert_zero_checkpoint_to_fp32_state_dict(lightning_ckpt_path, local_model_path)
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E501
162 |                 _logger.info(f"Loaded DeepSpeed checkpoint from {lightning_ckpt_path} to {local_model_path}")
163 |             else:
    |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:162:89: E501 Line too long (109 > 88)
    |
160 |                 with _torch_weights_only_disabled():
161 |                     model_state_dict = convert_zero_checkpoint_to_fp32_state_dict(lightning_ckpt_path, local_model_path)
162 |                 _logger.info(f"Loaded DeepSpeed checkpoint from {lightning_ckpt_path} to {local_model_path}")
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^^ E501
163 |             else:
164 |                 # DDP checkpoint
    |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:169:89: E501 Line too long (90 > 88)
    |
167 |                 _logger.info(f"Loaded DDP checkpoint from {lightning_ckpt_path}")
168 |             torch_model.load_state_dict(model_state_dict, strict=False)
169 |             _logger.info("Updated the state dict of the torch model in the ModelVariable")
    |                                                                                         ^^ E501
    |

michelangelo/lib/trainer/torch/pytorch_lightning/lightning_trainer.py:174:89: E501 Line too long (100 > 88)
    |
172 | @contextmanager
173 | def _torch_weights_only_disabled():
174 |     """Force torch.load() to use weights_only=False for call sites that don't pass it explicitly."""
    |                                                                                         ^^^^^^^^^^^^ E501
175 |     key = "TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"
176 |     old = os.environ.pop(key, None)
    |

michelangelo/lib/trainer/torch/pytorch_lightning/schema.py:31:89: E501 Line too long (98 > 88)
   |
29 | @dataclass
30 | class ModelSpec:
31 |     """A reference to a model that may be loaded for incremental training or transfer learning."""
   |                                                                                         ^^^^^^^^^^ E501
32 |
33 |     project_name: str
   |

michelangelo/lib/trainer/torch/pytorch_lightning/schema.py:35:18: UP007 Use `X | Y` for type annotations
   |
33 |     project_name: str
34 |     model_name: str
35 |     revision_id: Optional[str] = None
   |                  ^^^^^^^^^^^^^ UP007
   |
   = help: Convert to `X | Y`

michelangelo/lib/trainer/torch/pytorch_lightning/schema.py:44:22: UP007 Use `X | Y` for type annotations
   |
42 |     training_type: TrainingType
43 |     baseline_model: ModelSpec
44 |     deployment_name: Optional[str] = None
   |                      ^^^^^^^^^^^^^ UP007
45 |     skip_training: bool = False
46 |     log_layer_weights: bool = False
   |
   = help: Convert to `X | Y`

michelangelo/lib/trainer/torch/pytorch_lightning/schema.py:55:42: UP007 Use `X | Y` for type annotations
   |
53 |     metadata: IncrementalTrainingMetadata
54 |     load_optimizer_weights: bool = False
55 |     override_incremental_training_epoch: Optional[int] = None
   |                                          ^^^^^^^^^^^^^ UP007
   |
   = help: Convert to `X | Y`

michelangelo/lib/trainer/torch/pytorch_lightning/schema.py:63:21: UP007 Use `X | Y` for type annotations
   |
62 |     learning_mode: LearningMode
63 |     baseline_model: Optional[ModelSpec]
   |                     ^^^^^^^^^^^^^^^^^^^ UP007
   |
   = help: Convert to `X | Y`

michelangelo/lib/trainer/torch/pytorch_lightning/schema.py:72:28: UP007 Use `X | Y` for type annotations
   |
70 |     metadata: TransferLearningMetadata
71 |
72 |     model_loader_function: Optional[str] = None
   |                            ^^^^^^^^^^^^^ UP007
73 |     layer_names_to_inherit: list[str] = field(default_factory=list)
74 |     layer_names_to_inherit_regex: list[str] = field(default_factory=list)
   |
   = help: Convert to `X | Y`

michelangelo/lib/trainer/torch/utils.py:1:1: D100 Missing docstring in public module
michelangelo/lib/trainer/torch/utils.py:8:89: E501 Line too long (109 > 88)
   |
 8 | def get_total_training_memory_transformers(model: AutoModel, batch_size: int, sequence_length: int) -> float:
   |                                                                                         ^^^^^^^^^^^^^^^^^^^^^ E501
 9 |     """
10 |     Get the total memory (in MB) required for training the model.
   |

michelangelo/lib/trainer/torch/utils.py:9:5: D205 1 blank line required between summary line and description
   |
 8 |   def get_total_training_memory_transformers(model: AutoModel, batch_size: int, sequence_length: int) -> float:
 9 | /     """
10 | |     Get the total memory (in MB) required for training the model.
11 | |     This function is specific to transformers models.
12 | |
13 | |     Reference: https://blog.eleuther.ai/transformer-math/
14 | |     """
   | |_______^ D205
15 |       hidden_size = model.config.hidden_size  # Hidden size for activations
16 |       num_layers = model.config.num_hidden_layers  # Number of layers in the model
   |
   = help: Insert single blank line

michelangelo/lib/trainer/torch/utils.py:9:5: D212 [*] Multi-line docstring summary should start at the first line
   |
 8 |   def get_total_training_memory_transformers(model: AutoModel, batch_size: int, sequence_length: int) -> float:
 9 | /     """
10 | |     Get the total memory (in MB) required for training the model.
11 | |     This function is specific to transformers models.
12 | |
13 | |     Reference: https://blog.eleuther.ai/transformer-math/
14 | |     """
   | |_______^ D212
15 |       hidden_size = model.config.hidden_size  # Hidden size for activations
16 |       num_layers = model.config.num_hidden_layers  # Number of layers in the model
   |
   = help: Remove whitespace after opening quotes

michelangelo/lib/trainer/torch/utils.py:22:89: E501 Line too long (91 > 88)
   |
20 |     tensor_parallelism = 1  # Number of tensor parallelism
21 |
22 |     bytes_per_parameter = torch.tensor([1]).to(dtype).element_size()  # Bytes per parameter
   |                                                                                         ^^^ E501
23 |
24 |     # Calculating each memory component in MB
   |

michelangelo/lib/trainer/torch/utils.py:33:89: E501 Line too long (102 > 88)
   |
31 |     # 3. Optimizer Memory (assuming two states per parameter for AdamW optimizer)
32 |     # Adam is magic, but it is highly memory inefficient.
33 |     # In addition to requiring you to have a copy of the model parameters and the gradient parameters,
   |                                                                                         ^^^^^^^^^^^^^^ E501
34 |     # you also need to keep an additional three copies of the gradient parameters.
35 |     optimizer_memory = 3 * parameter_memory
   |

michelangelo/lib/trainer/torch/utils.py:43:89: E501 Line too long (115 > 88)
   |
41 |         * sequence_length
42 |         * hidden_size
43 |         * (10 + 24 / tensor_parallelism + 5 * num_atten_heads * sequence_length / hidden_size / tensor_parallelism)
   |                                                                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^ E501
44 |         / (1024**2)
45 |     )
   |

michelangelo/lib/trainer/torch/utils.py:48:89: E501 Line too long (92 > 88)
   |
47 |     # fp16 uses 2 bytes
48 |     activation_memory_per_layer = bytes_per_parameter / 2 * fp16_activation_memory_per_layer
   |                                                                                         ^^^^ E501
49 |     activation_memory_total = activation_memory_per_layer * num_layers
   |

michelangelo/lib/trainer/torch/utils.py:51:89: E501 Line too long (98 > 88)
   |
49 |     activation_memory_total = activation_memory_per_layer * num_layers
50 |
51 |     # Summing up and adding 20% for additional buffers and overheads for GPU memory fragmentation.
   |                                                                                         ^^^^^^^^^^ E501
52 |     total_memory = (parameter_memory + activation_memory_total + gradient_memory + optimizer_memory) * 1.2
   |

michelangelo/lib/trainer/torch/utils.py:52:89: E501 Line too long (106 > 88)
   |
51 |     # Summing up and adding 20% for additional buffers and overheads for GPU memory fragmentation.
52 |     total_memory = (parameter_memory + activation_memory_total + gradient_memory + optimizer_memory) * 1.2
   |                                                                                         ^^^^^^^^^^^^^^^^^^ E501
53 |
54 |     return total_memory
   |

michelangelo/lib/trainer/torch/utils.py:58:5: D103 Missing docstring in public function
   |
57 | # Function to estimate activation memory on non-transformer layers
58 | def estimate_activation_memory_non_transformer(layer_output_dims, batch_size, bytes_per_value):
   |     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ D103
59 |     total_activation_memory_mb = 0
   |

michelangelo/lib/trainer/torch/utils.py:58:89: E501 Line too long (95 > 88)
   |
57 | # Function to estimate activation memory on non-transformer layers
58 | def estimate_activation_memory_non_transformer(layer_output_dims, batch_size, bytes_per_value):
   |                                                                                         ^^^^^^^ E501
59 |     total_activation_memory_mb = 0
   |

michelangelo/lib/trainer/torch/utils.py:71:89: E501 Line too long (107 > 88)
   |
71 | def get_total_training_memory_nn_module(model: torch.nn.Module, batch_size: int, input_size: int) -> float:
   |                                                                                         ^^^^^^^^^^^^^^^^^^^ E501
72 |     """
73 |     Get the total memory (in MB) required for training the model.
   |

michelangelo/lib/trainer/torch/utils.py:72:5: D202 [*] No blank lines allowed after function docstring (found 1)
   |
71 |   def get_total_training_memory_nn_module(model: torch.nn.Module, batch_size: int, input_size: int) -> float:
72 | /     """
73 | |     Get the total memory (in MB) required for training the model.
74 | |     This function is specific to non-transformers models.
75 | |     """
   | |_______^ D202
76 |
77 |       num_parameters = sum(p.numel() for p in model.parameters())
   |
   = help: Remove blank line(s) after function docstring

michelangelo/lib/trainer/torch/utils.py:72:5: D205 1 blank line required between summary line and description
   |
71 |   def get_total_training_memory_nn_module(model: torch.nn.Module, batch_size: int, input_size: int) -> float:
72 | /     """
73 | |     Get the total memory (in MB) required for training the model.
74 | |     This function is specific to non-transformers models.
75 | |     """
   | |_______^ D205
76 |
77 |       num_parameters = sum(p.numel() for p in model.parameters())
   |
   = help: Insert single blank line

michelangelo/lib/trainer/torch/utils.py:72:5: D212 [*] Multi-line docstring summary should start at the first line
   |
71 |   def get_total_training_memory_nn_module(model: torch.nn.Module, batch_size: int, input_size: int) -> float:
72 | /     """
73 | |     Get the total memory (in MB) required for training the model.
74 | |     This function is specific to non-transformers models.
75 | |     """
   | |_______^ D212
76 |
77 |       num_parameters = sum(p.numel() for p in model.parameters())
   |
   = help: Remove whitespace after opening quotes

michelangelo/lib/trainer/torch/utils.py:84:89: E501 Line too long (91 > 88)
   |
82 |         break
83 |
84 |     bytes_per_parameter = torch.tensor([1]).to(dtype).element_size()  # Bytes per parameter
   |                                                                                         ^^^ E501
85 |
86 |     # Calculating each memory component in MB
   |

michelangelo/lib/trainer/torch/utils.py:95:89: E501 Line too long (102 > 88)
   |
93 |     # 3. Optimizer Memory (assuming two states per parameter for AdamW optimizer)
94 |     # Adam is magic, but it is highly memory inefficient.
95 |     # In addition to requiring you to have a copy of the model parameters and the gradient parameters,
   |                                                                                         ^^^^^^^^^^^^^^ E501
96 |     # you also need to keep an additional three copies of the gradient parameters.
97 |     optimizer_memory = 3 * parameter_memory
   |

michelangelo/lib/trainer/torch/utils.py:108:89: E501 Line too long (120 > 88)
    |
106 |     # We only count Linear layers, Conv layers, Norm layers, and RNN layers
107 |     hooks = []
108 |     supported_layer_types = (nn.Linear, nn.modules.conv._ConvNd, nn.modules.batchnorm._NormBase, nn.modules.rnn.RNNBase)
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E501
109 |
110 |     for layer in model.children():
    |

michelangelo/lib/trainer/torch/utils.py:126:89: E501 Line too long (124 > 88)
    |
125 |     # Use captured output dimensions to estimate activation memory
126 |     total_activation_memory = estimate_activation_memory_non_transformer(layer_output_dims, batch_size, bytes_per_parameter)
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E501
127 |
128 |     # Summing up and adding 20% for additional buffers and overheads for GPU memory fragmentation.
    |

michelangelo/lib/trainer/torch/utils.py:128:89: E501 Line too long (98 > 88)
    |
126 |     total_activation_memory = estimate_activation_memory_non_transformer(layer_output_dims, batch_size, bytes_per_parameter)
127 |
128 |     # Summing up and adding 20% for additional buffers and overheads for GPU memory fragmentation.
    |                                                                                         ^^^^^^^^^^ E501
129 |     total_memory_mb = (parameter_memory + total_activation_memory + gradient_memory + optimizer_memory) * 1.2
    |

michelangelo/lib/trainer/torch/utils.py:129:89: E501 Line too long (109 > 88)
    |
128 |     # Summing up and adding 20% for additional buffers and overheads for GPU memory fragmentation.
129 |     total_memory_mb = (parameter_memory + total_activation_memory + gradient_memory + optimizer_memory) * 1.2
    |                                                                                         ^^^^^^^^^^^^^^^^^^^^^ E501
130 |
131 |     return total_memory_mb
    |

Found 199 errors.
[*] 17 fixable with the `--fix` option (13 hidden fixes can be enabled with the `--unsafe-fixes` option).

🛠 Run poetry install -E dev to set up pre-commit hooks and apply fixes locally. (or run poetry run pre-commit before run git commit)

@dkurra dkurra changed the title Snapshot internal Lightning trainer into lib/trainer/torch Lightning Trainer for MA OSS May 21, 2026
- Rewrite test_lightning_trainer.py to match the snapshot API:
  CometParam, LightningTrainerParam, LightningTrainer init/train,
  LightningTrainerWithStateDict strategy detection, and the
  _torch_weights_only_disabled env-var context (24 tests, all pass).

- Trainer source cleanups:
  - Lazy comet_ml import inside _get_comet_logger() so the package can
    be installed without comet_ml.
  - Replace all print() calls with module loggers.
  - Add full Google-style docstrings and type hints across public API.
  - Rename UserException -> UserInputError (ruff N818).
  - Snapshot disclosure on top-level package docstring; package
    __init__ now re-exports CometParam, LightningTrainer,
    LightningTrainerParam, LightningTrainerWithStateDict, and the
    schema dataclasses.

- pyproject:
  - Split trainer extras into trainer / trainer-comet / trainer-deepspeed
    so users pull only what they need.
  - Add per-file-ignore for E501 on the snapshot trainer package to
    preserve upstream line shape.

- Docs:
  - Mention the trainer package and MovieLens example in the top-level
    README.
  - Drop internal Slack URL and personal owner tags from PR/example docs.
@github-actions
Copy link
Copy Markdown

🛠 Ruff Check & Format Results

⚠️ Format Issues Detected

Show details
--- examples/movielens/data.py
+++ examples/movielens/data.py
@@ -23,7 +23,9 @@
 # Canonical source. Some sandboxed environments can't reach files.grouplens.org;
 # fall back to a github mirror of the same u.data when the canonical URL fails.
 _DATA_URL = "https://files.grouplens.org/datasets/movielens/ml-100k.zip"
-_UDATA_MIRROR_URL = "https://raw.githubusercontent.com/vinjn/MLStudy/master/data/movielens-100k/u.data"
+_UDATA_MIRROR_URL = (
+    "https://raw.githubusercontent.com/vinjn/MLStudy/master/data/movielens-100k/u.data"
+)
 _DEFAULT_CACHE_DIR = "/tmp/movielens_data"
 _NETWORK_TIMEOUT_SECONDS = 30
 
@@ -55,17 +57,23 @@
 
     try:
         _logger.info("Downloading MovieLens-100k from %s", _DATA_URL)
-        with urllib.request.urlopen(_DATA_URL, timeout=_NETWORK_TIMEOUT_SECONDS) as resp:
+        with urllib.request.urlopen(
+            _DATA_URL, timeout=_NETWORK_TIMEOUT_SECONDS
+        ) as resp:
             data = resp.read()
         with zipfile.ZipFile(io.BytesIO(data)) as z:
             z.extractall(cache_dir)
         _logger.info("Extracted MovieLens-100k to %s", cache_dir)
         return udata_path
     except (urllib.error.URLError, TimeoutError, OSError) as exc:
-        _logger.warning("Canonical URL failed (%s); falling back to %s", exc, _UDATA_MIRROR_URL)
+        _logger.warning(
+            "Canonical URL failed (%s); falling back to %s", exc, _UDATA_MIRROR_URL
+        )
 
     os.makedirs(extracted_dir, exist_ok=True)
-    with urllib.request.urlopen(_UDATA_MIRROR_URL, timeout=_NETWORK_TIMEOUT_SECONDS) as resp:
+    with urllib.request.urlopen(
+        _UDATA_MIRROR_URL, timeout=_NETWORK_TIMEOUT_SECONDS
+    ) as resp:
         udata_bytes = resp.read()
     with open(udata_path, "wb") as f:
         f.write(udata_bytes)
@@ -91,15 +99,26 @@
         sep="\t",
         header=None,
         names=["user_id", "item_id", "rating", "timestamp"],
-        dtype={"user_id": np.int64, "item_id": np.int64, "rating": np.int64, "timestamp": np.int64},
+        dtype={
+            "user_id": np.int64,
+            "item_id": np.int64,
+            "rating": np.int64,
+            "timestamp": np.int64,
+        },
     )
     _logger.info("Loaded %d ratings", len(df))
 
-    user_id_to_idx = {uid: idx for idx, uid in enumerate(sorted(df["user_id"].unique()))}
-    item_id_to_idx = {iid: idx for idx, iid in enumerate(sorted(df["item_id"].unique()))}
+    user_id_to_idx = {
+        uid: idx for idx, uid in enumerate(sorted(df["user_id"].unique()))
+    }
+    item_id_to_idx = {
+        iid: idx for idx, iid in enumerate(sorted(df["item_id"].unique()))
+    }
     df["user_idx"] = df["user_id"].map(user_id_to_idx).astype(np.int64)
     df["item_idx"] = df["item_id"].map(item_id_to_idx).astype(np.int64)
-    df["rating_norm"] = ((df["rating"].astype(np.float32) - 1.0) / 4.0).astype(np.float32)
+    df["rating_norm"] = ((df["rating"].astype(np.float32) - 1.0) / 4.0).astype(
+        np.float32
+    )
 
     num_users = len(user_id_to_idx)
     num_items = len(item_id_to_idx)

--- examples/movielens/model.py
+++ examples/movielens/model.py
@@ -52,7 +52,14 @@
         preds = self(user_idx, item_idx)
         loss = F.mse_loss(preds, target)
         # sync_dist=True so the metric is averaged across Ray Train workers.
-        self.log(f"{stage}_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
+        self.log(
+            f"{stage}_loss",
+            loss,
+            prog_bar=True,
+            on_step=False,
+            on_epoch=True,
+            sync_dist=True,
+        )
         return loss
 
     def training_step(self, batch, batch_idx):  # noqa: ARG002

--- examples/movielens/train.py
+++ examples/movielens/train.py
@@ -35,7 +35,9 @@
     LightningTrainerParam,
 )
 
-logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
+logging.basicConfig(
+    level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s"
+)
 log = logging.getLogger("examples.movielens.train")
 
 _STORAGE_DIR = "/tmp/movielens_runs"
@@ -60,7 +62,9 @@
         api_key=api_key,
         workspace=workspace,
         project_name=os.environ.get("COMET_PROJECT_NAME", _DEFAULT_COMET_PROJECT),
-        experiment_name=os.environ.get("COMET_EXPERIMENT_NAME", _DEFAULT_COMET_EXPERIMENT),
+        experiment_name=os.environ.get(
+            "COMET_EXPERIMENT_NAME", _DEFAULT_COMET_EXPERIMENT
+        ),
         tags=tags,
     )
 
@@ -95,7 +99,9 @@
     from pytorch_lightning.loggers import MLFlowLogger  # noqa: PLC0415
 
     return MLFlowLogger(
-        experiment_name=os.environ.get("MLFLOW_EXPERIMENT_NAME", _DEFAULT_MLFLOW_EXPERIMENT),
+        experiment_name=os.environ.get(
+            "MLFLOW_EXPERIMENT_NAME", _DEFAULT_MLFLOW_EXPERIMENT
+        ),
         tracking_uri=tracking_uri,
         run_name=os.environ.get("MLFLOW_RUN_NAME"),
         tags=_parse_mlflow_tags(os.environ.get("MLFLOW_TAGS", "")),
@@ -116,7 +122,9 @@
             comet_param.experiment_name,
         )
         if os.environ.get("MLFLOW_TRACKING_URI"):
-            log.info("MLFLOW_TRACKING_URI is also set but Comet takes precedence; MLflow logging skipped.")
+            log.info(
+                "MLFLOW_TRACKING_URI is also set but Comet takes precedence; MLflow logging skipped."
+            )
     else:
         mlflow_logger = _build_mlflow_logger()
         if mlflow_logger is not None:
@@ -126,7 +134,9 @@
                 mlflow_logger.experiment_id,
             )
         else:
-            log.info("Experiment tracking disabled (no COMET_* or MLFLOW_TRACKING_URI env vars set)")
+            log.info(
+                "Experiment tracking disabled (no COMET_* or MLFLOW_TRACKING_URI env vars set)"
+            )
 
     lightning_trainer_kwargs = {
         # Don't pass accelerator/devices here: ray.train.lightning.prepare_trainer

📌 Please run poetry run ruff format . locally to fix formatting issues.

🚨 Lint Issues Detected

Show details
examples/movielens/data.py:58:89: E501 Line too long (89 > 88)
   |
56 |     try:
57 |         _logger.info("Downloading MovieLens-100k from %s", _DATA_URL)
58 |         with urllib.request.urlopen(_DATA_URL, timeout=_NETWORK_TIMEOUT_SECONDS) as resp:
   |                                                                                         ^ E501
59 |             data = resp.read()
60 |         with zipfile.ZipFile(io.BytesIO(data)) as z:
   |

examples/movielens/data.py:65:89: E501 Line too long (96 > 88)
   |
63 |         return udata_path
64 |     except (urllib.error.URLError, TimeoutError, OSError) as exc:
65 |         _logger.warning("Canonical URL failed (%s); falling back to %s", exc, _UDATA_MIRROR_URL)
   |                                                                                         ^^^^^^^^ E501
66 |
67 |     os.makedirs(extracted_dir, exist_ok=True)
   |

examples/movielens/data.py:68:89: E501 Line too long (93 > 88)
   |
67 |     os.makedirs(extracted_dir, exist_ok=True)
68 |     with urllib.request.urlopen(_UDATA_MIRROR_URL, timeout=_NETWORK_TIMEOUT_SECONDS) as resp:
   |                                                                                         ^^^^^ E501
69 |         udata_bytes = resp.read()
70 |     with open(udata_path, "wb") as f:
   |

examples/movielens/data.py:94:89: E501 Line too long (100 > 88)
   |
92 |         header=None,
93 |         names=["user_id", "item_id", "rating", "timestamp"],
94 |         dtype={"user_id": np.int64, "item_id": np.int64, "rating": np.int64, "timestamp": np.int64},
   |                                                                                         ^^^^^^^^^^^^ E501
95 |     )
96 |     _logger.info("Loaded %d ratings", len(df))
   |

examples/movielens/data.py:98:89: E501 Line too long (89 > 88)
    |
 96 |     _logger.info("Loaded %d ratings", len(df))
 97 |
 98 |     user_id_to_idx = {uid: idx for idx, uid in enumerate(sorted(df["user_id"].unique()))}
    |                                                                                         ^ E501
 99 |     item_id_to_idx = {iid: idx for idx, iid in enumerate(sorted(df["item_id"].unique()))}
100 |     df["user_idx"] = df["user_id"].map(user_id_to_idx).astype(np.int64)
    |

examples/movielens/data.py:99:89: E501 Line too long (89 > 88)
    |
 98 |     user_id_to_idx = {uid: idx for idx, uid in enumerate(sorted(df["user_id"].unique()))}
 99 |     item_id_to_idx = {iid: idx for idx, iid in enumerate(sorted(df["item_id"].unique()))}
    |                                                                                         ^ E501
100 |     df["user_idx"] = df["user_id"].map(user_id_to_idx).astype(np.int64)
101 |     df["item_idx"] = df["item_id"].map(item_id_to_idx).astype(np.int64)
    |

examples/movielens/data.py:102:89: E501 Line too long (90 > 88)
    |
100 |     df["user_idx"] = df["user_id"].map(user_id_to_idx).astype(np.int64)
101 |     df["item_idx"] = df["item_id"].map(item_id_to_idx).astype(np.int64)
102 |     df["rating_norm"] = ((df["rating"].astype(np.float32) - 1.0) / 4.0).astype(np.float32)
    |                                                                                         ^^ E501
103 |
104 |     num_users = len(user_id_to_idx)
    |

examples/movielens/model.py:8:8: N812 Lowercase `functional` imported as non-lowercase `F`
  |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
  |        ^^^^^^^^^^^^^^^^^^^^^^^^ N812
  |

examples/movielens/model.py:19:9: D107 Missing docstring in `__init__`
   |
17 |     """
18 |
19 |     def __init__(
   |         ^^^^^^^^ D107
20 |         self,
21 |         num_users: int,
   |

examples/movielens/model.py:41:9: D102 Missing docstring in public method
   |
39 |         nn.init.normal_(self.item_emb.weight, std=0.01)
40 |
41 |     def forward(self, user_idx: torch.Tensor, item_idx: torch.Tensor) -> torch.Tensor:
   |         ^^^^^^^ D102
42 |         u = self.user_emb(user_idx)
43 |         i = self.item_emb(item_idx)
   |

examples/movielens/model.py:55:89: E501 Line too long (100 > 88)
   |
53 |         loss = F.mse_loss(preds, target)
54 |         # sync_dist=True so the metric is averaged across Ray Train workers.
55 |         self.log(f"{stage}_loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)
   |                                                                                         ^^^^^^^^^^^^ E501
56 |         return loss
   |

examples/movielens/model.py:58:9: D102 Missing docstring in public method
   |
56 |         return loss
57 |
58 |     def training_step(self, batch, batch_idx):  # noqa: ARG002
   |         ^^^^^^^^^^^^^ D102
59 |         return self._step(batch, "train")
   |

examples/movielens/model.py:58:49: RUF100 [*] Unused `noqa` directive (non-enabled: `ARG002`)
   |
56 |         return loss
57 |
58 |     def training_step(self, batch, batch_idx):  # noqa: ARG002
   |                                                 ^^^^^^^^^^^^^^ RUF100
59 |         return self._step(batch, "train")
   |
   = help: Remove unused `noqa` directive

examples/movielens/model.py:61:9: D102 Missing docstring in public method
   |
59 |         return self._step(batch, "train")
60 |
61 |     def validation_step(self, batch, batch_idx):  # noqa: ARG002
   |         ^^^^^^^^^^^^^^^ D102
62 |         return self._step(batch, "val")
   |

examples/movielens/model.py:61:51: RUF100 [*] Unused `noqa` directive (non-enabled: `ARG002`)
   |
59 |         return self._step(batch, "train")
60 |
61 |     def validation_step(self, batch, batch_idx):  # noqa: ARG002
   |                                                   ^^^^^^^^^^^^^^ RUF100
62 |         return self._step(batch, "val")
   |
   = help: Remove unused `noqa` directive

examples/movielens/model.py:64:9: D102 Missing docstring in public method
   |
62 |         return self._step(batch, "val")
63 |
64 |     def configure_optimizers(self):
   |         ^^^^^^^^^^^^^^^^^^^^ D102
65 |         return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
   |

examples/movielens/train.py:38:89: E501 Line too long (97 > 88)
   |
36 | )
37 |
38 | logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
   |                                                                                         ^^^^^^^^^ E501
39 | log = logging.getLogger("examples.movielens.train")
   |

examples/movielens/train.py:47:29: UP007 Use `X | Y` for type annotations
   |
47 | def _build_comet_param() -> Optional[CometParam]:
   |                             ^^^^^^^^^^^^^^^^^^^^ UP007
48 |     """Build a CometParam from env vars, or return None to skip Comet logging.
   |
   = help: Convert to `X | Y`

examples/movielens/train.py:63:89: E501 Line too long (91 > 88)
   |
61 |         workspace=workspace,
62 |         project_name=os.environ.get("COMET_PROJECT_NAME", _DEFAULT_COMET_PROJECT),
63 |         experiment_name=os.environ.get("COMET_EXPERIMENT_NAME", _DEFAULT_COMET_EXPERIMENT),
   |                                                                                         ^^^ E501
64 |         tags=tags,
65 |     )
   |

examples/movielens/train.py:68:42: UP007 Use `X | Y` for type annotations
   |
68 | def _parse_mlflow_tags(tags_env: str) -> Optional[dict]:
   |                                          ^^^^^^^^^^^^^^ UP007
69 |     """Parse ``key1=val1,key2=val2`` into a dict; return None if empty/malformed."""
70 |     if not tags_env.strip():
   |
   = help: Convert to `X | Y`

examples/movielens/train.py:95:57: RUF100 [*] Unused `noqa` directive (non-enabled: `PLC0415`)
   |
93 |         return None
94 |     # Import lazily so an unused MLflow path doesn't force the dependency.
95 |     from pytorch_lightning.loggers import MLFlowLogger  # noqa: PLC0415
   |                                                         ^^^^^^^^^^^^^^^ RUF100
96 |
97 |     return MLFlowLogger(
   |
   = help: Remove unused `noqa` directive

examples/movielens/train.py:98:89: E501 Line too long (93 > 88)
    |
 97 |     return MLFlowLogger(
 98 |         experiment_name=os.environ.get("MLFLOW_EXPERIMENT_NAME", _DEFAULT_MLFLOW_EXPERIMENT),
    |                                                                                         ^^^^^ E501
 99 |         tracking_uri=tracking_uri,
100 |         run_name=os.environ.get("MLFLOW_RUN_NAME"),
    |

examples/movielens/train.py:105:5: D103 Missing docstring in public function
    |
105 | def main() -> dict:
    |     ^^^^ D103
106 |     splits = load_movielens_100k()
    |

examples/movielens/train.py:119:89: E501 Line too long (107 > 88)
    |
117 |         )
118 |         if os.environ.get("MLFLOW_TRACKING_URI"):
119 |             log.info("MLFLOW_TRACKING_URI is also set but Comet takes precedence; MLflow logging skipped.")
    |                                                                                         ^^^^^^^^^^^^^^^^^^^ E501
120 |     else:
121 |         mlflow_logger = _build_mlflow_logger()
    |

examples/movielens/train.py:129:89: E501 Line too long (101 > 88)
    |
127 |             )
128 |         else:
129 |             log.info("Experiment tracking disabled (no COMET_* or MLFLOW_TRACKING_URI env vars set)")
    |                                                                                         ^^^^^^^^^^^^^ E501
130 |
131 |     lightning_trainer_kwargs = {
    |

Found 25 errors.
[*] 3 fixable with the `--fix` option (2 hidden fixes can be enabled with the `--unsafe-fixes` option).

🛠 Run poetry install -E dev to set up pre-commit hooks and apply fixes locally. (or run poetry run pre-commit before run git commit)

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 21, 2026

Coverage report

Click to see where and how coverage changed

FileStatementsMissingCoverageCoverage
(new stmts)
Lines missing
  python/michelangelo/lib/trainer/torch
  _numpy_utils.py 50, 77, 82, 86, 106, 112, 148, 172-173
  data_collate_functions.py 40, 45-47, 91, 101, 105, 148, 176
  python/michelangelo/lib/trainer/torch/pytorch_lightning
  __init__.py
  lightning_trainer.py
  schema.py
Project Total  

This report was generated by python-coverage-comment-action

- examples/movielens: drop Optional[X] for X | None, line-wrap long
  strings/calls, add docstrings, replace `import nn as F` alias.
- pyproject coverage.run: omit lib/trainer/**/_private/* — those modules
  run inside Ray Train worker processes and are exercised by integration
  tests, not unit tests; the _private/ convention already marks them as
  not-for-direct-import.
- Add DDP and DeepSpeed path tests for
  LightningTrainerWithStateDict.update_model_state_dict so the public
  lightning_trainer.py clears the 90% diff-coverage threshold.
@github-actions
Copy link
Copy Markdown

🛠 Ruff Check & Format Results

🚨 Lint Issues Detected

Show details
tests/michelangelo/lib/trainer/torch/pytorch_lightning/test_lightning_trainer.py:355:89: E501 Line too long (97 > 88)
    |
354 |     def test_update_state_dict_ddp_path(self, tmp_path):
355 |         """DDP path: ``torch.load`` is called and its ``state_dict`` is loaded into the model."""
    |                                                                                         ^^^^^^^^^ E501
356 |         trainer = self._build()  # no strategy → DDP path
    |

tests/michelangelo/lib/trainer/torch/pytorch_lightning/test_lightning_trainer.py:359:91: RUF100 [*] Unused `noqa` directive (non-enabled: `PLC0415`)
    |
358 |         # Build a fake checkpoint directory containing a CHECKPOINT_NAME file.
359 |         from michelangelo.lib.trainer.torch.pytorch_lightning.lightning_trainer import (  # noqa: PLC0415
    |                                                                                           ^^^^^^^^^^^^^^^ RUF100
360 |             CHECKPOINT_NAME,
361 |         )
    |
    = help: Remove unused `noqa` directive

tests/michelangelo/lib/trainer/torch/pytorch_lightning/test_lightning_trainer.py:385:89: E501 Line too long (90 > 88)
    |
384 |     def test_update_state_dict_deepspeed_path(self, tmp_path):
385 |         """DeepSpeed path: ZeRO-conversion helper is called inside the env-var context."""
    |                                                                                         ^^ E501
386 |         trainer = self._build(lightning_trainer_kwargs={"strategy": "deepspeed"})
    |

tests/michelangelo/lib/trainer/torch/pytorch_lightning/test_lightning_trainer.py:388:91: RUF100 [*] Unused `noqa` directive (non-enabled: `PLC0415`)
    |
386 |         trainer = self._build(lightning_trainer_kwargs={"strategy": "deepspeed"})
387 |
388 |         from michelangelo.lib.trainer.torch.pytorch_lightning.lightning_trainer import (  # noqa: PLC0415
    |                                                                                           ^^^^^^^^^^^^^^^ RUF100
389 |             CHECKPOINT_NAME,
390 |         )
    |
    = help: Remove unused `noqa` directive

Found 4 errors.
[*] 2 fixable with the `--fix` option.

🛠 Run poetry install -E dev to set up pre-commit hooks and apply fixes locally. (or run poetry run pre-commit before run git commit)

dkurra added 2 commits May 21, 2026 20:27
The per-file-ignore for E501 covers the trainer source but not the
tests directory, so trim the two new docstrings under 88 chars and
drop the now-unneeded `# noqa: PLC0415` markers.
Closes the test-coverage gaps the OSS review team called out:

- tests/.../test_numpy_utils.py — 25 tests for sentinel_for_numpy_dtype
  (floats/ints/unicode/object/bytes/bool + unsupported raises),
  infer_dtype recursion, and pad_ragged_tensor across 1-D/2-D rags
  and float/int sentinel injection.
- tests/.../test_data_collate_functions.py — 33 tests for the
  structural helpers, pad_ragged_lists (1-D / 2-D / 3-D / explicit
  sentinel), the per-column collate functions, the batch path,
  and the LiteralEvalFloat32Collate OO wrapper.
- tests/.../test_resolve_helpers.py — 38 tests for _resolve_strategy
  / _resolve_plugins / _resolve_logger / _resolve_callbacks covering
  string + instance + None inputs, type-check error paths, the
  default Ray report callback append, and DeepSpeed/FSDP per-node
  callback routing. DeepSpeed strategy + Ray report callbacks are
  patched at the resolver-module level so the tests do not need a
  GPU driver or an active Ray Train session.
- examples/movielens/README.md — document the new trainer extras
  split: trainer / trainer-comet / trainer-deepspeed; add the
  trainer-comet install hint to the Comet section.

122 trainer tests now pass locally.
@dkurra dkurra marked this pull request as ready for review May 21, 2026 23:15
@dkurra dkurra merged commit c3845df into main May 22, 2026
7 checks passed
@dkurra dkurra deleted the trainer-snapshot-from-internal branch May 22, 2026 06:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants