Skip to content

Commit 376225f

Browse files
Fix torch's convert_to_tensor not respecting dtype when input is a Variable (#21452)
* Fix torch's `convert_to_tensor` not respecting `dtype` when input is a `Variable`. * Fix openvino backend. * Revert the casting and use variables for the tests.
1 parent 22ab5a3 commit 376225f

File tree

4 files changed

+18
-15
lines changed

4 files changed

+18
-15
lines changed

keras/src/backend/openvino/core.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -600,18 +600,17 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
600600
raise ValueError("`sparse=True` is not supported with openvino backend")
601601
if ragged:
602602
raise ValueError("`ragged=True` is not supported with openvino backend")
603+
if dtype is not None:
604+
dtype = standardize_dtype(dtype)
603605
if isinstance(x, OpenVINOKerasTensor):
604606
return x
605607
elif isinstance(x, np.ndarray):
606608
if dtype is not None:
607-
dtype = standardize_dtype(dtype)
608609
ov_type = OPENVINO_DTYPES[dtype]
609610
return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0))
610611
return OpenVINOKerasTensor(ov_opset.constant(x).output(0))
611612
elif isinstance(x, (list, tuple)):
612-
if dtype is not None:
613-
dtype = standardize_dtype(dtype)
614-
else:
613+
if dtype is None:
615614
# try to properly deduce element type
616615
elem = _get_first_element(x)
617616
if isinstance(elem, float):
@@ -624,12 +623,11 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
624623
dtype = standardize_dtype(dtype)
625624
ov_type = OPENVINO_DTYPES[dtype]
626625
return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0), x)
627-
if dtype is not None:
628-
dtype = standardize_dtype(dtype)
629626
if isinstance(x, Variable):
627+
x = x.value
630628
if dtype and dtype != x.dtype:
631-
return x.value.astype(dtype)
632-
return x.value
629+
x = cast(x, dtype)
630+
return x
633631
if not is_tensor(x) and standardize_dtype(dtype) == "bfloat16":
634632
return ov.Tensor(np.asarray(x).astype(dtype))
635633
if dtype is None:

keras/src/backend/torch/core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,10 +192,10 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
192192
if ragged:
193193
raise ValueError("`ragged=True` is not supported with torch backend")
194194
if isinstance(x, Variable):
195-
# TorchDynamo has bugs supporting nn.Parameter type check.
196-
# Return it directly instead of pass it to the rest of the logic in the
197-
# function.
198-
return x.value
195+
if dtype is None:
196+
return x.value
197+
x = x.value
198+
return x.to(to_torch_dtype(dtype))
199199
if is_tensor(x):
200200
device = get_device()
201201
if x.device != device:

keras/src/ops/core_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,6 +1150,11 @@ def test_convert_to_tensor(self, x, dtype, expected_dtype):
11501150
ops.convert_to_tensor(x, dtype=dtype), expected_dtype
11511151
)
11521152

1153+
@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
1154+
def test_convert_to_tensor_with_variable(self, dtype):
1155+
x = backend.Variable(np.array([1.0, 0.0, 1.0], dtype=np.float32))
1156+
self.assertDType(ops.convert_to_tensor(x, dtype=dtype), dtype)
1157+
11531158
@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
11541159
def test_saturate_cast(self, dtype):
11551160
x = np.ones((1,))

keras/src/ops/nn_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3137,7 +3137,7 @@ def test_dot_product_attention(self, dtype):
31373137
def test_rms_normalization(self, dtypes):
31383138
input_dtype, weight_dtype = dtypes
31393139
inputs = knp.ones((2, 8), dtype=input_dtype)
3140-
scale = knp.ones((8,), dtype=weight_dtype)
3140+
scale = backend.Variable(knp.ones((8,), dtype=weight_dtype))
31413141
expected_dtype = input_dtype
31423142

31433143
self.assertDType(knn.rms_normalization(inputs, scale), expected_dtype)
@@ -3151,8 +3151,8 @@ def test_rms_normalization(self, dtypes):
31513151
def test_layer_normalization(self, dtypes):
31523152
input_dtype, weight_dtype = dtypes
31533153
inputs = knp.ones((2, 8), dtype=input_dtype)
3154-
gamma = knp.ones((8,), dtype=weight_dtype)
3155-
beta = knp.ones((8,), dtype=weight_dtype)
3154+
gamma = backend.Variable(knp.ones((8,), dtype=weight_dtype))
3155+
beta = backend.Variable(knp.ones((8,), dtype=weight_dtype))
31563156
expected_dtype = input_dtype
31573157

31583158
self.assertDType(

0 commit comments

Comments
 (0)