Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/maxdiffusion/configs/ltx2_3_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ profiler_steps: 5

replicate_vae: False

run_text_encoder_on_tpu: False
# Dynamically disables VAE slicing and distributes the batch dimension to avoid HBM OOM for larger batch sizes.
enable_dynamic_vae_sharding: True

allow_split_physical_axes: False
learning_rate_schedule_steps: -1
max_train_steps: 500
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/configs/ltx2_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ profiler_steps: 5
replicate_vae: False
use_bwe: False

run_text_encoder_on_tpu: True
run_text_encoder_on_tpu: False
# Dynamically disables VAE slicing and distributes the batch dimension to avoid HBM OOM for larger batch sizes.
enable_dynamic_vae_sharding: True
allow_split_physical_axes: False
Expand Down
9 changes: 5 additions & 4 deletions src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from torchax import default_env
from maxdiffusion.models.ltx2.text_encoders.torchax_text_encoder import TorchaxGemma3TextEncoder
from maxdiffusion.maxdiffusion_utils import get_dummy_ltx2_inputs
import contextlib
import flax
Expand Down Expand Up @@ -352,7 +350,10 @@ def load_text_encoder(cls, config: HyperParameters):
)
text_encoder.eval()

if getattr(config, "run_text_encoder_on_tpu", True):
if getattr(config, "run_text_encoder_on_tpu", False):
from torchax import default_env
from maxdiffusion.models.ltx2.text_encoders.torchax_text_encoder import TorchaxGemma3TextEncoder

with default_env():
text_encoder = text_encoder.to("jax")
text_encoder = TorchaxGemma3TextEncoder(text_encoder)
Expand Down Expand Up @@ -855,7 +856,7 @@ def _get_gemma_prompt_embeds(
prompt = [p.strip() for p in prompt]

if self.text_encoder is not None:
run_text_encoder_on_tpu = getattr(self.config, "run_text_encoder_on_tpu", True) if hasattr(self, "config") else True
run_text_encoder_on_tpu = getattr(self.config, "run_text_encoder_on_tpu", False) if hasattr(self, "config") else False
if run_text_encoder_on_tpu:
# Torchax Text Encoder
text_inputs = self.tokenizer(
Expand Down
Loading