|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | """Tests for call_tf."""
|
| 15 | + |
| 16 | +import contextlib |
15 | 17 | from functools import partial
|
16 | 18 | import os
|
17 | 19 | from typing import Callable
|
|
22 | 24 | from absl.testing import parameterized
|
23 | 25 | import jax
|
24 | 26 | from jax import config
|
| 27 | +from jax import dlpack |
25 | 28 | from jax import dtypes
|
26 | 29 | from jax import lax
|
27 | 30 | from jax import numpy as jnp
|
28 | 31 | from jax._src import test_util as jtu
|
| 32 | +from jax._src import xla_bridge |
29 | 33 | from jax._src.lib.mlir import ir
|
30 | 34 | from jax._src.lib.mlir.dialects import hlo
|
31 |
| -from jax.experimental import jax2tf |
32 | 35 | from jax.experimental import export
|
| 36 | +from jax.experimental import jax2tf |
33 | 37 | from jax.experimental.jax2tf.tests import tf_test_util
|
34 | 38 | import numpy as np
|
35 | 39 |
|
@@ -814,6 +818,49 @@ def f_jax(x):
|
814 | 818 | res = f_tf(x)
|
815 | 819 | self.assertAllClose(res, f_jax(x))
|
816 | 820 |
|
| 821 | + @parameterized.named_parameters( |
| 822 | + {"testcase_name": f"_type={type_.__name__}", "type_": type_} |
| 823 | + for type_ in dlpack.SUPPORTED_DTYPES |
| 824 | + ) |
| 825 | + def test_avoid_copy_between_gpu_and_cpu(self, type_): |
| 826 | + try: |
| 827 | + gpu_devices = jax.devices("gpu") |
| 828 | + except RuntimeError: |
| 829 | + gpu_devices = [] |
| 830 | + if not gpu_devices: |
| 831 | + raise unittest.SkipTest("Test requires a GPU device.") |
| 832 | + |
| 833 | + def tf_fun(x): |
| 834 | + if type_ == jnp.bool_: |
| 835 | + return tf.math.logical_or(x, True) |
| 836 | + else: |
| 837 | + return x + 1 |
| 838 | + |
| 839 | + jax_array_on_gpu = jnp.zeros([1], type_, device=gpu_devices[0]) |
| 840 | + |
| 841 | + # Since the input array is already on a GPU device, we expect that no memory |
| 842 | + # copy occurs between GPU and CPU. Thus, we expect no errors raised by the |
| 843 | + # transfer guard. |
| 844 | + # There are two exceptions: |
| 845 | + # First, when dtype is "int32". This is because almost all TensorFlow |
| 846 | + # kernels for GPU devices keep int32 tensors in host memory. |
| 847 | + # (https://github.com/tensorflow/tensorflow/blob/4eb3e36d1b0cd511e1677e740bd093f42365cf9f/tensorflow/python/eager/pywrap_tensor.cc#L352-L354) |
| 848 | + # Hence, for "int32", we do expect a "host-to-device" copy. |
| 849 | + # Second, when using PJRT C API runtime. This is because it currently skips dlpack |
| 850 | + # to workaround "PJRT C API does not support GetDefaultLayout" runtime error. |
| 851 | + # https://github.com/openxla/xla/blob/762bde36adf22792e91c38fe87cabe5af05bfadc/xla/pjrt/pjrt_c_api_client.h#L285-L289 |
| 852 | + @contextlib.contextmanager |
| 853 | + def _transfer_guard(guard_level): |
| 854 | + with contextlib.ExitStack() as stack: |
| 855 | + stack.enter_context(jax.transfer_guard_device_to_device(guard_level)) |
| 856 | + stack.enter_context(jax.transfer_guard_device_to_host(guard_level)) |
| 857 | + if not (type_ == jnp.int32 or xla_bridge.using_pjrt_c_api()): |
| 858 | + stack.enter_context(jax.transfer_guard_host_to_device(guard_level)) |
| 859 | + yield |
| 860 | + |
| 861 | + with _transfer_guard("disallow_explicit"): |
| 862 | + jax2tf.call_tf(tf_fun)(jax_array_on_gpu) |
| 863 | + |
817 | 864 |
|
818 | 865 | class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
|
819 | 866 | "Reloading output of jax2tf into JAX with call_tf"
|
|
0 commit comments