@@ -149,6 +149,9 @@ def forward(
149149 seq_img = hidden_states .shape [1 ]
150150 seq_txt = encoder_hidden_states .shape [1 ]
151151
152+ transformer_patches = transformer_options .get ("patches" , {})
153+ extra_options = transformer_options .copy ()
154+
152155 # Project and reshape to BHND format (batch, heads, seq, dim)
153156 img_query = self .to_q (hidden_states ).view (batch_size , seq_img , self .heads , - 1 ).transpose (1 , 2 ).contiguous ()
154157 img_key = self .to_k (hidden_states ).view (batch_size , seq_img , self .heads , - 1 ).transpose (1 , 2 ).contiguous ()
@@ -167,15 +170,22 @@ def forward(
167170 joint_key = torch .cat ([txt_key , img_key ], dim = 2 )
168171 joint_value = torch .cat ([txt_value , img_value ], dim = 2 )
169172
170- joint_query = apply_rope1 (joint_query , image_rotary_emb )
171- joint_key = apply_rope1 (joint_key , image_rotary_emb )
172-
173173 if encoder_hidden_states_mask is not None :
174174 attn_mask = torch .zeros ((batch_size , 1 , seq_txt + seq_img ), dtype = hidden_states .dtype , device = hidden_states .device )
175175 attn_mask [:, 0 , :seq_txt ] = encoder_hidden_states_mask
176176 else :
177177 attn_mask = None
178178
179+ extra_options ["img_slice" ] = [txt_query .shape [2 ], joint_query .shape [2 ]]
180+ if "attn1_patch" in transformer_patches :
181+ patch = transformer_patches ["attn1_patch" ]
182+ for p in patch :
183+ out = p (joint_query , joint_key , joint_value , pe = image_rotary_emb , attn_mask = encoder_hidden_states_mask , extra_options = extra_options )
184+ joint_query , joint_key , joint_value , image_rotary_emb , encoder_hidden_states_mask = out .get ("q" , joint_query ), out .get ("k" , joint_key ), out .get ("v" , joint_value ), out .get ("pe" , image_rotary_emb ), out .get ("attn_mask" , encoder_hidden_states_mask )
185+
186+ joint_query = apply_rope1 (joint_query , image_rotary_emb )
187+ joint_key = apply_rope1 (joint_key , image_rotary_emb )
188+
179189 joint_hidden_states = optimized_attention_masked (joint_query , joint_key , joint_value , self .heads ,
180190 attn_mask , transformer_options = transformer_options ,
181191 skip_reshape = True )
@@ -444,6 +454,7 @@ def _forward(
444454
445455 timestep_zero_index = None
446456 if ref_latents is not None :
457+ ref_num_tokens = []
447458 h = 0
448459 w = 0
449460 index = 0
@@ -474,16 +485,16 @@ def _forward(
474485 kontext , kontext_ids , _ = self .process_img (ref , index = index , h_offset = h_offset , w_offset = w_offset )
475486 hidden_states = torch .cat ([hidden_states , kontext ], dim = 1 )
476487 img_ids = torch .cat ([img_ids , kontext_ids ], dim = 1 )
488+ ref_num_tokens .append (kontext .shape [1 ])
477489 if timestep_zero :
478490 if index > 0 :
479491 timestep = torch .cat ([timestep , timestep * 0 ], dim = 0 )
480492 timestep_zero_index = num_embeds
493+ transformer_options = transformer_options .copy ()
494+ transformer_options ["reference_image_num_tokens" ] = ref_num_tokens
481495
482496 txt_start = round (max (((x .shape [- 1 ] + (self .patch_size // 2 )) // self .patch_size ) // 2 , ((x .shape [- 2 ] + (self .patch_size // 2 )) // self .patch_size ) // 2 ))
483497 txt_ids = torch .arange (txt_start , txt_start + context .shape [1 ], device = x .device ).reshape (1 , - 1 , 1 ).repeat (x .shape [0 ], 1 , 3 )
484- ids = torch .cat ((txt_ids , img_ids ), dim = 1 )
485- image_rotary_emb = self .pe_embedder (ids ).to (x .dtype ).contiguous ()
486- del ids , txt_ids , img_ids
487498
488499 hidden_states = self .img_in (hidden_states )
489500 encoder_hidden_states = self .txt_norm (encoder_hidden_states )
@@ -495,6 +506,18 @@ def _forward(
495506 patches = transformer_options .get ("patches" , {})
496507 blocks_replace = patches_replace .get ("dit" , {})
497508
509+ if "post_input" in patches :
510+ for p in patches ["post_input" ]:
511+ out = p ({"img" : hidden_states , "txt" : encoder_hidden_states , "img_ids" : img_ids , "txt_ids" : txt_ids , "transformer_options" : transformer_options })
512+ hidden_states = out ["img" ]
513+ encoder_hidden_states = out ["txt" ]
514+ img_ids = out ["img_ids" ]
515+ txt_ids = out ["txt_ids" ]
516+
517+ ids = torch .cat ((txt_ids , img_ids ), dim = 1 )
518+ image_rotary_emb = self .pe_embedder (ids ).to (x .dtype ).contiguous ()
519+ del ids , txt_ids , img_ids
520+
498521 transformer_options ["total_blocks" ] = len (self .transformer_blocks )
499522 transformer_options ["block_type" ] = "double"
500523 for i , block in enumerate (self .transformer_blocks ):
0 commit comments