55os .environ ["XLA_PYTHON_CLIENT_MEM_FRACTION" ] = "1.0"
66os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "3"
77os .environ ["JAX_TRACEBACK_FILTERING" ] = "off"
8- os .environ ["XLA_FLAGS" ] = "--xla_gpu_enable_command_buffer="
98
109sys .path .append (os .path .join (os .path .dirname (os .path .abspath (__file__ )), "../src" ))
1110
1211import jax
1312from jax import numpy as jnp
1413from jax import random as jrnd
1514
16- from jax_flash_attn2 import get_cached_flash_attention
15+ from jax_flash_attn2 import create_flash_attention
1716
1817USE_BIAS = True
1918
2019
21- def _attn_refrence (query_states , key_states , value_states , bias ):
20+ def _gqa_attn_refrence (query_states , key_states , value_states , bias ):
2221 b , qs , num_q_heads , d = query_states .shape
2322 num_kv_heads = value_states .shape [2 ]
2423 ks = value_states .shape [1 ]
@@ -63,13 +62,32 @@ def _attn_refrence(query_states, key_states, value_states, bias):
6362 )
6463
6564
65+ def _mha_attn_refrence (query_states , key_states , value_states , bias ):
66+ d = query_states .shape [- 1 ]
67+
68+ attention_weight = jnp .einsum ("bqhd,bkhd->bhqk" , query_states * (d ** - 0.5 ), key_states )
69+
70+ if bias is not None :
71+ attention_weight = jnp .add (attention_weight , bias )
72+ attention_weight = jax .nn .softmax (attention_weight )
73+
74+ return jnp .einsum ("bhqk,bkhd->bqhd" , attention_weight , value_states )
75+
76+
77+ flash_attn = create_flash_attention (
78+ backend = "gpu" ,
79+ platform = "triton" ,
80+ blocksize_q = 64 ,
81+ blocksize_k = 64 ,
82+ softmax_scale = None ,
83+ )
84+
85+
6686def test_forward ():
6787 """Tests the forward pass of the attention mechanism."""
6888 q_key , k_key , v_key = jrnd .split (jrnd .PRNGKey (8 ), 3 )
6989 q_key , k_key , v_key = jrnd .split (jrnd .PRNGKey (8 ), 3 )
7090 B , QH , KVH , QS , KS , D = 1 , 32 , 8 , 1024 , 1024 , 128
71- blocksize_k = 64
72- blocksize_q = 128
7391 q = jax .nn .initializers .normal (2 )(q_key , (B , QS , QH , D ), dtype = jnp .float16 )
7492 k = jax .nn .initializers .normal (2 )(k_key , (B , KS , KVH , D ), dtype = jnp .float16 )
7593 v = jax .nn .initializers .normal (2 )(v_key , (B , KS , KVH , D ), dtype = jnp .float16 )
@@ -82,27 +100,19 @@ def test_forward():
82100 if USE_BIAS
83101 else None
84102 )
85- flash_attn = get_cached_flash_attention (
86- backend = "gpu" ,
87- platform = "triton" ,
88- blocksize_q = blocksize_q ,
89- blocksize_k = blocksize_k ,
90- softmax_scale = None ,
91- )
92103 print ("QKV Allocated" )
93104 co = flash_attn (q , k , v , b ) # passes 256K on 24G GPU 3090
94105 print (co [- 1 , - 1 , - 1 , :5 ])
95- fo = _attn_refrence (q , k , v , b )
106+ fo = _gqa_attn_refrence (q , k , v , b )
96107 print (fo [- 1 , - 1 , - 1 , :5 ])
97108 print ("Results are Close" if jnp .allclose (co , fo , 0 , 0.125 ) else "Wrong results!" )
98109
99110
100111def test_backward ():
101- """Tests the backward pass of the attention mechanism."""
112+ """Tests the backward pass of the attention mechanism."""
113+
102114 q_key , k_key , v_key = jrnd .split (jrnd .PRNGKey (8 ), 3 )
103115 B , QH , KVH , QS , KS , D = 1 , 32 , 32 , 1024 , 1024 , 128
104- blocksize_k = 16
105- blocksize_q = 16
106116 q = jax .nn .initializers .normal (2 )(q_key , (B , QS , QH , D ), dtype = jnp .float16 )
107117 k = jax .nn .initializers .normal (2 )(k_key , (B , KS , KVH , D ), dtype = jnp .float16 )
108118 v = jax .nn .initializers .normal (2 )(v_key , (B , KS , KVH , D ), dtype = jnp .float16 )
@@ -116,25 +126,18 @@ def test_backward():
116126 else None
117127 )
118128
119- flash_attn = get_cached_flash_attention (
120- backend = "gpu" ,
121- platform = "triton" ,
122- blocksize_q = blocksize_q ,
123- blocksize_k = blocksize_k ,
124- softmax_scale = None ,
125- )
126129 try :
127130 co = jax .grad (lambda * x : flash_attn (* x ).sum ())(q , k , v , b )
128131 print ("Custom op backward pass gradients:" )
129- print (co [- 1 ][ - 1 , - 1 , :5 ]) # Print last 5 elements of last head of last batch
132+ print (co [- 1 , - 1 , - 1 , :5 ]) # Print last 5 elements of last head of last batch
130133 except Exception as er :
131134 print (f"Custom op backward pass failed: { er } " )
132135 co = None
133136
134137 try :
135- fo = jax .grad (lambda * x : _attn_refrence (* x ).sum ())(q , k , v , b )
138+ fo = jax .grad (lambda * x : _mha_attn_refrence (* x ).sum ())(q , k , v , b )
136139
137- print (fo [- 1 , - 1 , - 1 , :5 ]) # Print last 5 elements of last head of last batch
140+ print (fo [- 1 , - 1 , - 1 , :5 ])
138141 except Exception as e :
139142 print (f"Flax backward pass failed : { e } " )
140143 fo = None
0 commit comments