Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit dd24d4d

Browse files
committed
Add docs and address comments
1 parent df50976 commit dd24d4d

File tree

3 files changed

+42
-25
lines changed

3 files changed

+42
-25
lines changed

docs/source/inference.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,3 +271,20 @@ Strided indexing with dynamic stride
271271
The value of :code:`S(0)` is not fixed until runtime, so we can't resolve the
272272
size of :code:`A` or the range of the loop. This case throws a compile-time
273273
error. A :code:`where` clause that defines the range of :code:`i` is required.
274+
275+
Constant fill using an exists clause
276+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
277+
278+
.. code::
279+
280+
def constant_fill(float(N) A, float c) -> B {
281+
B(i) = c where A(i) exists
282+
}
283+
284+
An :code:`exists` clause allows you to add additional expressions to the range
285+
inference process without having the expressions affect the actual computation.
286+
In this example, it allows you to say that :code:`B(i)` should have the same size as
287+
:code:`A(i)`, but be filled with a constant value :code:`c`. That is, you should infer the
288+
range of :code:`B(i)` to exist at all the places where :code:`A(i)` exists.
289+
It is equivalent to writing the expression :code:`true ? c : A(i)`, but with
290+
clearer intentions.

include/tc/lang/sema.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ struct Sema {
369369
TreeRef boolType(TreeRef anchor) {
370370
return c(TK_BOOL, anchor->range(), {});
371371
}
372-
void checkDim(const Ident& dim) {
372+
void checkDim(Ident dim) {
373373
insert(env, dim, dimType(dim), false);
374374
}
375375
TreeRef checkTensorType(TreeRef type) {
@@ -451,7 +451,7 @@ struct Sema {
451451
// where clauses are checked _before_ the rhs because they
452452
// introduce let bindings that are in scope for the rhs
453453
auto where_clauses_ = stmt.whereClauses().map(
454-
[&](const TreeRef& rc) { return checkWhereClause(rc); });
454+
[&](TreeRef rc) { return checkWhereClause(rc); });
455455

456456
TreeRef rhs_ = checkExp(stmt.rhs(), true);
457457
TreeRef scalar_type = typeOfExpr(rhs_);
@@ -488,7 +488,7 @@ struct Sema {
488488
live_input_names.erase(stmt.ident().name());
489489

490490
auto equivalent_statement_ =
491-
stmt.equivalent().map([&](const Equivalent& eq) {
491+
stmt.equivalent().map([&](Equivalent eq) {
492492
auto indices_ = eq.accesses().map(
493493
[&](TreeRef index) { return checkExp(index, true); });
494494
return Equivalent::create(eq.range(), eq.name(), indices_);
@@ -526,7 +526,7 @@ struct Sema {
526526

527527
return result;
528528
}
529-
bool isNotInplace(const TreeRef& assignment) {
529+
bool isNotInplace(TreeRef assignment) {
530530
switch (assignment->kind()) {
531531
case TK_PLUS_EQ_B:
532532
case TK_TIMES_EQ_B:
@@ -567,15 +567,15 @@ struct Sema {
567567
throw ErrorReport(ident) << name << " already defined";
568568
}
569569
}
570-
TreeRef lookup(const Ident& ident, bool required) {
570+
TreeRef lookup(Ident ident, bool required) {
571571
TreeRef v = lookup(index_env, ident, false);
572572
if (!v)
573573
v = lookup(let_env, ident, false);
574574
if (!v)
575575
v = lookup(env, ident, required);
576576
return v;
577577
}
578-
TreeRef lookup(Env& the_env, const Ident& ident, bool required) {
578+
TreeRef lookup(Env& the_env, Ident ident, bool required) {
579579
std::string name = ident.name();
580580
auto it = the_env.find(name);
581581
if (required && it == the_env.end()) {

test/test_corner_cases.cc

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -79,73 +79,73 @@ static void Fail(
7979
}
8080
}
8181

82-
TEST(FailTest, E1) {
82+
TEST(TestCornerCases, E1) {
8383
Fail("expected (", " def f{} {}", {}, {});
8484
}
85-
TEST(FailTest, E2) {
85+
TEST(TestCornerCases, E2) {
8686
Succeed("def f(float(1) a) -> (b) { b(i) = a(i) }", {F(1)}, {F(1)});
8787
}
8888

8989
// free(): invalid next size (fast): 0x000000003b2d6230 ***
90-
TEST(FailTest, DISABLED_E4) {
90+
TEST(TestCornerCases, DISABLED_E4) {
9191
Succeed("def f(float a) -> (b) { b = a }", {F()}, {F()});
9292
}
9393

9494
// main conflicts with program main in nvcc
95-
TEST(FailTest, DISABLED_E3) {
95+
TEST(TestCornerCases, DISABLED_E3) {
9696
Succeed(
9797
"def main(float(1) a) -> (b) { b(i) = a(i) }", {F(1)}, {F(1)}, "main");
9898
}
9999

100100
// segfaults on line:
101101
// src/aten/aten_compiler.cc:123
102102
// 123 at::Backend backend = inputs[0].type().backend();
103-
TEST(FailTest, DISABLED_E5) {
103+
TEST(TestCornerCases, DISABLED_E5) {
104104
Succeed("def f() -> (b) { b(i) = 4 where i in 0:10 }", {}, {F(0)});
105105
}
106106

107-
TEST(FailTest, E6) {
107+
TEST(TestCornerCases, E6) {
108108
Succeed("def f(float a) -> (b) { b(i) = a where i in 0:10 }", {F()}, {F(10)});
109109
}
110110

111-
TEST(FailTest, E7) {
111+
TEST(TestCornerCases, E7) {
112112
Fail(
113113
"expected 2 inputs",
114114
"def f(float a, float c) -> (b) { b(i) = a where i in 0:10 }",
115115
{F()},
116116
{F(10)});
117117
}
118118

119-
TEST(FailTest, E8) {
119+
TEST(TestCornerCases, E8) {
120120
Fail(
121121
"expected type int32",
122122
"def f(int32 a) -> (b) { b(i) = a where i in 0:10 }",
123123
{F()},
124124
{F(10)});
125125
}
126126

127-
TEST(FailTest, E9) {
127+
TEST(TestCornerCases, E9) {
128128
Fail(
129129
"expected a tensor with 0",
130130
"def f(int32 a) -> (b) { b(i) = a where i in 0:10 }",
131131
{I(1, 2)},
132132
{F(10)});
133133
}
134134

135-
TEST(FailTest, E10) {
135+
TEST(TestCornerCases, E10) {
136136
Succeed(
137137
"def f(int32 a) -> (b) { b(i) = a where i in 0:10 }", {I()}, {I(10, 10)});
138138
}
139139

140-
TEST(FailTest, E11) {
140+
TEST(TestCornerCases, E11) {
141141
Fail(
142142
"expected integral type",
143143
"def f(int32(N) a) -> (b) { b(i) = a(i + .5) }",
144144
{I()},
145145
{I(10, 10)});
146146
}
147147

148-
TEST(FailTest, E12) {
148+
TEST(TestCornerCases, E12) {
149149
// this test should eventually work when we can handle non-trivial
150150
// expressions in where clauses
151151
Fail(
@@ -155,7 +155,7 @@ TEST(FailTest, E12) {
155155
{I(10)});
156156
}
157157

158-
TEST(FailTest, E13) {
158+
TEST(TestCornerCases, E13) {
159159
// this test is harder still, because the bounds of the output
160160
// depend on the non-trivial expression
161161
Fail(
@@ -165,7 +165,7 @@ TEST(FailTest, E13) {
165165
{I(10)});
166166
}
167167

168-
TEST(FailTest, DISABLED_E14) {
168+
TEST(TestCornerCases, DISABLED_E14) {
169169
// Currently expressions in where clauses are assumed to be
170170
// affine. Needs fixing.
171171
Fail(
@@ -175,7 +175,7 @@ TEST(FailTest, DISABLED_E14) {
175175
{I(10)});
176176
}
177177

178-
TEST(FailTest, E15){
178+
TEST(TestCornerCases, E15){
179179
#define GEN_COMPARATOR(op) \
180180
{ \
181181
auto a = F(); \
@@ -195,7 +195,7 @@ TEST(FailTest, E15){
195195

196196
}
197197

198-
TEST(FailTest, E16){
198+
TEST(TestCornerCases, E16){
199199
#define GEN_BOOLS(op) \
200200
{ \
201201
auto a = F(); \
@@ -213,21 +213,21 @@ TEST(FailTest, E16){
213213

214214
GEN_BOOLS(||) GEN_BOOLS(&&)}
215215

216-
TEST(FailTest, E17) {
216+
TEST(TestCornerCases, E17) {
217217
auto r = F(1);
218218
Succeed(
219219
"def f(float(1) a) -> (b) { b(i) = 4.0 where a(i) exists }", {F(1)}, {r});
220220
CHECK_EQ(at::Scalar(r[0]).toFloat(), 4);
221221
}
222222

223-
TEST(FailTest, E18) {
223+
TEST(TestCornerCases, E18) {
224224
auto a = F(1);
225225
auto r = F(1);
226226
Succeed(
227227
"def f(float(1) a) -> (b) { b(i) = 2*foo where foo = a(i) }", {a}, {r});
228228
CHECK_EQ(at::Scalar(r[0]).toFloat(), at::Scalar(a[0]).toFloat() * 2);
229229
}
230-
TEST(FailTest, E19) {
230+
TEST(TestCornerCases, E19) {
231231
Fail(
232232
"undefined variable",
233233
"def f(float(1) a) -> (b) { b(i) = 2*foo where foo = a(i), foo in 1:2 }",

0 commit comments

Comments
 (0)