Skip to content

Commit f2264e8

Browse files
committed
try manually writing logic that kinda makes sense:
1 parent 322d03d commit f2264e8

File tree

1 file changed

+27
-7
lines changed

1 file changed

+27
-7
lines changed

src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,7 @@ def __call__(
749749
max_sequence_length: int = 256,
750750
temporal_tile_size: int = 80,
751751
temporal_overlap: int = 24,
752+
temporal_overlap_cond_strength: float = 0.5,
752753
horizontal_tiles: int = 1,
753754
vertical_tiles: int = 1,
754755
spatial_overlap: int = 1,
@@ -977,12 +978,21 @@ def __call__(
977978
last_latent_tile_num_frames = last_latent_chunk.shape[2]
978979
latent_chunk = torch.cat([last_latent_chunk, latent_chunk], dim=2)
979980
total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames
981+
last_latent_chunk = self._pack_latents(
982+
last_latent_chunk,
983+
self.transformer_spatial_patch_size,
984+
self.transformer_temporal_patch_size,
985+
)
986+
last_latent_chunk_num_tokens = last_latent_chunk.shape[1]
987+
if self.do_classifier_free_guidance:
988+
last_latent_chunk = torch.cat([last_latent_chunk, last_latent_chunk], dim=0)
980989

981990
conditioning_mask = torch.zeros(
982991
(batch_size, total_latent_num_frames),
983992
dtype=torch.float32,
984993
device=device,
985994
)
995+
# conditioning_mask[:, :last_latent_tile_num_frames] = temporal_overlap_cond_strength
986996
conditioning_mask[:, :last_latent_tile_num_frames] = 1.0
987997
else:
988998
total_latent_num_frames = latent_tile_num_frames
@@ -1041,9 +1051,17 @@ def __call__(
10411051
torch.cat([latent_chunk] * 2) if self.do_classifier_free_guidance else latent_chunk
10421052
)
10431053
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
1044-
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
1045-
if start_index > 0:
1046-
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
1054+
# Create timestep tensor that has prod(latent_model_input.shape) elements
1055+
if start_index == 0:
1056+
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1)
1057+
else:
1058+
timestep = t.view(1, 1).expand((latent_model_input.shape[:-1])).clone()
1059+
timestep[:, :last_latent_chunk_num_tokens] = 0.0
1060+
1061+
timestep = timestep.float()
1062+
# timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
1063+
# if start_index > 0:
1064+
# timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
10471065

10481066
with self.transformer.cache_context("cond_uncond"):
10491067
noise_pred = self.transformer(
@@ -1075,8 +1093,11 @@ def __call__(
10751093
if start_index == 0:
10761094
latent_chunk = denoised_latent_chunk
10771095
else:
1078-
tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
1079-
latent_chunk = torch.where(tokens_to_denoise_mask, denoised_latent_chunk, latent_chunk)
1096+
latent_chunk = torch.cat(
1097+
[last_latent_chunk, denoised_latent_chunk[:, last_latent_chunk_num_tokens:]], dim=1
1098+
)
1099+
# tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
1100+
# latent_chunk = torch.where(tokens_to_denoise_mask, denoised_latent_chunk, latent_chunk)
10801101

10811102
if callback_on_step_end is not None:
10821103
callback_kwargs = {}
@@ -1108,8 +1129,7 @@ def __call__(
11081129
if start_index == 0:
11091130
first_tile_out_latents = latent_chunk.clone()
11101131
else:
1111-
# We drop the first latent frame as it's a reinterpreted 8-frame latent understood as 1-frame latent
1112-
latent_chunk = latent_chunk[:, :, last_latent_tile_num_frames + 1 :, :, :]
1132+
latent_chunk = latent_chunk[:, :, last_latent_tile_num_frames:-1, :, :]
11131133
latent_chunk = LTXLatentUpsamplePipeline.adain_filter_latent(
11141134
latent_chunk, first_tile_out_latents, adain_factor
11151135
)

0 commit comments

Comments
 (0)