Skip to content

Commit 4d045e1

Browse files
committed
Examples of training bias
stack-info: PR: #84, branch: drisspg/stack/2
1 parent 36f8bd5 commit 4d045e1

File tree

1 file changed

+247
-0
lines changed

1 file changed

+247
-0
lines changed

examples/learnable_bias.py

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
import functools
2+
import logging
3+
import torch
4+
import torch.nn.functional as F
5+
import json
6+
import argparse
7+
from torch.nn.attention.flex_attention import flex_attention
8+
from typing import Callable, Dict, List, Tuple, Optional
9+
from enum import Enum, auto
10+
from torch.optim import Adam
11+
from torch.utils.data import DataLoader, TensorDataset
12+
from tqdm import tqdm
13+
14+
logging.basicConfig(
15+
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
16+
)
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class BiasType(Enum):
21+
RELATIVE_1D = "relative_1d"
22+
ABSOLUTE_2D = "absolute_2d"
23+
HEAD_SPECIFIC = "head_specific"
24+
BATCH_HEAD = "batch_head"
25+
MULTIPLICATIVE = "multiplicative"
26+
LOCAL_WINDOW = "local_window"
27+
GLOBAL_TOKENS = "global_tokens"
28+
WEIRD = "weird"
29+
OFFSET = "offset"
30+
31+
32+
class AttentionTrainer:
33+
34+
def __init__(
35+
self,
36+
batch_size: int = 8,
37+
num_heads: int = 4,
38+
seq_length: int = 256,
39+
head_dim: int = 64,
40+
device: str = "cuda",
41+
dtype: torch.dtype = torch.float32,
42+
window_size: int = 16,
43+
learning_rate: float = 1e-1,
44+
):
45+
self.B = batch_size
46+
self.H = num_heads
47+
self.S = seq_length
48+
self.D = head_dim
49+
self.W = window_size
50+
self.device = device
51+
self.dtype = dtype
52+
self.lr = learning_rate
53+
self.which_bias = torch.tensor(0, device=device)
54+
self.offset = None
55+
56+
# Initialize bias generators and functions like in the original
57+
self.bias_generators = {
58+
BiasType.RELATIVE_1D: self._generate_relative_1d_bias,
59+
BiasType.ABSOLUTE_2D: self._generate_absolute_2d_bias,
60+
BiasType.HEAD_SPECIFIC: self._generate_head_specific_bias,
61+
BiasType.BATCH_HEAD: self._generate_batch_head_bias,
62+
BiasType.MULTIPLICATIVE: self._generate_multiplicative_bias,
63+
BiasType.LOCAL_WINDOW: self._generate_local_window_bias,
64+
BiasType.GLOBAL_TOKENS: self._generate_global_tokens_bias,
65+
BiasType.WEIRD: self._generate_weird_bias,
66+
BiasType.OFFSET: self._generate_offset_bias,
67+
}
68+
69+
# Copy the bias application functions from the original
70+
self.bias_functions = {
71+
BiasType.RELATIVE_1D: self._apply_relative_1d_bias,
72+
BiasType.ABSOLUTE_2D: self._apply_absolute_2d_bias,
73+
BiasType.HEAD_SPECIFIC: self._apply_head_specific_bias,
74+
BiasType.BATCH_HEAD: self._apply_batch_head_bias,
75+
BiasType.MULTIPLICATIVE: self._apply_multiplicative_bias,
76+
BiasType.LOCAL_WINDOW: self._apply_local_window_bias,
77+
BiasType.GLOBAL_TOKENS: self._apply_global_tokens_bias,
78+
BiasType.WEIRD: self._apply_weird_bias,
79+
BiasType.OFFSET: self._apply_offset_bias,
80+
}
81+
82+
def _generate_tensor(self, *size):
83+
return torch.randn(
84+
*size, device=self.device, dtype=self.dtype, requires_grad=True
85+
)
86+
87+
# Bias Generators
88+
89+
def _generate_relative_1d_bias(self):
90+
return self._generate_tensor(2 * self.S)
91+
92+
def _generate_absolute_2d_bias(self):
93+
return self._generate_tensor(self.S, self.S)
94+
95+
def _generate_head_specific_bias(self):
96+
return self._generate_tensor(self.H, self.S, self.S)
97+
98+
def _generate_batch_head_bias(self):
99+
return self._generate_tensor(self.B, self.H, self.S, self.S)
100+
101+
def _generate_multiplicative_bias(self):
102+
return self._generate_tensor(self.S)
103+
104+
def _generate_local_window_bias(self):
105+
return self._generate_tensor(2 * self.W + 1)
106+
107+
def _generate_learned_pattern_bias(self):
108+
return self._generate_tensor(self.H, self.D)
109+
110+
def _generate_global_tokens_bias(self):
111+
return self._generate_tensor(self.S)
112+
113+
def _generate_weird_bias(self):
114+
return self._generate_tensor(self.B, self.H, 4, self.S)
115+
116+
def _generate_offset_bias(self):
117+
# Generate both the bias and offset tensors
118+
bias = self._generate_tensor(self.S)
119+
self.offset = torch.randint(0, self.S, (self.S,), device=self.device)
120+
return bias
121+
122+
# Bias Application Functions
123+
def _apply_relative_1d_bias(self, score, b, h, q_idx, kv_idx, bias):
124+
return score + bias[torch.abs(q_idx - kv_idx)]
125+
126+
def _apply_absolute_2d_bias(self, score, b, h, q_idx, kv_idx, bias):
127+
return score + bias[q_idx, kv_idx]
128+
129+
def _apply_head_specific_bias(self, score, b, h, q_idx, kv_idx, bias):
130+
return score + bias[h, q_idx, kv_idx]
131+
132+
def _apply_batch_head_bias(self, score, b, h, q_idx, kv_idx, bias):
133+
return score + bias[b, h, q_idx, kv_idx]
134+
135+
def _apply_multiplicative_bias(self, score, b, h, q_idx, kv_idx, bias):
136+
return score * bias[q_idx]
137+
138+
def _apply_local_window_bias(self, score, b, h, q_idx, kv_idx, bias):
139+
window_idx = torch.clamp(q_idx - kv_idx + self.W, 0, 2 * self.W)
140+
return score + bias[window_idx]
141+
142+
def _apply_global_tokens_bias(self, score, b, h, q_idx, kv_idx, bias):
143+
return score + bias[kv_idx]
144+
145+
def _apply_weird_bias(self, score, b, h, q_idx, kv_idx, bias):
146+
return score + bias[b, h, self.which_bias, q_idx]
147+
148+
def _apply_offset_bias(self, score, b, h, q_idx, kv_idx, bias):
149+
return score + bias[self.offset[q_idx]]
150+
151+
# Copy all the bias generator and application methods from the original class
152+
# [Previous methods remain the same as in the original code]
153+
154+
def generate_dummy_data(self, num_samples: int) -> TensorDataset:
155+
"""Generate dummy training data."""
156+
queries = torch.randn(
157+
num_samples, self.B, self.H, self.S, self.D, device=self.device
158+
)
159+
keys = torch.randn(
160+
num_samples, self.B, self.H, self.S, self.D, device=self.device
161+
)
162+
values = torch.randn(
163+
num_samples, self.B, self.H, self.S, self.D, device=self.device
164+
)
165+
166+
# Generate dummy targets (for this example, we'll try to predict specific patterns)
167+
targets = torch.randn(
168+
num_samples, self.B, self.H, self.S, self.D, device=self.device
169+
)
170+
171+
return TensorDataset(queries, keys, values, targets)
172+
173+
def train(
174+
self,
175+
bias_type: BiasType = BiasType.RELATIVE_1D,
176+
num_epochs: int = 10,
177+
num_samples: int = 2,
178+
batch_size: int = 4,
179+
):
180+
"""Train the attention model with the specified bias type."""
181+
# Generate bias parameters
182+
bias = self.bias_generators[bias_type]()
183+
optimizer = Adam([bias], lr=self.lr)
184+
185+
# Generate dummy dataset
186+
dataset = self.generate_dummy_data(num_samples)
187+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
188+
189+
# Create bias function closure
190+
def bias_func(score, b, h, q_idx, kv_idx):
191+
return self.bias_functions[bias_type](score, b, h, q_idx, kv_idx, bias)
192+
193+
# Compile the attention function
194+
flex_compiled = torch.compile(
195+
flex_attention, backend="eager", fullgraph=True, dynamic=False
196+
)
197+
198+
# Training loop
199+
for epoch in range(num_epochs):
200+
total_loss = 0.0
201+
with tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
202+
for batch_idx, (q_batch, k_batch, v_batch, targets) in enumerate(pbar):
203+
q_batch.requires_grad_()
204+
optimizer.zero_grad()
205+
206+
# Forward pass
207+
outputs = flex_compiled(
208+
q_batch[0], k_batch[0], v_batch[0], score_mod=bias_func
209+
)
210+
211+
# Compute loss (MSE for this example)
212+
loss = F.mse_loss(outputs, targets[0])
213+
214+
# Backward pass
215+
loss.backward()
216+
optimizer.step()
217+
218+
total_loss += loss.item()
219+
pbar.set_postfix({"loss": f"{loss.item():.6f}"})
220+
221+
avg_loss = total_loss / len(dataloader)
222+
logger.info(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.6f}")
223+
224+
return bias, avg_loss
225+
226+
227+
def main(
228+
bias_type: BiasType = BiasType.RELATIVE_1D,
229+
num_epochs: int = 10,
230+
num_samples: int = 2,
231+
batch_size: int = 4,
232+
):
233+
trainer = AttentionTrainer()
234+
trained_bias, final_loss = trainer.train(
235+
bias_type=bias_type,
236+
num_epochs=num_epochs,
237+
num_samples=num_samples,
238+
batch_size=batch_size,
239+
)
240+
241+
logger.info(f"Final loss: {final_loss:.6f}")
242+
243+
244+
if __name__ == "__main__":
245+
from jsonargparse import CLI
246+
247+
CLI(main)

0 commit comments

Comments
 (0)