Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 57d9b0e

Browse files
Merge pull request #203 from apaszke/tc_format_pr
Add TC format (a pretty printer for AST)
2 parents 8941200 + 13b0513 commit 57d9b0e

File tree

7 files changed

+251
-0
lines changed

7 files changed

+251
-0
lines changed

include/tc/lang/lexer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,11 @@ enum TokenKind {
9999
#undef DEFINE_TOKEN
100100
};
101101

102+
// Returns a human-readable description of the token
102103
std::string kindToString(int kind);
104+
// Returns the string used by the lexer to match a given token, or throws
105+
// if it can't be produced by the lexer.
106+
std::string kindToToken(int kind);
103107

104108
// nested hash tables that indicate char-by-char what is a valid token.
105109
struct TokenTrie;

include/tc/lang/tc_format.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/**
2+
* Copyright (c) 2017-present, Facebook, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include "tc/lang/tree.h"
19+
20+
#include <ostream>
21+
22+
namespace lang {
23+
24+
/// \file tc_format.h
25+
/// A pretty printer that turns a Def (TC AST) into a valid TC string that
26+
/// could be e.g. re-parsed.
27+
28+
void tcFormat(std::ostream& s, TreeRef def);
29+
30+
} // namespace lang

include/tc/lang/tree_views.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ struct ListView : public TreeView {
161161
size_t size() const {
162162
return tree_->trees().size();
163163
}
164+
bool empty() const {
165+
return size() == 0;
166+
}
164167
static TreeRef create(const SourceRange& range, TreeList elements) {
165168
return Compound::create(TK_LIST, range, std::move(elements));
166169
}

src/lang/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_library(
77

88
parser.cc
99
lexer.cc
10+
tc_format.cc
1011
)
1112

1213
install(

src/lang/lexer.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,22 @@ std::string kindToString(int kind) {
3232
}
3333
}
3434

35+
std::string kindToToken(int kind) {
36+
if (kind < 256)
37+
return std::string(1, kind);
38+
switch (kind) {
39+
#define DEFINE_CASE(tok, _, str) \
40+
case tok: \
41+
if (str == "") \
42+
throw std::runtime_error("No token for: " + kindToString(kind)); \
43+
return str;
44+
TC_FORALL_TOKEN_KINDS(DEFINE_CASE)
45+
#undef DEFINE_CASE
46+
default:
47+
throw std::runtime_error("unknown kind: " + std::to_string(kind));
48+
}
49+
}
50+
3551
SharedParserData& sharedParserData() {
3652
static SharedParserData data; // safely handles multi-threaded init
3753
return data;

src/lang/tc_format.cc

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
/**
2+
* Copyright (c) 2017-present, Facebook, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include "tc/lang/tc_format.h"
17+
#include "tc/lang/tree_views.h"
18+
19+
namespace lang {
20+
21+
namespace {
22+
23+
void showExpr(std::ostream& s, const TreeRef& expr);
24+
25+
template <typename T>
26+
void show(std::ostream& s, T x) {
27+
s << x;
28+
}
29+
30+
template <typename T, typename F>
31+
void showList(std::ostream& s, const ListView<T>& list, F elem_cb) {
32+
bool first = true;
33+
for (const auto& elem : list) {
34+
if (!first) {
35+
s << ", ";
36+
}
37+
elem_cb(s, elem);
38+
first = false;
39+
}
40+
}
41+
42+
template <typename T>
43+
std::ostream& operator<<(std::ostream& s, const ListView<T>& list) {
44+
showList(s, list, show<T>);
45+
return s;
46+
}
47+
48+
std::ostream& operator<<(std::ostream& s, const Ident& id) {
49+
return s << id.name();
50+
}
51+
52+
std::ostream& operator<<(std::ostream& s, const Param& p) {
53+
if (!p.typeIsInferred()) {
54+
TensorType type{p.type()};
55+
s << kindToString(type.scalarType()) << "(";
56+
showList(s, type.dims(), showExpr);
57+
s << ") ";
58+
}
59+
return s << p.ident();
60+
}
61+
62+
std::ostream& operator<<(std::ostream& s, const Comprehension& comp) {
63+
s << comp.ident() << "(" << comp.indices() << ") "
64+
<< kindToToken(comp.assignment()->kind()) << " ";
65+
showExpr(s, comp.rhs());
66+
if (!comp.whereClauses().empty())
67+
throw std::runtime_error("Printing of where clauses is not supported yet");
68+
if (comp.equivalent().present())
69+
throw std::runtime_error(
70+
"Printing of equivalent comprehensions is not supported yet");
71+
return s;
72+
}
73+
74+
void showExpr(std::ostream& s, const TreeRef& expr) {
75+
switch (expr->kind()) {
76+
case TK_IDENT: {
77+
s << Ident(expr);
78+
break;
79+
}
80+
case TK_AND:
81+
case TK_OR:
82+
case '<':
83+
case '>':
84+
case TK_EQ:
85+
case TK_LE:
86+
case TK_GE:
87+
case TK_NE:
88+
case '+':
89+
case '*':
90+
case '/': {
91+
s << "(";
92+
showExpr(s, expr->tree(0));
93+
s << " " << kindToToken(expr->kind()) << " ";
94+
showExpr(s, expr->tree(1));
95+
s << ")";
96+
break;
97+
// '-' is annoying because it can be both unary and binary
98+
}
99+
case '-': {
100+
if (expr->trees().size() == 1) {
101+
s << "-";
102+
showExpr(s, expr->tree(0));
103+
} else {
104+
s << "(";
105+
showExpr(s, expr->tree(0));
106+
s << " - ";
107+
showExpr(s, expr->tree(1));
108+
s << ")";
109+
}
110+
break;
111+
}
112+
case '!': {
113+
s << "!";
114+
showExpr(s, expr->tree(0));
115+
break;
116+
}
117+
case TK_CONST: {
118+
Const con{expr};
119+
int scalarType = con.type()->kind();
120+
switch (con.type()->kind()) {
121+
case TK_FLOAT:
122+
case TK_DOUBLE:
123+
s << con.value();
124+
break;
125+
case TK_UINT8:
126+
case TK_UINT16:
127+
case TK_UINT32:
128+
case TK_UINT64:
129+
s << static_cast<uint64_t>(con.value());
130+
break;
131+
case TK_INT8:
132+
case TK_INT16:
133+
case TK_INT32:
134+
case TK_INT64:
135+
s << static_cast<int64_t>(con.value());
136+
break;
137+
default:
138+
throw std::runtime_error(
139+
"Unknown scalar type in const: " +
140+
kindToString(con.type()->kind()));
141+
}
142+
break;
143+
}
144+
case TK_CAST: {
145+
Cast cast{expr};
146+
s << kindToToken(cast.type()->kind()) << "(";
147+
showExpr(s, cast.value());
148+
s << ")";
149+
break;
150+
}
151+
case '.': {
152+
Select sel{expr};
153+
s << sel.name() << "." << sel.index();
154+
break;
155+
}
156+
case TK_APPLY:
157+
case TK_ACCESS:
158+
case TK_BUILT_IN: {
159+
s << Ident(expr->tree(0)) << "(";
160+
showList(s, ListView<TreeRef>(expr->tree(1)), showExpr);
161+
s << ")";
162+
break;
163+
}
164+
default: {
165+
throw std::runtime_error(
166+
"Unexpected kind in showExpr: " + kindToString(expr->kind()));
167+
}
168+
}
169+
}
170+
171+
} // anonymous namespace
172+
173+
void tcFormat(std::ostream& s, TreeRef _def) {
174+
Def def{_def};
175+
s << "def " << def.name() << "(" << def.params() << ")"
176+
<< " -> (" << def.returns() << ") {\n";
177+
for (const Comprehension& c : def.statements()) {
178+
s << " " << c << "\n";
179+
}
180+
s << "}";
181+
}
182+
183+
} // namespace lang

test/test_lang.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "tc/lang/canonicalize.h"
2626
#include "tc/lang/parser.h"
2727
#include "tc/lang/sema.h"
28+
#include "tc/lang/tc_format.h"
2829

2930
using namespace lang;
3031

@@ -147,6 +148,17 @@ std::string canonicalText(const std::string& text) {
147148
return ss.str();
148149
}
149150

151+
void testTcFormat() {
152+
static std::ios_base::Init initIostreams;
153+
auto source = R"(def fun2(float(B, N, M) X, float(B, M, K) Y) -> (Q) {
154+
Q(b, ii, j) += (((exp(X(b, ii, k)) * int(Y(b, k, j))) * 2.5) + 3)
155+
})";
156+
auto def_tree = Parser(source).parseFunction();
157+
std::ostringstream s;
158+
tcFormat(s, def_tree);
159+
ASSERT(s.str() == source);
160+
}
161+
150162
int main(int argc, char** argv) {
151163
std::vector<std::string> args;
152164
for (int i = 1; i < argc; i++) {
@@ -319,6 +331,8 @@ int main(int argc, char** argv) {
319331
)";
320332
ASSERT(canonicalText(option_one) == canonicalText(option_two));
321333

334+
testTcFormat();
335+
322336
// assertSemaEqual(
323337
// "comments.expected",
324338
// R"(#beginning comment

0 commit comments

Comments
 (0)