Skip to content

Commit d967c33

Browse files
rchen152jax authors
authored andcommitted
Silence some pytype errors.
PiperOrigin-RevId: 623308993
1 parent f1ae623 commit d967c33

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -418,9 +418,9 @@ def _wrapped(
418418
if is_grouped:
419419

420420
def reshape_activations(activations):
421-
if activations.ndim == 4:
422-
kv_heads, q_heads_per_kv_head, q_seq_len, head_dim = activations.shape
423-
return activations.reshape(
421+
if activations.ndim == 4: # pytype: disable=attribute-error
422+
kv_heads, q_heads_per_kv_head, q_seq_len, head_dim = activations.shape # pytype: disable=attribute-error
423+
return activations.reshape( # pytype: disable=attribute-error
424424
kv_heads * q_heads_per_kv_head, q_seq_len, head_dim
425425
)
426426
return activations
@@ -755,7 +755,7 @@ def body(kv_compute_index, _):
755755

756756
qk = apply_mask_and_soft_cap()
757757

758-
m_curr = qk.max(axis=-1)[:, None]
758+
m_curr = qk.max(axis=-1)[:, None] # pytype: disable=attribute-error
759759
assert m_curr.shape == (bq, 1)
760760
m_next = jnp.maximum(m_prev, m_curr)
761761
assert m_next.shape == (bq, NUM_LANES)

0 commit comments

Comments
 (0)