Skip to content

Commit 61d7877

Browse files
authored
Fix view scalar bug segfault (#1603)
* fix view scalar bug * fix view scalar bug * one more fix
1 parent 5e89aac commit 61d7877

File tree

4 files changed

+7
-3
lines changed

4 files changed

+7
-3
lines changed

mlx/backend/common/primitives.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ void View::eval_cpu(const std::vector<array>& inputs, array& out) {
606606
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
607607
in.flags().row_contiguous) {
608608
auto strides = in.strides();
609-
for (int i = 0; i < strides.size() - 1; ++i) {
609+
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
610610
strides[i] *= ibytes;
611611
strides[i] /= obytes;
612612
}

mlx/backend/metal/primitives.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ void View::eval_gpu(const std::vector<array>& inputs, array& out) {
417417
if (ibytes == obytes || obytes < ibytes && in.strides().back() == 1 ||
418418
in.flags().row_contiguous) {
419419
auto strides = in.strides();
420-
for (int i = 0; i < strides.size() - 1; ++i) {
420+
for (int i = 0; i < static_cast<int>(strides.size()) - 1; ++i) {
421421
strides[i] *= ibytes;
422422
strides[i] /= obytes;
423423
}

mlx/ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4587,7 +4587,7 @@ array view(const array& a, const Dtype& dtype, StreamOrDevice s /* = {} */) {
45874587
" axis must be a multiple of the requested type size.");
45884588
}
45894589
out_shape.back() /= (obytes / ibytes);
4590-
} else {
4590+
} else if (ibytes > obytes) {
45914591
// Type size ratios are always integers
45924592
out_shape.back() *= (ibytes / obytes);
45934593
}

python/tests/test_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2532,6 +2532,10 @@ def test_conjugate(self):
25322532
self.assertTrue(np.array_equal(np.array(out_mlx), out_np))
25332533

25342534
def test_view(self):
2535+
# Check scalar
2536+
out = mx.array(1, mx.int8).view(mx.uint8).item()
2537+
self.assertEqual(out, 1)
2538+
25352539
a = mx.random.randint(shape=(4, 2, 4), low=-100, high=100)
25362540
a_np = np.array(a)
25372541

0 commit comments

Comments
 (0)