@@ -15,19 +15,19 @@ Tensor Comprehension Notation
15
15
-----------------------------
16
16
TC borrow three ideas from Einstein notation that make expressions concise:
17
17
18
- 1 . loop index variables are defined implicitly by using them in an expression and their range is aggressively inferred based on what they index,
19
- 2 . indices that appear on the right of an expression but not on the left are assumed to be reduction dimensions,
20
- 3 . the evaluation order of points in the iteration space does not affect the output.
18
+ 1 . Loop index variables are defined implicitly by using them in an expression and their range is aggressively inferred based on what they index.
19
+ 2 . Indices that appear on the right of an expression but not on the left are assumed to be reduction dimensions.
20
+ 3 . The evaluation order of points in the iteration space does not affect the output.
21
21
22
22
Let's start with a simple example is a matrix vector product:
23
23
24
- def mv(float(R,C) A, float(C) x ) -> (o) {
25
- o(i) +=! A(i,j) * b (j)
24
+ def mv(float(R,C) A, float(C) B ) -> (o) {
25
+ o(i) +=! A(i,j) * B (j)
26
26
}
27
27
28
28
` A ` and ` x ` are input tensors. ` o ` is an output tensor.
29
- The statement ` o(i) += A(i,j)* b(j) ` introduces two index variables ` i ` and ` j ` .
30
- Their range is inferred by their use indexing ` A ` and ` b ` . ` i = [0,R) ` , ` j = [0,C) ` .
29
+ The statement ` o(i) += A(i,j) * b(j) ` introduces two index variables ` i ` and ` j ` .
30
+ Their range is inferred by their use indexing ` A ` and ` B ` . ` i = [0,R) ` , ` j = [0,C) ` .
31
31
Because ` j ` only appears on the right side,
32
32
stores into ` o ` will reduce over ` j ` with the reduction specified for the loop.
33
33
Reductions can occur across multiple variables, but they all share the same kind of associative reduction (e.g. +=)
@@ -36,7 +36,7 @@ to maintain invariant (3). `mv` computes the same thing as this C++ loop:
36
36
for(int i = 0; i < R; i++) {
37
37
o(i) = 0.0f;
38
38
for(int j = 0; j < C; j++) {
39
- o(i) += A(i,j) * b (j);
39
+ o(i) += A(i,j) * B (j);
40
40
}
41
41
}
42
42
@@ -47,30 +47,33 @@ Examples of TC
47
47
48
48
We provide a few basic examples.
49
49
50
- Simple matrix-vector:
50
+ ** Simple matrix-vector** :
51
51
52
- def mv(float(R,C) A, float(C) x ) -> (o) {
53
- o(i) += A(i,j) * b (j)
52
+ def mv(float(R,C) A, float(C) B ) -> (o) {
53
+ o(i) += A(i,j) * B (j)
54
54
}
55
55
56
- Simple matrix-multiply (note the layout for B is transposed and matches the
56
+ ** Simple matrix-multiply:**
57
+
58
+ Note the layout for B is transposed and matches the
57
59
traditional layout of the weight matrix in a linear layer):
58
60
59
61
def mm(float(X,Y) A, float(Y,Z) B) -> (R) {
60
62
R(i,j) += A(i,j) * B(j,k)
61
63
}
62
64
63
- Simple 2-D convolution (no stride, no padding):
65
+ ** Simple 2-D convolution (no stride, no padding):**
64
66
65
67
def conv(float(B,IP,H,W) input, float(OP,IP,KH,KW) weight) -> (output) {
66
68
output(b, op, h, w) += input(b, ip, h + kh, w + kw) * weight(op, ip, kh, kw)
67
69
}
68
70
69
- Simple 2D max pooling (note the similarity with a convolution with a
71
+ ** Simple 2D max pooling:**
72
+
73
+ Note the similarity with a convolution with a
70
74
"select"-style kernel):
71
75
72
76
def maxpool2x2(float(B,C,H,W) input) -> (output) {
73
77
output(b,c,i,j) max= input(b,c,2*i + kw, 2*j + kh)
74
78
where kw = [0, 2[, kh = [0, 2[
75
79
}
76
-
0 commit comments