Skip to content

MultiHeadAttention's use_causal_mask is broken #21284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
pfekin opened this issue May 14, 2025 · 4 comments
Open

MultiHeadAttention's use_causal_mask is broken #21284

pfekin opened this issue May 14, 2025 · 4 comments
Assignees
Labels

Comments

@pfekin
Copy link

pfekin commented May 14, 2025

There is leakage of forward embeddings when not calling the MultiHeadAttention layer with a mask and using use_causal_mask=True instead.
I get +0.99 accuracy on a randomly generated validation dataset using Colab.

@dhantule
Copy link
Contributor

Hi @pfekin, thanks for reporting this.
Could you provide some reproducible code ?

@pfekin
Copy link
Author

pfekin commented May 15, 2025

I'm sorry for the size - given the nature of the problem it's hard to make the script shorter.
The validation dataset is made of random indices - with time it will fit with +0.99 accuracy.
When given a causal mask MultiHeadAttention works as it should.


import tensorflow as tf
from tensorflow.keras import layers, Model
import tensorflow_datasets as tfds
import math
import numpy as np
from datasets import load_dataset

# Configuration
SEQ_LENGTH = 64
BATCH_SIZE = 64
VOCAB_SIZE = 10000
EMBED_DIM = 256
EPOCHS = 10

# Data Preparation 
def prepare_wikitext(split, tokenizer):
    raw_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split=split)
    texts = [x['text'] for x in raw_dataset if x['text'].strip() != '']
    ds = tf.data.Dataset.from_tensor_slices(texts)
    tokenizer.adapt(ds)
             
    def process(text):
        tokens = tokenizer(text)
        return tokens[:-1], tokens[1:]
    
    def filter_empty(x, y):
        return tf.shape(x)[0] > 0
        
    return ds.map(process, num_parallel_calls=tf.data.AUTOTUNE)\
             .filter(filter_empty)    

def prepare_random_validation_ds(vocab_size, seq_length, num_samples, tokenizer=None):
    # Generate random tokens in the vocabulary range
    random_sequences = np.random.randint(1, vocab_size, size=(num_samples, seq_length + 1))
    ds = tf.data.Dataset.from_tensor_slices(random_sequences.astype(np.int64))

    def process(seq):
        return seq[:-1], seq[1:]

    ds = ds.map(process, num_parallel_calls=tf.data.AUTOTUNE)
    return ds
    
tokenizer = tf.keras.layers.TextVectorization(
        max_tokens=VOCAB_SIZE,
        output_sequence_length=SEQ_LENGTH+1,
        standardize='lower_and_strip_punctuation' 
    )
    
train_ds = prepare_wikitext('train', tokenizer).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
#val_ds = prepare_wikitext('validation', tokenizer).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
val_ds = prepare_random_validation_ds(VOCAB_SIZE, SEQ_LENGTH, 10000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

class AdaptivePositionEncoding(layers.Layer):
    def __init__(self, max_len, d_model):
        super().__init__()
        self.max_len = max_len
        self.d_model = d_model

    def build(self, input_shape):
        position = tf.range(self.max_len, dtype=tf.float32)[:, tf.newaxis]  # (max_len, 1)
        div_term = tf.exp(
            tf.range(0, self.d_model, 2, dtype=tf.float32) *
            (-math.log(10000.0) / self.d_model
        ))  # (d_model//2,)

        # Compute sinusoidal embeddings
        pe = tf.zeros((self.max_len, self.d_model))  # (max_len, d_model)
        sin_vals = tf.sin(position * div_term)  # (max_len, d_model//2)
        cos_vals = tf.cos(position * div_term)  # (max_len, d_model//2)

        # Interleave sin and cos values
        pe = tf.reshape(
            tf.stack([sin_vals, cos_vals], axis=2),
            [self.max_len, self.d_model]
        )
        self.pe = tf.Variable(pe[tf.newaxis, :, :], trainable=True)

    def call(self, x):
        return self.pe[:, :tf.shape(x)[1], :] 

#Attention Mechanism
class DotProductAttention(layers.Layer):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.mha = layers.MultiHeadAttention(num_heads, d_model//num_heads)

    def call(self, x, mask=None, use_causal_mask=False):
        return self.mha(x, x, attention_mask=mask, use_causal_mask=False)

# Autoregressive Model
class LanguageModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.embed = layers.Embedding(VOCAB_SIZE, EMBED_DIM)
        self.pos_enc = AdaptivePositionEncoding(SEQ_LENGTH, EMBED_DIM)
        self.attention = DotProductAttention(EMBED_DIM, num_heads=4)
        self.out = layers.Dense(VOCAB_SIZE)

    def call(self, inputs):
        x = self.embed(inputs)
        x = x + self.pos_enc(x)
        attn_output = self.attention(x, use_causal_mask=True)  
        return self.out(attn_output)

# Training & Evaluation
model = LanguageModel()
model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-3, clipvalue=1.0 ),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS)
model.summary()


@dhantule
Copy link
Contributor

Hi @pfekin, I've run your code and I'm unable to reproduce. Please refer this gist.

@pfekin
Copy link
Author

pfekin commented May 19, 2025

You need to install the necessary libraries beforehand:
!pip install --upgrade datasets fsspec huggingface_hub

Also, you might need to set HF_TOKEN (Hugginface API environment variable) as a Colab secret.

I you want you can replace prepare_wikitext with:

def prepare_imdb(split, tokenizer):
    ds = tfds.load('imdb_reviews', split=split)  # Using IMDB reviews dataset
    tokenizer.adapt(ds.map(lambda x: x['text']))

    def process(example):
        tokens = tokenizer(example['text'])
        return tokens[:-1], tokens[1:]

    return ds.map(process, num_parallel_calls=tf.data.AUTOTUNE)

It will replace the Wikitetext-2 dataset (hosted on HuggingFace) with the IMDB dataset (hosted on Tensorflow and does not require an API key), but it's larger and it will take longer to return results.

On Wikitext-2 I'm getting val_accuracy: 0.6750 after 10 epochs on a randomly generated validation dataset.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants