Skip to content

Commit 0e97669

Browse files
committed
try generating in reverse like... like what seems to be done in original codebase
1 parent f2264e8 commit 0e97669

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

src/diffusers/pipelines/ltx/pipeline_ltx_condition_infinite.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -974,9 +974,11 @@ def __call__(
974974
latent_tile_num_frames = latent_chunk.shape[2]
975975

976976
if start_index > 0:
977-
last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1)
977+
# last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1)
978+
last_latent_chunk = self._select_latents(tile_out_latents, 0, temporal_overlap - 1)
979+
last_latent_chunk = torch.flip(last_latent_chunk, dims=[2])
978980
last_latent_tile_num_frames = last_latent_chunk.shape[2]
979-
latent_chunk = torch.cat([last_latent_chunk, latent_chunk], dim=2)
981+
latent_chunk = torch.cat([latent_chunk, last_latent_chunk], dim=2)
980982
total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames
981983
last_latent_chunk = self._pack_latents(
982984
last_latent_chunk,
@@ -993,7 +995,9 @@ def __call__(
993995
device=device,
994996
)
995997
# conditioning_mask[:, :last_latent_tile_num_frames] = temporal_overlap_cond_strength
996-
conditioning_mask[:, :last_latent_tile_num_frames] = 1.0
998+
# conditioning_mask[:, :last_latent_tile_num_frames] = 1.0
999+
conditioning_mask[:, -last_latent_tile_num_frames:] = temporal_overlap_cond_strength
1000+
# conditioning_mask[:, -last_latent_tile_num_frames:] = 1.0
9971001
else:
9981002
total_latent_num_frames = latent_tile_num_frames
9991003

@@ -1051,14 +1055,14 @@ def __call__(
10511055
torch.cat([latent_chunk] * 2) if self.do_classifier_free_guidance else latent_chunk
10521056
)
10531057
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
1054-
# Create timestep tensor that has prod(latent_model_input.shape) elements
1058+
10551059
if start_index == 0:
10561060
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1)
10571061
else:
10581062
timestep = t.view(1, 1).expand((latent_model_input.shape[:-1])).clone()
1059-
timestep[:, :last_latent_chunk_num_tokens] = 0.0
1060-
1063+
timestep[:, -last_latent_chunk_num_tokens:] = 0.0
10611064
timestep = timestep.float()
1065+
10621066
# timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
10631067
# if start_index > 0:
10641068
# timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
@@ -1094,7 +1098,8 @@ def __call__(
10941098
latent_chunk = denoised_latent_chunk
10951099
else:
10961100
latent_chunk = torch.cat(
1097-
[last_latent_chunk, denoised_latent_chunk[:, last_latent_chunk_num_tokens:]], dim=1
1101+
[denoised_latent_chunk[:, :-last_latent_chunk_num_tokens], last_latent_chunk],
1102+
dim=1,
10981103
)
10991104
# tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
11001105
# latent_chunk = torch.where(tokens_to_denoise_mask, denoised_latent_chunk, latent_chunk)
@@ -1129,7 +1134,7 @@ def __call__(
11291134
if start_index == 0:
11301135
first_tile_out_latents = latent_chunk.clone()
11311136
else:
1132-
latent_chunk = latent_chunk[:, :, last_latent_tile_num_frames:-1, :, :]
1137+
latent_chunk = latent_chunk[:, :, 1:-last_latent_tile_num_frames, :, :]
11331138
latent_chunk = LTXLatentUpsamplePipeline.adain_filter_latent(
11341139
latent_chunk, first_tile_out_latents, adain_factor
11351140
)
@@ -1140,10 +1145,10 @@ def __call__(
11401145
# Combine samples
11411146
t_minus_one = temporal_overlap - 1
11421147
parts = [
1143-
tile_out_latents[:, :, :-t_minus_one],
1144-
alpha * tile_out_latents[:, :, -t_minus_one:]
1145-
+ (1 - alpha) * latent_chunk[:, :, :t_minus_one],
1146-
latent_chunk[:, :, t_minus_one:],
1148+
latent_chunk[:, :, :-t_minus_one],
1149+
(1 - alpha) * latent_chunk[:, :, -t_minus_one:]
1150+
+ alpha * tile_out_latents[:, :, :t_minus_one],
1151+
tile_out_latents[:, :, t_minus_one:],
11471152
]
11481153
latent_chunk = torch.cat(parts, dim=2)
11491154

@@ -1152,7 +1157,7 @@ def __call__(
11521157
tile_weights = self._create_spatial_weights(
11531158
tile_out_latents, v, h, horizontal_tiles, vertical_tiles, spatial_overlap
11541159
)
1155-
final_latents[:, :, :, v_start:v_end, h_start:h_end] += latent_chunk * tile_weights
1160+
final_latents[:, :, :, v_start:v_end, h_start:h_end] += tile_out_latents * tile_weights
11561161
weights[:, :, :, v_start:v_end, h_start:h_end] += tile_weights
11571162

11581163
eps = 1e-8

0 commit comments

Comments
 (0)