Skip to content

Commit e00149c

Browse files
pclove1jax authors
authored andcommitted
Fix unnecessary memory copies between GPU and CPU when jax2tf.call_tf() is used.
- The root cause of the bug is that dtype lookups are incorrect because hashes behave differently between dtype instances and their types. Added comments to `jax.dlpack.SUPPORTED_DTYPES` about this. - Added unit test coverage. - Fixing this bug revealed a limitation of causing "host-to-device" copy in the following two situations. See the details in the unit test comments.: - When the dtype is 'int32'. - When using PJRT C API runtime. PiperOrigin-RevId: 610799558
1 parent c42a035 commit e00149c

File tree

3 files changed

+63
-6
lines changed

3 files changed

+63
-6
lines changed

jax/_src/dlpack.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@
2626
from jax._src.typing import Array
2727

2828

29+
# A set of dtypes that dlpack supports.
30+
# Note: Make sure to use a "type", not a dtype instance, when looking up this set
31+
# because their hashes are different.
32+
# For example,
33+
# hash(jnp.float32) != hash(jnp.dtype(jnp.float32))
34+
# hash(jnp.float32) == hash(jnp.dtype(jnp.float32).type)
35+
# TODO(phawkins): Migrate to using dtypes instead of the scalar type objects.
2936
SUPPORTED_DTYPES = frozenset({
3037
jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16,
3138
jnp.uint32, jnp.uint64, jnp.float16, jnp.bfloat16, jnp.float32,
@@ -76,7 +83,6 @@ def to_dlpack(x: Array, take_ownership: bool = False,
7683
) # type: ignore
7784

7885

79-
8086
def from_dlpack(external_array):
8187
"""Returns a :class:`~jax.Array` representation of a DLPack tensor.
8288

jax/experimental/jax2tf/call_tf.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from jax._src import core
4141
from jax._src import effects
4242
from jax._src import util
43+
from jax._src import xla_bridge
4344
from jax._src.lib import xla_client
4445
from jax._src.lib.mlir import ir
4546
from jax._src.lib.mlir.dialects import func as func_dialect
@@ -332,7 +333,7 @@ def _call_tf_impl(*args_jax_flat, callable_flat_tf, **_):
332333
def _arg_jax_to_tf(arg_jax):
333334
if (isinstance(arg_jax, jax.Array) and
334335
list(arg_jax.devices())[0].platform in _DLPACK_PLATFORMS and
335-
arg_jax.dtype in dlpack.SUPPORTED_DTYPES):
336+
arg_jax.dtype.type in dlpack.SUPPORTED_DTYPES):
336337
arg_dlpack = jax.dlpack.to_dlpack(arg_jax, take_ownership=False)
337338
return tf.experimental.dlpack.from_dlpack(arg_dlpack)
338339
# The following avoids copies to the host on CPU, always for Array
@@ -349,11 +350,14 @@ def _arg_jax_to_tf(arg_jax):
349350
res_tf_flat = callable_flat_tf(*args_tf_flat)
350351

351352
def _res_tf_to_jax(res_tf: TfVal):
352-
res_tf, _ = jax2tf_internal._tfval_to_tensor_jax_dtype(res_tf)
353-
if isinstance(res_tf, tf.Tensor) and res_tf.dtype in dlpack.SUPPORTED_DTYPES:
353+
res_tf, jax_dtype = jax2tf_internal._tfval_to_tensor_jax_dtype(res_tf)
354+
if isinstance(res_tf, tf.Tensor) and jax_dtype.type in dlpack.SUPPORTED_DTYPES:
354355
res_tf_platform = tf.DeviceSpec.from_string(res_tf.backing_device).device_type
355356
res_jax_platform = res_tf_platform.lower()
356-
if res_jax_platform in _DLPACK_PLATFORMS:
357+
# Skip using dlpack in PJRT C API runtime, because it currently fails
358+
# with "PJRT C API does not support GetDefaultLayout".
359+
# https://github.com/openxla/xla/blob/762bde36adf22792e91c38fe87cabe5af05bfadc/xla/pjrt/pjrt_c_api_client.h#L285-L289
360+
if res_jax_platform in _DLPACK_PLATFORMS and not xla_bridge.using_pjrt_c_api():
357361
res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf)
358362
return jax.dlpack.from_dlpack(res_dlpack)
359363

jax/experimental/jax2tf/tests/call_tf_test.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Tests for call_tf."""
15+
16+
import contextlib
1517
from functools import partial
1618
import os
1719
from typing import Callable
@@ -22,14 +24,16 @@
2224
from absl.testing import parameterized
2325
import jax
2426
from jax import config
27+
from jax import dlpack
2528
from jax import dtypes
2629
from jax import lax
2730
from jax import numpy as jnp
2831
from jax._src import test_util as jtu
32+
from jax._src import xla_bridge
2933
from jax._src.lib.mlir import ir
3034
from jax._src.lib.mlir.dialects import hlo
31-
from jax.experimental import jax2tf
3235
from jax.experimental import export
36+
from jax.experimental import jax2tf
3337
from jax.experimental.jax2tf.tests import tf_test_util
3438
import numpy as np
3539

@@ -814,6 +818,49 @@ def f_jax(x):
814818
res = f_tf(x)
815819
self.assertAllClose(res, f_jax(x))
816820

821+
@parameterized.named_parameters(
822+
{"testcase_name": f"_type={type_.__name__}", "type_": type_}
823+
for type_ in dlpack.SUPPORTED_DTYPES
824+
)
825+
def test_avoid_copy_between_gpu_and_cpu(self, type_):
826+
try:
827+
gpu_devices = jax.devices("gpu")
828+
except RuntimeError:
829+
gpu_devices = []
830+
if not gpu_devices:
831+
raise unittest.SkipTest("Test requires a GPU device.")
832+
833+
def tf_fun(x):
834+
if type_ == jnp.bool_:
835+
return tf.math.logical_or(x, True)
836+
else:
837+
return x + 1
838+
839+
jax_array_on_gpu = jnp.zeros([1], type_, device=gpu_devices[0])
840+
841+
# Since the input array is already on a GPU device, we expect that no memory
842+
# copy occurs between GPU and CPU. Thus, we expect no errors raised by the
843+
# transfer guard.
844+
# There are two exceptions:
845+
# First, when dtype is "int32". This is because almost all TensorFlow
846+
# kernels for GPU devices keep int32 tensors in host memory.
847+
# (https://github.com/tensorflow/tensorflow/blob/4eb3e36d1b0cd511e1677e740bd093f42365cf9f/tensorflow/python/eager/pywrap_tensor.cc#L352-L354)
848+
# Hence, for "int32", we do expect a "host-to-device" copy.
849+
# Second, when using PJRT C API runtime. This is because it currently skips dlpack
850+
# to workaround "PJRT C API does not support GetDefaultLayout" runtime error.
851+
# https://github.com/openxla/xla/blob/762bde36adf22792e91c38fe87cabe5af05bfadc/xla/pjrt/pjrt_c_api_client.h#L285-L289
852+
@contextlib.contextmanager
853+
def _transfer_guard(guard_level):
854+
with contextlib.ExitStack() as stack:
855+
stack.enter_context(jax.transfer_guard_device_to_device(guard_level))
856+
stack.enter_context(jax.transfer_guard_device_to_host(guard_level))
857+
if not (type_ == jnp.int32 or xla_bridge.using_pjrt_c_api()):
858+
stack.enter_context(jax.transfer_guard_host_to_device(guard_level))
859+
yield
860+
861+
with _transfer_guard("disallow_explicit"):
862+
jax2tf.call_tf(tf_fun)(jax_array_on_gpu)
863+
817864

818865
class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase):
819866
"Reloading output of jax2tf into JAX with call_tf"

0 commit comments

Comments
 (0)