Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 731c536

Browse files
author
Protonu Basu
committed
Add checks to validate strides before compilation
This commit is motivated by the need to implement checks to verify correct strides. Adding a new unit test with incorrect strides which looks for a new exception. Added overloaded functions to print out information from DLTensor, DLConstTensor, TensorInfo. Added a helper function to throw an invalid stride exception.
1 parent f084cfe commit 731c536

File tree

6 files changed

+134
-3
lines changed

6 files changed

+134
-3
lines changed

tc/core/compiler.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <sstream>
1919
#include <string>
2020

21+
#include "tc/core/exceptions.h"
2122
#include "tc/core/flags.h"
2223
#include "tc/core/halide_utils.h"
2324
#include "tc/core/tensor.h"
@@ -35,6 +36,19 @@ std::vector<TensorInfo> inferOutputTensorInfo(
3536
}
3637

3738
namespace detail {
39+
40+
inline void helpThrowInvalidStride(
41+
const lang::TreeRef param,
42+
const DLConstTensor* inputsInfo) {
43+
lang::ErrorReport ep = lang::ErrorReport(param);
44+
std::ostringstream err_c;
45+
err_c << ep.what() << "compilation aborted: invalid strides in tensor "
46+
<< std::endl;
47+
err_c << *(inputsInfo);
48+
err_c << std::endl;
49+
throw InvalidStrideException(err_c.str());
50+
}
51+
3852
void checkInputsCompliant(
3953
const tc2halide::HalideComponents& halideComponents,
4054
const std::vector<const DLConstTensor*>& inputsInfo) {
@@ -67,6 +81,20 @@ void checkInputsCompliant(
6781
<< "expected a tensor with " << hdim << " dimensions but found "
6882
<< dldim << " dimensions.";
6983
}
84+
auto dlstrides = inputsInfo[i]->strides;
85+
auto dlsizes = inputsInfo[i]->shape;
86+
if (dldim) {
87+
if (dlstrides[dldim - 1] < 1) {
88+
helpThrowInvalidStride(
89+
halideComponents.getDef().params()[i], inputsInfo[i]);
90+
}
91+
for (size_t j = 0; j < dldim - 1; ++j) {
92+
if (dlstrides[j] < dlstrides[j + 1] * dlsizes[j + 1]) {
93+
helpThrowInvalidStride(
94+
halideComponents.getDef().params()[i], inputsInfo[i]);
95+
}
96+
}
97+
}
7098
}
7199
}
72100

tc/core/compiler.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ std::unique_ptr<typename Backend::ExecutorType> compile(
117117
std::vector<TensorInfo> inferOutputTensorInfo(
118118
lang::TreeRef tcDefinition,
119119
const std::vector<const DLConstTensor*> inputs);
120+
120121
} // namespace detail
121122
} // namespace tc
122123

tc/core/exceptions.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/**
2+
* Copyright (c) 2017-present, Facebook, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include <stdexcept>
19+
#include <string>
20+
21+
namespace tc {
22+
23+
struct InvalidStrideException : public std::runtime_error {
24+
explicit InvalidStrideException(const std::string& s)
25+
: std::runtime_error(s) {}
26+
};
27+
28+
} // namespace tc

tc/core/tensor.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,48 @@ bool TensorInfo::operator!=(const TensorInfo& t) const {
107107
return !(*this == t);
108108
}
109109

110+
std::ostream& operator<<(std::ostream& os, const TensorInfo& t) {
111+
auto ndim = t.shape.size();
112+
if (ndim == 0) {
113+
return os;
114+
}
115+
os << "Dimensions: [";
116+
if (ndim) {
117+
for (size_t i = 0; i < ndim - 1; ++i) {
118+
os << t.shape[i] << ", ";
119+
}
120+
os << t.shape[ndim - 1] << "]";
121+
os << " Strides: [";
122+
if (t.strides.size()) {
123+
for (size_t i = 0; i < ndim - 1; ++i) {
124+
os << t.strides[i] << ", ";
125+
}
126+
os << t.strides[ndim - 1] << "] ";
127+
}
128+
}
129+
os << "Type: ";
130+
os << toString(t.dtype) << " ";
131+
return os;
132+
}
133+
134+
std::ostream& operator<<(std::ostream& os, const DLTensor& t) {
135+
auto t_info = TensorInfo(&t);
136+
os << t_info;
137+
os << "Byte Offset: ";
138+
os << t.byte_offset << " ";
139+
os << "Base Pointer: " << t.data << " ";
140+
return os;
141+
}
142+
143+
std::ostream& operator<<(std::ostream& os, const DLConstTensor& t) {
144+
auto t_info = TensorInfo(&t);
145+
os << t_info;
146+
os << "Byte Offset: ";
147+
os << t.byte_offset << " ";
148+
os << "Base Pointer: " << t.data << " ";
149+
return os;
150+
}
151+
110152
std::vector<TensorInfo> makeTensorInfoVector(
111153
const google::protobuf::RepeatedPtrField<TensorInfoProto>& buf) {
112154
std::vector<TensorInfo> iis;

tc/core/tensor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ bool operator==(const DLDataType& t1, const DLDataType& t2);
127127
std::string toString(const DLDataType& t);
128128
std::ostream& operator<<(std::ostream& os, const DLDataType& t);
129129
std::ostream& operator<<(std::ostream& os, const DLTensor& t);
130+
std::ostream& operator<<(std::ostream& os, const TensorInfo& t);
131+
std::ostream& operator<<(std::ostream& os, const DLConstTensor& t);
130132

131133
// Basic metadata-owning DLTensor, only copies the underlying raw pointer.
132134
DLTensorUPtr makeDLTensor(const DLTensor* ptr);

test/cuda/test_tc_mapper.cc

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "tc/aten/aten_compiler.h"
2323
#include "tc/core/cuda/cuda.h"
2424
#include "tc/core/cuda/cuda_tc_executor.h"
25+
#include "tc/core/exceptions.h"
2526
#include "tc/core/scope_guard.h"
2627
#include "tc/lang/canonicalize.h"
2728
#include "tc/lang/sema.h"
@@ -306,10 +307,10 @@ TEST_F(TcCudaMapperTest, TensorAddStrided) {
306307
M = 64;
307308
at::Tensor I0 = at::CUDA(at::kFloat).rand({N, M});
308309
at::Tensor I0_view =
309-
I0.type().tensor().set_(*I0.storage(), 0, {N, M}, {1, 16});
310+
I0.type().tensor().set_(*I0.storage(), 0, {N, M}, {128, 1});
310311
at::Tensor I1 = at::CUDA(at::kFloat).rand({N, M});
311312
at::Tensor I1_view =
312-
I1.type().tensor().set_(*I1.storage(), 0, {N, M}, {1, 16});
313+
I1.type().tensor().set_(*I1.storage(), 0, {N, M}, {128, 1});
313314
std::vector<at::Tensor> inputs = {I0_view, I1_view};
314315

315316
static constexpr auto TC = R"TC(
@@ -327,12 +328,41 @@ def tensoraddstrided(float(N, M) I0_view, float(N, M) I1_view) -> (O) {
327328
std::string expected =
328329
"const float32 (*I0_view)[64] = "
329330
"reinterpret_cast<const float32 (*)[64]>(pI0_view)";
330-
331331
ASSERT_NE(std::string::npos, res.second.find(expected))
332332
<< "In resulting code:\n"
333333
<< res.second << "\nfound unexpected: " << expected;
334334
}
335335

336+
///////////////////////////////////////////////////////////////////////////////
337+
// TensorAddInvalidStrides
338+
// O(n, m) += I0_view(n, m) * I1_view(n, m)
339+
///////////////////////////////////////////////////////////////////////////////
340+
TEST_F(TcCudaMapperTest, TensorAddInvalidStrides) {
341+
N = 64;
342+
M = 64;
343+
at::Tensor I0 = at::CUDA(at::kFloat).rand({N, M});
344+
at::Tensor I0_view =
345+
I0.type().tensor().set_(*I0.storage(), 0, {N, M}, {16, 1});
346+
at::Tensor I1 = at::CUDA(at::kFloat).rand({N, M});
347+
at::Tensor I1_view =
348+
I1.type().tensor().set_(*I1.storage(), 0, {N, M}, {16, 1});
349+
std::vector<at::Tensor> inputs = {I0_view, I1_view};
350+
351+
static constexpr auto TC = R"TC(
352+
def tensoraddstrided(float(N, M) I0_view, float(N, M) I1_view) -> (O) {
353+
O(n, m) += I0_view(n, m) + I1_view(n, m)
354+
}
355+
)TC";
356+
357+
auto checkFun = [](const std::vector<at::Tensor>& ins,
358+
std::vector<at::Tensor>& outs) { return true; };
359+
auto options = tc::CudaMappingOptions::makeNaiveMappingOptions();
360+
auto name = "tensoraddstrided";
361+
362+
EXPECT_THROW(
363+
Check(TC, name, options, inputs, checkFun), tc::InvalidStrideException);
364+
}
365+
336366
///////////////////////////////////////////////////////////////////////////////
337367
// Lookup Table
338368
// O(b, n) +=! LUT(I(b, n), r_r)

0 commit comments

Comments
 (0)