Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 063597c

Browse files
abadamsftynse
authored andcommitted
Allow computed expressions on the left-hand-side
1 parent 278f174 commit 063597c

File tree

7 files changed

+155
-41
lines changed

7 files changed

+155
-41
lines changed

tc/core/tc2halide.cc

Lines changed: 69 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ Expr translateExpr(
216216
}
217217
}
218218

219-
vector<const Variable*> unboundVariables(const vector<Var>& lhs, Expr rhs) {
219+
vector<const Variable*> unboundVariables(const vector<Expr>& lhs, Expr rhs) {
220220
class FindUnboundVariables : public IRVisitor {
221221
using IRVisitor::visit;
222222

@@ -241,14 +241,19 @@ vector<const Variable*> unboundVariables(const vector<Var>& lhs, Expr rhs) {
241241
set<string> visited;
242242

243243
public:
244-
FindUnboundVariables(const vector<Var>& lhs) {
245-
for (auto v : lhs) {
246-
bound.push(v.name());
244+
FindUnboundVariables(const vector<Expr>& lhs) {
245+
for (auto e : lhs) {
246+
if (const Variable* v = e.as<Variable>()) {
247+
bound.push(v->name);
248+
}
247249
}
248250
}
249251
vector<const Variable*> result;
250252
} finder(lhs);
251253
rhs.accept(&finder);
254+
for (auto e : lhs) {
255+
e.accept(&finder);
256+
}
252257
return finder.result;
253258
}
254259

@@ -507,22 +512,31 @@ void translateComprehension(
507512
f = Function(c.ident().name());
508513
(*funcs)[c.ident().name()] = f;
509514
}
515+
516+
// we currently inline all of the let bindings generated in where clauses
517+
// in the future we may consider using Halide Let bindings when they
518+
// are supported later
519+
map<string, Expr> lets;
520+
510521
// Function is the internal Halide IR type for a pipeline
511522
// stage. Func is the front-end class that wraps it. Here it's
512523
// convenient to use both.
513524
Func func(f);
514525

515-
vector<Var> lhs;
516-
vector<Expr> lhs_as_exprs;
517-
for (lang::Ident id : c.indices()) {
518-
lhs.push_back(Var(id.name()));
519-
lhs_as_exprs.push_back(lhs.back());
526+
vector<Expr> lhs;
527+
vector<Var> lhs_vars;
528+
bool total_definition = true;
529+
for (lang::TreeRef idx : c.indices()) {
530+
Expr e = translateExpr(idx, params, *funcs, lets);
531+
if (const Variable* op = e.as<Variable>()) {
532+
lhs_vars.push_back(Var(op->name));
533+
} else {
534+
total_definition = false;
535+
lhs_vars.push_back(Var());
536+
}
537+
lhs.push_back(e);
520538
}
521539

522-
// we currently inline all of the let bindings generated in where clauses
523-
// in the future we may consider using Halide Let bindings when they
524-
// are supported later
525-
map<string, Expr> lets;
526540
for (auto wc : c.whereClauses()) {
527541
if (wc->kind() == lang::TK_LET) {
528542
auto let = lang::Let(wc);
@@ -546,9 +560,8 @@ void translateComprehension(
546560
auto setupIdentity = [&](const Expr& identity, bool zero) {
547561
if (!f.has_pure_definition()) {
548562
added_implicit_initialization = true;
549-
func(lhs) = (zero) ? identity
550-
: undef(rhs.type()); // undef causes the original value
551-
// to remain in input arrays
563+
// undef causes the original value to remain in input arrays
564+
func(lhs_vars) = (zero) ? identity : undef(rhs.type());
552565
}
553566
};
554567

@@ -587,6 +600,9 @@ void translateComprehension(
587600
break;
588601

589602
case '=':
603+
if (!total_definition) {
604+
setupIdentity(rhs, false);
605+
}
590606
break;
591607
default:
592608
throw lang::ErrorReport(c) << "Unimplemented reduction "
@@ -618,9 +634,10 @@ void translateComprehension(
618634
for (auto& exp : all_exprs) {
619635
exp = bindParams.mutate(exp);
620636
}
621-
622-
// TODO: When the LHS incorporates general expressions we'll need to
623-
// bind params there too.
637+
for (auto& e : lhs) {
638+
e = bindParams.mutate(e);
639+
all_exprs.push_back(e);
640+
}
624641

625642
// Do forward bounds inference -- construct an expression that says
626643
// this expression never reads out of bounds on its inputs, and
@@ -660,19 +677,34 @@ void translateComprehension(
660677
// (e.g. an in-place stencil)?. The .bound directive will use the
661678
// bounds of the last stage for all stages.
662679

663-
// Does a tensor have a single bound, or can its bounds shrink over
664-
// time? Solve for a single bound for now.
680+
// Set the bounds to be the union of the boxes written to by every
681+
// comprehension touching the tensor.
682+
for (size_t i = 0; i < lhs.size(); i++) {
683+
Expr e = lhs[i];
684+
if (const Variable* v = e.as<Variable>()) {
685+
if (!solution.contains(v->name)) {
686+
throw lang::ErrorReport(c)
687+
<< "Free variable " << v
688+
<< " was not solved in range inference. May not be used right-hand side";
689+
}
690+
}
665691

666-
for (Var v : lhs) {
667-
if (!solution.contains(v.name())) {
668-
throw lang::ErrorReport(c)
669-
<< "Free variable " << v
670-
<< " was not solved in range inference. May not be used right-hand side";
692+
Interval in = bounds_of_expr_in_scope(e, solution);
693+
if (!in.is_bounded()) {
694+
throw lang::ErrorReport(c.indices()[i])
695+
<< "Left-hand side expression is unbounded";
671696
}
672-
// TODO: We're enforcing a single bound across all comprehensions
673-
// for now. We should really check later ones are equal to earlier
674-
// ones instead of just clobbering.
675-
(*bounds)[f][v.name()] = solution.get(v.name());
697+
in.min = cast<int>(in.min);
698+
in.max = cast<int>(in.max);
699+
700+
map<string, Interval>& b = (*bounds)[f];
701+
string dim_name = f.dimensions() ? f.args()[i] : lhs_vars[i].name();
702+
auto old = b.find(dim_name);
703+
if (old != b.end()) {
704+
// Take the union with any existing bounds
705+
in.include(old->second);
706+
}
707+
b[dim_name] = in;
676708
}
677709

678710
// Free variables that appear on the rhs but not the lhs are
@@ -703,6 +735,9 @@ void translateComprehension(
703735
for (auto v : unbound) {
704736
Expr rv = Variable::make(Int(32), v->name, domain);
705737
rhs = substitute(v->name, rv, rhs);
738+
for (Expr& e : lhs) {
739+
e = substitute(v->name, rv, e);
740+
}
706741
}
707742
rdom = RDom(domain);
708743
}
@@ -718,9 +753,12 @@ void translateComprehension(
718753
}
719754
}
720755
while (!lhs.empty()) {
721-
loop_nest.push_back(lhs.back());
756+
if (const Variable* v = lhs.back().as<Variable>()) {
757+
loop_nest.push_back(Var(v->name));
758+
}
722759
lhs.pop_back();
723760
}
761+
stage.reorder(loop_nest);
724762

725763
if (added_implicit_initialization) {
726764
// Also reorder reduction initializations to the TC convention
@@ -734,7 +772,6 @@ void translateComprehension(
734772
}
735773

736774
func.compute_root();
737-
stage.reorder(loop_nest);
738775
}
739776

740777
HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {

tc/lang/parser.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,15 @@ struct Parser {
151151
TreeRef parseExpList() {
152152
return parseList('(', ',', ')', [&](int i) { return parseExp(); });
153153
}
154+
TreeRef parseOptionalExpList() {
155+
TreeRef list = nullptr;
156+
if (L.cur().kind == '(') {
157+
list = parseExpList();
158+
} else {
159+
list = List::create(L.cur().range, {});
160+
}
161+
return list;
162+
}
154163
TreeRef parseIdentList() {
155164
return parseList('(', ',', ')', [&](int i) { return parseIdent(); });
156165
}
@@ -226,7 +235,7 @@ struct Parser {
226235
}
227236
TreeRef parseStmt() {
228237
auto ident = parseIdent();
229-
TreeRef list = parseOptionalIdentList();
238+
TreeRef list = parseOptionalExpList();
230239
auto assign = parseAssignment();
231240
auto rhs = parseExp();
232241
TreeRef equivalent_statement = parseEquivalent();

tc/lang/sema.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -442,9 +442,11 @@ struct Sema {
442442

443443
// register index variables (non-reductions)
444444
for (const auto& index : stmt.indices()) {
445-
std::string idx = index.name();
446-
auto typ = indexType(index);
447-
insert(index_env, index, typ, true);
445+
if (index->kind() == TK_IDENT) {
446+
std::string idx = Ident(index).name();
447+
auto typ = indexType(index);
448+
insert(index_env, Ident(index), typ, true);
449+
}
448450
}
449451

450452
// make dimension variables for each dimension of the output tensor
@@ -465,6 +467,9 @@ struct Sema {
465467
auto where_clauses_ = stmt.whereClauses().map(
466468
[&](TreeRef rc) { return checkWhereClause(rc); });
467469

470+
auto indices_ =
471+
stmt.indices().map([&](TreeRef idx) { return checkExp(idx, true); });
472+
468473
TreeRef rhs_ = checkExp(stmt.rhs(), true);
469474
TreeRef scalar_type = typeOfExpr(rhs_);
470475

@@ -525,7 +530,7 @@ struct Sema {
525530
TreeRef result = Comprehension::create(
526531
stmt.range(),
527532
stmt.ident(),
528-
stmt.indices(),
533+
indices_,
529534
stmt.assignment(),
530535
rhs_,
531536
where_clauses_,

tc/lang/tc_format.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ std::ostream& operator<<(std::ostream& s, const Param& p) {
6060
}
6161

6262
std::ostream& operator<<(std::ostream& s, const Comprehension& comp) {
63-
s << comp.ident() << "(" << comp.indices() << ") "
64-
<< kindToToken(comp.assignment()->kind()) << " ";
63+
s << comp.ident() << "(";
64+
showList(s, comp.indices(), showExpr);
65+
s << ") " << kindToToken(comp.assignment()->kind()) << " ";
6566
showExpr(s, comp.rhs());
6667
if (!comp.whereClauses().empty())
6768
throw std::runtime_error("Printing of where clauses is not supported yet");

tc/lang/tree_views.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,8 @@ struct Comprehension : public TreeView {
386386
Ident ident() const {
387387
return Ident(subtree(0));
388388
}
389-
ListView<Ident> indices() const {
390-
return ListView<Ident>(subtree(1));
389+
ListView<TreeRef> indices() const {
390+
return ListView<TreeRef>(subtree(1));
391391
}
392392
// kind == '=', TK_PLUS_EQ, TK_PLUS_EQ_B, etc.
393393
TreeRef assignment() const {

test/cuda/test_execution_engine.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,25 @@ def concat(float(M, N) A, float(M, N) B) -> (O1) {
145145
outputs);
146146
}
147147

148+
TEST_F(ATenCompilationUnitTest, Concat2) {
149+
at::Tensor a = at::CUDA(at::kFloat).rand({32, 16});
150+
at::Tensor b = at::CUDA(at::kFloat).rand({32, 16});
151+
std::vector<at::Tensor> inputs = {a, b};
152+
std::vector<at::Tensor> outputs;
153+
154+
Check(
155+
R"(
156+
def concat(float(M, N) A, float(M, N) B) -> (O1) {
157+
O1(n, 0, m) = A(m, n)
158+
O1(n, 1, m) = B(m, n)
159+
}
160+
)",
161+
"concat",
162+
tc::CudaMappingOptions::makeNaiveCudaMappingOptions(),
163+
inputs,
164+
outputs);
165+
}
166+
148167
TEST_F(ATenCompilationUnitTest, Indexing) {
149168
at::Tensor a = at::CUDA(at::kFloat).rand({3, 4});
150169
at::Tensor b = at::CUDA(at::kInt).ones({2});

test/cuda/test_tc_mapper.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,49 @@ def fun(float(B, R) LUT, int32(B, N) I) -> (O) {
352352
checkFun);
353353
}
354354

355+
TEST_F(TcCudaMapperTest, Histogram) {
356+
const int N = 17, M = 82;
357+
at::Tensor I =
358+
at::CUDA(at::kFloat).rand({N, M}).mul_(256).floor_().toType(at::kByte);
359+
std::vector<at::Tensor> inputs = {I};
360+
std::vector<at::Tensor> outputs;
361+
362+
static constexpr auto TC = R"TC(
363+
def fun(uint8(N, M) I) -> (O) {
364+
O(I(i, j)) +=! 1
365+
}
366+
)TC";
367+
368+
auto checkFun = [=](const std::vector<at::Tensor>& inputs,
369+
std::vector<at::Tensor>& outputs) {
370+
at::Tensor I = inputs[0].toBackend(at::kCPU);
371+
at::Tensor O = outputs[0].toBackend(at::kCPU);
372+
auto IAccessor = I.accessor<uint8_t, 2>();
373+
auto OAccessor = O.accessor<int, 1>();
374+
int sum = 0;
375+
for (int i = 0; i < 256; i++) {
376+
sum += OAccessor[i];
377+
}
378+
CHECK_EQ(sum, N * M);
379+
380+
for (int i = 0; i < N; i++) {
381+
for (int j = 0; j < M; j++) {
382+
OAccessor[IAccessor[i][j]]--;
383+
}
384+
}
385+
386+
for (int i = 0; i < 256; i++) {
387+
CHECK_EQ(OAccessor[i], 0);
388+
}
389+
};
390+
Check(
391+
TC,
392+
"fun",
393+
tc::CudaMappingOptions::makeNaiveCudaMappingOptions(),
394+
inputs,
395+
checkFun);
396+
}
397+
355398
///////////////////////////////////////////////////////////////////////////////
356399
// SpatialBatchNormalization
357400
///////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)