diff --git a/include/tensor.cuh b/include/tensor.cuh index 84a7e3f..9e1ace7 100644 --- a/include/tensor.cuh +++ b/include/tensor.cuh @@ -515,7 +515,7 @@ DTensor::DTensor(const DTensor &other, size_t axis, size_t from, size_t to m_numCols = other.m_numCols; m_numMats = len; } else if (axis == 1) { - offset = other.m_numCols * from; + offset = other.m_numRows * from; m_numRows = other.m_numRows; m_numCols = len; m_numMats = 1; diff --git a/test/testTensor.cu b/test/testTensor.cu index 8d3c58e..3d1a7b4 100644 --- a/test/testTensor.cu +++ b/test/testTensor.cu @@ -215,7 +215,7 @@ void tensorSlicingConstructorAxis1() { EXPECT_EQ(2, tenzSlice.numRows()); EXPECT_EQ(2, tenzSlice.numCols()); EXPECT_EQ(1, tenzSlice.numMats()); - std::vector expected = {4, 5, 6, 7}; + std::vector expected = {3, 4, 5, 6}; std::vector tenzSliceDown(4); tenzSlice.download(tenzSliceDown); EXPECT_EQ(expected, tenzSliceDown);