@@ -749,6 +749,7 @@ def __call__(
749
749
max_sequence_length : int = 256 ,
750
750
temporal_tile_size : int = 80 ,
751
751
temporal_overlap : int = 24 ,
752
+ temporal_overlap_cond_strength : float = 0.5 ,
752
753
horizontal_tiles : int = 1 ,
753
754
vertical_tiles : int = 1 ,
754
755
spatial_overlap : int = 1 ,
@@ -977,12 +978,21 @@ def __call__(
977
978
last_latent_tile_num_frames = last_latent_chunk .shape [2 ]
978
979
latent_chunk = torch .cat ([last_latent_chunk , latent_chunk ], dim = 2 )
979
980
total_latent_num_frames = last_latent_tile_num_frames + latent_tile_num_frames
981
+ last_latent_chunk = self ._pack_latents (
982
+ last_latent_chunk ,
983
+ self .transformer_spatial_patch_size ,
984
+ self .transformer_temporal_patch_size ,
985
+ )
986
+ last_latent_chunk_num_tokens = last_latent_chunk .shape [1 ]
987
+ if self .do_classifier_free_guidance :
988
+ last_latent_chunk = torch .cat ([last_latent_chunk , last_latent_chunk ], dim = 0 )
980
989
981
990
conditioning_mask = torch .zeros (
982
991
(batch_size , total_latent_num_frames ),
983
992
dtype = torch .float32 ,
984
993
device = device ,
985
994
)
995
+ # conditioning_mask[:, :last_latent_tile_num_frames] = temporal_overlap_cond_strength
986
996
conditioning_mask [:, :last_latent_tile_num_frames ] = 1.0
987
997
else :
988
998
total_latent_num_frames = latent_tile_num_frames
@@ -1041,9 +1051,17 @@ def __call__(
1041
1051
torch .cat ([latent_chunk ] * 2 ) if self .do_classifier_free_guidance else latent_chunk
1042
1052
)
1043
1053
latent_model_input = latent_model_input .to (prompt_embeds .dtype )
1044
- timestep = t .expand (latent_model_input .shape [0 ]).unsqueeze (- 1 ).float ()
1045
- if start_index > 0 :
1046
- timestep = torch .min (timestep , (1 - conditioning_mask_model_input ) * 1000.0 )
1054
+ # Create timestep tensor that has prod(latent_model_input.shape) elements
1055
+ if start_index == 0 :
1056
+ timestep = t .expand (latent_model_input .shape [0 ]).unsqueeze (- 1 )
1057
+ else :
1058
+ timestep = t .view (1 , 1 ).expand ((latent_model_input .shape [:- 1 ])).clone ()
1059
+ timestep [:, :last_latent_chunk_num_tokens ] = 0.0
1060
+
1061
+ timestep = timestep .float ()
1062
+ # timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
1063
+ # if start_index > 0:
1064
+ # timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
1047
1065
1048
1066
with self .transformer .cache_context ("cond_uncond" ):
1049
1067
noise_pred = self .transformer (
@@ -1075,8 +1093,11 @@ def __call__(
1075
1093
if start_index == 0 :
1076
1094
latent_chunk = denoised_latent_chunk
1077
1095
else :
1078
- tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask )).unsqueeze (- 1 )
1079
- latent_chunk = torch .where (tokens_to_denoise_mask , denoised_latent_chunk , latent_chunk )
1096
+ latent_chunk = torch .cat (
1097
+ [last_latent_chunk , denoised_latent_chunk [:, last_latent_chunk_num_tokens :]], dim = 1
1098
+ )
1099
+ # tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
1100
+ # latent_chunk = torch.where(tokens_to_denoise_mask, denoised_latent_chunk, latent_chunk)
1080
1101
1081
1102
if callback_on_step_end is not None :
1082
1103
callback_kwargs = {}
@@ -1108,8 +1129,7 @@ def __call__(
1108
1129
if start_index == 0 :
1109
1130
first_tile_out_latents = latent_chunk .clone ()
1110
1131
else :
1111
- # We drop the first latent frame as it's a reinterpreted 8-frame latent understood as 1-frame latent
1112
- latent_chunk = latent_chunk [:, :, last_latent_tile_num_frames + 1 :, :, :]
1132
+ latent_chunk = latent_chunk [:, :, last_latent_tile_num_frames :- 1 , :, :]
1113
1133
latent_chunk = LTXLatentUpsamplePipeline .adain_filter_latent (
1114
1134
latent_chunk , first_tile_out_latents , adain_factor
1115
1135
)
0 commit comments