@@ -974,9 +974,11 @@ def __call__(
974
974
latent_tile_num_frames = latent_chunk .shape [2 ]
975
975
976
976
if start_index > 0 :
977
- last_latent_chunk = self ._select_latents (tile_out_latents , - temporal_overlap , - 1 )
977
+ # last_latent_chunk = self._select_latents(tile_out_latents, -temporal_overlap, -1)
978
+ last_latent_chunk = self ._select_latents (tile_out_latents , 0 , temporal_overlap - 1 )
979
+ last_latent_chunk = torch .flip (last_latent_chunk , dims = [2 ])
978
980
last_latent_tile_num_frames = last_latent_chunk .shape [2 ]
979
- latent_chunk = torch .cat ([last_latent_chunk , latent_chunk ], dim = 2 )
981
+ latent_chunk = torch .cat ([latent_chunk , last_latent_chunk ], dim = 2 )
980
982
total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames
981
983
last_latent_chunk = self ._pack_latents (
982
984
last_latent_chunk ,
@@ -993,7 +995,9 @@ def __call__(
993
995
device = device ,
994
996
)
995
997
# conditioning_mask[:, :last_latent_tile_num_frames] = temporal_overlap_cond_strength
996
- conditioning_mask [:, :last_latent_tile_num_frames ] = 1.0
998
+ # conditioning_mask[:, :last_latent_tile_num_frames] = 1.0
999
+ conditioning_mask [:, - last_latent_tile_num_frames :] = temporal_overlap_cond_strength
1000
+ # conditioning_mask[:, -last_latent_tile_num_frames:] = 1.0
997
1001
else :
998
1002
total_latent_num_frames = latent_tile_num_frames
999
1003
@@ -1051,14 +1055,14 @@ def __call__(
1051
1055
torch .cat ([latent_chunk ] * 2 ) if self .do_classifier_free_guidance else latent_chunk
1052
1056
)
1053
1057
latent_model_input = latent_model_input .to (prompt_embeds .dtype )
1054
- # Create timestep tensor that has prod(latent_model_input.shape) elements
1058
+
1055
1059
if start_index == 0 :
1056
1060
timestep = t .expand (latent_model_input .shape [0 ]).unsqueeze (- 1 )
1057
1061
else :
1058
1062
timestep = t .view (1 , 1 ).expand ((latent_model_input .shape [:- 1 ])).clone ()
1059
- timestep [:, :last_latent_chunk_num_tokens ] = 0.0
1060
-
1063
+ timestep [:, - last_latent_chunk_num_tokens :] = 0.0
1061
1064
timestep = timestep .float ()
1065
+
1062
1066
# timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
1063
1067
# if start_index > 0:
1064
1068
# timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
@@ -1094,7 +1098,8 @@ def __call__(
1094
1098
latent_chunk = denoised_latent_chunk
1095
1099
else :
1096
1100
latent_chunk = torch .cat (
1097
- [last_latent_chunk , denoised_latent_chunk [:, last_latent_chunk_num_tokens :]], dim = 1
1101
+ [denoised_latent_chunk [:, :- last_latent_chunk_num_tokens ], last_latent_chunk ],
1102
+ dim = 1 ,
1098
1103
)
1099
1104
# tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
1100
1105
# latent_chunk = torch.where(tokens_to_denoise_mask, denoised_latent_chunk, latent_chunk)
@@ -1129,7 +1134,7 @@ def __call__(
1129
1134
if start_index == 0 :
1130
1135
first_tile_out_latents = latent_chunk .clone ()
1131
1136
else :
1132
- latent_chunk = latent_chunk [:, :, last_latent_tile_num_frames : - 1 , :, :]
1137
+ latent_chunk = latent_chunk [:, :, 1 : - last_latent_tile_num_frames , :, :]
1133
1138
latent_chunk = LTXLatentUpsamplePipeline .adain_filter_latent (
1134
1139
latent_chunk , first_tile_out_latents , adain_factor
1135
1140
)
@@ -1140,10 +1145,10 @@ def __call__(
1140
1145
# Combine samples
1141
1146
t_minus_one = temporal_overlap - 1
1142
1147
parts = [
1143
- tile_out_latents [:, :, :- t_minus_one ],
1144
- alpha * tile_out_latents [:, :, - t_minus_one :]
1145
- + ( 1 - alpha ) * latent_chunk [:, :, :t_minus_one ],
1146
- latent_chunk [:, :, t_minus_one :],
1148
+ latent_chunk [:, :, :- t_minus_one ],
1149
+ ( 1 - alpha ) * latent_chunk [:, :, - t_minus_one :]
1150
+ + alpha * tile_out_latents [:, :, :t_minus_one ],
1151
+ tile_out_latents [:, :, t_minus_one :],
1147
1152
]
1148
1153
latent_chunk = torch .cat (parts , dim = 2 )
1149
1154
@@ -1152,7 +1157,7 @@ def __call__(
1152
1157
tile_weights = self ._create_spatial_weights (
1153
1158
tile_out_latents , v , h , horizontal_tiles , vertical_tiles , spatial_overlap
1154
1159
)
1155
- final_latents [:, :, :, v_start :v_end , h_start :h_end ] += latent_chunk * tile_weights
1160
+ final_latents [:, :, :, v_start :v_end , h_start :h_end ] += tile_out_latents * tile_weights
1156
1161
weights [:, :, :, v_start :v_end , h_start :h_end ] += tile_weights
1157
1162
1158
1163
eps = 1e-8
0 commit comments