Skip to content

Trainable tokens overwriting prompt tokens #75

@jweiss-mathworks

Description

@jweiss-mathworks

In the paper, in Appendix D describing the refinement of the textual space, the paper says that the trainable tokens are concatenated with the original tokens, discarding the trainable tokens from the previous layer. This is not what the code is doing; instead, the trainable tokens are overwriting the original tokens from index 1 to 1 + length of trainable tokens.

To illustrate this, first, look at the construction of the token embeddings:
prompts, tokenized_prompts, compound_prompts_text = prompt_learner(cls_id = None)

When constructing the prompts in AnomalyCLIP_PromptLearner, the positive prompt is of length 16:
[startToken, X, X, X, X, X, X, X, X, X, X, X, X, object, ., endToken]
where X is a placeholder for the learnable prompt token embeddings. This is padded with 0's to match CLIP's input size of 77. After embedding using CLIP, this is of size [1, 1, 77, 768].

The negative prompt follows a similar pattern except for the inclusion of the word "damaged". It is of length 17:
[startToken, X, X, X, X, X, X, X, X, X, X, X, X, damaged, object, ., endToken].
This is also padded with 0's to match CLIP's input size of 77, and the embedding is also [1, 1, 77, 768].

The positive and negative prompt embeddings are concatenated into a tensor of size [2, 77, 768]. This is all good so far.

Now let's look at how the model handles these token embeddings:
text_features = model.encode_text_learn(prompts, tokenized_prompts, compound_prompts_text).float()

The compount_prompts_text variable is a list of 8 parameters, each of size [4, 768], which represent the trainable tokens t'm as referenced in appendix D. (Note that there is a difference between the paper and the code: the paper says that these are applied "into the text encoder from its bottom to the top layer", however the code only adds these from layers 2-9. This is not the main issue though.)

The following line will be called in the AnomalyCLIP "encode_text_learn" method:
x = self.transformer([x, deep_compound_prompts_text, 0])

and this will call the "forward" method of all of the "ResidualAttentionBlock_learnable_token" layers sequentially. This is where the issue occurs. Looking at the "forward" method:

    def forward(self, inputs):
            # if/else statement removed for brevity
            x = inputs[0]
            compound_prompts_deeper = inputs[1]
            counter = inputs[2]
            if not self.first_layer:
                # First check if the ith layer needs compound prompts or not
                if not (counter > len(compound_prompts_deeper) - 1):
                    # Appending the learnable tokens in different way
                    # x -> [77, NCLS, DIM]
                    # First remove the learnable tokens from previous layer
        >>>>>>>>>>> prefix = x[:1, :, :]
                    suffix = x[1 + self.compound_prompt_nctx:, :, :]
                    textual_context = compound_prompts_deeper[counter]
                    textual_context = textual_context.expand(x.shape[1], -1, -1).permute(1, 0, 2).half()
                    # Add the learnable tokens of this layer with the input, replaced by previous
                    # layer learnable tokens
                    x = torch.cat([prefix, textual_context, suffix], dim=0)
                >>>># to show the bug
                >>>>print(torch.all(torch.all(torch.eq(inputs[0], x), dim=1), dim = 1))
                    # Once done, update the counter, so that the next time, it does not use same learnable tokens
                    counter += 1
            x = x + self.attention(self.ln_1(x))
            x = x + self.mlp(self.ln_2(x))
        return [x, compound_prompts_deeper, counter]

We can see that the first layer won't have any issues because the learnable tokens t'm are not applied, but layers 2-9 will have an issue. Let's look at layer 2. Specifically, in the line prepended by >>>>>>>>>>>, we see that we take the first token embedding, corresponding the to the startToken embedding. In the next line, since self.compount_prompt_nctx = 4, we take the last 72 token embeddings giving suffix.shape = [72, 2, 768]. For the positive prompt, this corresponds to the embeddings of
[X, X, X, X, X, X, X, X, object, ., endToken] (length 11)
plus the 61 padded 0s.

Finally, in the line x = torch.cat([prefix, textual_context, suffix], dim=0), these are put together, giving the token embeddings of:
[startToken, t'm1, t'm2, t'm3, t'm4, X, X, X, X, X, X, X, X object, ., endToken]. I have verified this by debugging.
As you can see, the learnable tokens t'm1 ... t'm4 are overwriting 4 of the learnable prompt token embeddings, not concatenating as the paper says. If they were concatenating, I would expect something like this:
[startToken, X, X, X, X, X, X, X, X, X, X, X, X object, ., t'm1, t'm2, t'm3, t'm4, endToken].

In order to fix this, the ResidualAttentionBlock_learnable_token layer can use the design_details variable that is passed in on construction to store the prompt length, and use that to construct the tokens for each layer, noting that the negative prompt length is 1 token longer than the positive prompt length.

The implications of this are not bad. The paper still has good results with this bug, and the paper could have a correction to acknowledge the bug.
My hypothesis is that the bug hinders the model's ability to learn and that the results may be improved by doing concatenation instead of overwriting.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions