Skip to content

Commit 9d5909f

Browse files
committed
fixing possible issues
1 parent 47d20cd commit 9d5909f

File tree

4 files changed

+38
-32
lines changed

4 files changed

+38
-32
lines changed

src/jax_flash_attn2/flash_attention.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,6 @@ def __call__(
173173
f"Query heads ({num_q_heads}) must be divisible by "
174174
f"key/value heads ({num_kv_heads})"
175175
)
176-
177-
bias = self._handle_bias(bias, num_q_heads, num_kv_heads)
178-
179176
if self.config.platform == Platform.TRITON:
180177
return self._compute_triton(query, key, value, bias)
181178
elif self.config.platform == Platform.PALLAS:
@@ -192,6 +189,7 @@ def _compute_triton(
192189
) -> chex.Array:
193190
"""Computes attention using Triton backend."""
194191
# fmt:off
192+
bias = self._handle_bias(bias, query.shape[2], key.shape[2])
195193
if query.shape[2] == key.shape[2] or os.environ.get("FORCE_MHA", "false") in ["true", "1", "on"]:
196194
key, value = self.repeat_kv_heads(key, value, query.shape[2] // key.shape[2])
197195

@@ -223,6 +221,8 @@ def _compute_pallas(
223221
bias: Optional[chex.Array],
224222
) -> chex.Array:
225223
"""Computes attention using Pallas backend."""
224+
225+
bias = self._handle_bias(bias, query.shape[2], key.shape[2])
226226
key, value = self.repeat_kv_heads(key, value, query.shape[2] // key.shape[2])
227227

228228
if self.config.backend == Backend.GPU:
@@ -272,6 +272,8 @@ def _compute_jax(
272272
bias: Optional[chex.Array],
273273
) -> chex.Array:
274274
"""Computes attention using JAX backend."""
275+
276+
bias = self._handle_bias(bias, query.shape[2], key.shape[2])
275277
key, value = self.repeat_kv_heads(key, value, query.shape[2] // key.shape[2])
276278
return jax_flash_attn_2_mu(
277279
query_state=query,

src/jax_flash_attn2/triton_kernels/gqa_kernel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,7 @@ def _fwd_attn_kernel_call_with_residual(
11151115

11161116

11171117
@functools.partial(custom_vjp, nondiff_argnums=[4, 5, 6])
1118+
@functools.partial(jax.jit, static_argnums=[4, 5, 6])
11181119
def _flash_gqa_attn2(
11191120
query: chex.Array,
11201121
key: chex.Array,
@@ -1240,7 +1241,7 @@ def _test_forward():
12401241
def _test_backward():
12411242
"""Tests the backward pass of the attention mechanism."""
12421243
q_key, k_key, v_key = jrnd.split(jrnd.PRNGKey(8), 3)
1243-
B, QH, KVH, QS, KS, D = 1, 32, 32, 1024, 1024, 128
1244+
B, QH, KVH, QS, KS, D = 1, 32, 16, 1024, 1024, 128
12441245
blocksize_k = 16
12451246
blocksize_q = 16
12461247
q = jax.nn.initializers.normal(2)(q_key, (B, QS, QH, D), dtype=jnp.float16)

src/jax_flash_attn2/triton_kernels/mha_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,7 @@ def _fwd_attn_kernel_call_with_residual(
10381038

10391039

10401040
@functools.partial(custom_vjp, nondiff_argnums=[4, 5, 6])
1041+
@functools.partial(jax.jit, static_argnums=[4, 5, 6])
10411042
def _flash_attn2(
10421043
query: chex.Array,
10431044
key: chex.Array,
@@ -1076,7 +1077,6 @@ def _flash_attn2(
10761077
_fwd_attn_kernel_call_with_residual,
10771078
_bwd_attn_kernel_call,
10781079
)
1079-
10801080
triton_flash_mha_attn_2_gpu = _flash_attn2
10811081
__all__ = ["triton_flash_mha_attn_2_gpu"]
10821082

@@ -1130,7 +1130,7 @@ def _test_forward():
11301130
"""Tests the forward pass of the attention mechanism."""
11311131
q_key, k_key, v_key = jrnd.split(jrnd.PRNGKey(8), 3)
11321132
q_key, k_key, v_key = jrnd.split(jrnd.PRNGKey(8), 3)
1133-
B, QH, KVH, QS, KS, D = 1, 32, 8, 1024, 1024, 128
1133+
B, QH, KVH, QS, KS, D = 1, 32, 32, 1024, 1024, 128
11341134
blocksize_k = 64
11351135
blocksize_q = 128
11361136
q = jax.nn.initializers.normal(2)(q_key, (B, QS, QH, D), dtype=jnp.float16)

tests/test_triton.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,19 @@
55
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"
66
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
77
os.environ["JAX_TRACEBACK_FILTERING"] = "off"
8-
os.environ["XLA_FLAGS"] = "--xla_gpu_enable_command_buffer="
98

109
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../src"))
1110

1211
import jax
1312
from jax import numpy as jnp
1413
from jax import random as jrnd
1514

16-
from jax_flash_attn2 import get_cached_flash_attention
15+
from jax_flash_attn2 import create_flash_attention
1716

1817
USE_BIAS = True
1918

2019

21-
def _attn_refrence(query_states, key_states, value_states, bias):
20+
def _gqa_attn_refrence(query_states, key_states, value_states, bias):
2221
b, qs, num_q_heads, d = query_states.shape
2322
num_kv_heads = value_states.shape[2]
2423
ks = value_states.shape[1]
@@ -63,13 +62,32 @@ def _attn_refrence(query_states, key_states, value_states, bias):
6362
)
6463

6564

65+
def _mha_attn_refrence(query_states, key_states, value_states, bias):
66+
d = query_states.shape[-1]
67+
68+
attention_weight = jnp.einsum("bqhd,bkhd->bhqk", query_states * (d**-0.5), key_states)
69+
70+
if bias is not None:
71+
attention_weight = jnp.add(attention_weight, bias)
72+
attention_weight = jax.nn.softmax(attention_weight)
73+
74+
return jnp.einsum("bhqk,bkhd->bqhd", attention_weight, value_states)
75+
76+
77+
flash_attn = create_flash_attention(
78+
backend="gpu",
79+
platform="triton",
80+
blocksize_q=64,
81+
blocksize_k=64,
82+
softmax_scale=None,
83+
)
84+
85+
6686
def test_forward():
6787
"""Tests the forward pass of the attention mechanism."""
6888
q_key, k_key, v_key = jrnd.split(jrnd.PRNGKey(8), 3)
6989
q_key, k_key, v_key = jrnd.split(jrnd.PRNGKey(8), 3)
7090
B, QH, KVH, QS, KS, D = 1, 32, 8, 1024, 1024, 128
71-
blocksize_k = 64
72-
blocksize_q = 128
7391
q = jax.nn.initializers.normal(2)(q_key, (B, QS, QH, D), dtype=jnp.float16)
7492
k = jax.nn.initializers.normal(2)(k_key, (B, KS, KVH, D), dtype=jnp.float16)
7593
v = jax.nn.initializers.normal(2)(v_key, (B, KS, KVH, D), dtype=jnp.float16)
@@ -82,27 +100,19 @@ def test_forward():
82100
if USE_BIAS
83101
else None
84102
)
85-
flash_attn = get_cached_flash_attention(
86-
backend="gpu",
87-
platform="triton",
88-
blocksize_q=blocksize_q,
89-
blocksize_k=blocksize_k,
90-
softmax_scale=None,
91-
)
92103
print("QKV Allocated")
93104
co = flash_attn(q, k, v, b) # passes 256K on 24G GPU 3090
94105
print(co[-1, -1, -1, :5])
95-
fo = _attn_refrence(q, k, v, b)
106+
fo = _gqa_attn_refrence(q, k, v, b)
96107
print(fo[-1, -1, -1, :5])
97108
print("Results are Close" if jnp.allclose(co, fo, 0, 0.125) else "Wrong results!")
98109

99110

100111
def test_backward():
101-
"""Tests the backward pass of the attention mechanism."""
112+
"""Tests the backward pass of the attention mechanism."""
113+
102114
q_key, k_key, v_key = jrnd.split(jrnd.PRNGKey(8), 3)
103115
B, QH, KVH, QS, KS, D = 1, 32, 32, 1024, 1024, 128
104-
blocksize_k = 16
105-
blocksize_q = 16
106116
q = jax.nn.initializers.normal(2)(q_key, (B, QS, QH, D), dtype=jnp.float16)
107117
k = jax.nn.initializers.normal(2)(k_key, (B, KS, KVH, D), dtype=jnp.float16)
108118
v = jax.nn.initializers.normal(2)(v_key, (B, KS, KVH, D), dtype=jnp.float16)
@@ -116,25 +126,18 @@ def test_backward():
116126
else None
117127
)
118128

119-
flash_attn = get_cached_flash_attention(
120-
backend="gpu",
121-
platform="triton",
122-
blocksize_q=blocksize_q,
123-
blocksize_k=blocksize_k,
124-
softmax_scale=None,
125-
)
126129
try:
127130
co = jax.grad(lambda *x: flash_attn(*x).sum())(q, k, v, b)
128131
print("Custom op backward pass gradients:")
129-
print(co[-1][-1, -1, :5]) # Print last 5 elements of last head of last batch
132+
print(co[-1, -1, -1, :5]) # Print last 5 elements of last head of last batch
130133
except Exception as er:
131134
print(f"Custom op backward pass failed: {er}")
132135
co = None
133136

134137
try:
135-
fo = jax.grad(lambda *x: _attn_refrence(*x).sum())(q, k, v, b)
138+
fo = jax.grad(lambda *x: _mha_attn_refrence(*x).sum())(q, k, v, b)
136139

137-
print(fo[-1, -1, -1, :5]) # Print last 5 elements of last head of last batch
140+
print(fo[-1, -1, -1, :5])
138141
except Exception as e:
139142
print(f"Flax backward pass failed : {e}")
140143
fo = None

0 commit comments

Comments
 (0)