Skip to content

Commit 383ae41

Browse files
gneculajax authors
authored andcommitted
Attempt to eliminate flakiness for jax2tf test.
PiperOrigin-RevId: 617878818
1 parent c945766 commit 383ae41

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

jax/experimental/jax2tf/tests/jax2tf_limitations.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,12 @@ def dot_general(cls, harness: test_harnesses.Harness):
547547
# may be more precise.
548548
custom_numeric(dtypes=[np.float16], devices=["cpu"], tol=1e-2,
549549
modes=("eager", "graph", "compiled")),
550+
# Flakiness on different_dtypes_lhs_int16_4_3_rhs_float16_3_6_dimensionnumbers_1_0_enable_xla_True
551+
# Strangely, we only see the flakiness in primitives_graph_serialization_test_gpu_pjrt_c_api
552+
custom_numeric(dtypes=[np.int16], devices=["gpu"], tol=1e-2,
553+
modes=("eager", "graph", "compiled"),
554+
enabled=(harness.params["enable_xla"] and
555+
harness.dtype != harness.params["rhs_dtype"])),
550556
]
551557

552558
@classmethod

0 commit comments

Comments
 (0)