Skip to content

The code implementation of Causal Attention does not match the paper description #74

@bugmakerwww

Description

@bugmakerwww

Thanks for the excellent work.

In the paper, given condition frame C0, C1, C2, video and action frames can attend to the previous condition frame and current video and action, as description in the figure below:
Image

However, in the released code, I found that not every block has a conditional frame, that is, only the first frame is a conditional frame;

  • For teacher forcing training, video block i can attend to clean_blocks[0:i] + current_noisy_block + action[i] + state[i], but cannot attend to the first frame or the condition frame of current block;
            q_block = noisy_image_q[:, noisy_start:noisy_end]
            
            # Build context: first_clean_frame + clean_blocks[0:i] + current_noisy_block + action[i] + state[i]
            k_context = torch.cat([
                clean_image_k[:, :clean_end],
                noisy_image_k[:, noisy_start:noisy_end],
                noisy_action_k[:, action_start:action_end],
                noisy_state_k[:, state_start:state_end]
            ], dim=1)
            v_context = torch.cat([
                clean_image_v[:, :clean_end],
                noisy_image_v[:, noisy_start:noisy_end],
                noisy_action_v[:, action_start:action_end],
                noisy_state_v[:, state_start:state_end]
            ], dim=1)
  • For normal training, video block i can attend to all previous blocks and current block;
            block_starts = [frame_seqlen + i * block_size for i in range(num_blocks)]
            block_ends = [min(start + block_size, total_len) for start in block_starts]
            kv_starts = [max(0, end - self.local_attn_size * frame_seqlen) for end in block_ends]
            
            for block_idx in range(num_blocks):
                block_start = block_starts[block_idx]
                block_end = block_ends[block_idx]
                kv_start = kv_starts[block_idx]
                
                output[:, block_start:block_end] = self.attn(
                    q[:, block_start:block_end],
                    k[:, kv_start:block_end],
                    v[:, kv_start:block_end]
                )

It seems that in any training mode, there are no independent conditional frames for each block. I want to know:

  • what C0, C1, and C2 in the image correspond to in the code? Does the description in the paper correspond correctly to the code implementation?
  • if C1 represents a conditional frame of block1, should it be denoised? Or a clean encoded GT?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions