Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 8c345a5

Browse files
authored
Merge pull request #118 from facebookresearch/fcrelu-where
Add unit test for broadcasted fcrelu
2 parents 157372c + 3909efd commit 8c345a5

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) 2017-present, Facebook, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
##############################################################################
15+
16+
import tensor_comprehensions as tc
17+
18+
import torch
19+
import torch.cuda
20+
import unittest
21+
22+
23+
class TestBroadcastFCRelu(unittest.TestCase):
24+
25+
def test_broadcast_fcrelu(self):
26+
LANG = """
27+
def fcrelu(float(B, M) I, float(N, M) W1, float(N) B1) -> (O1) {
28+
O1(b, n) = B1(n) where b in 0:B
29+
O1(b, n) += I(b, m) * W1(n, m)
30+
O1(b, n) = fmax(O1(b, n), 0)
31+
}
32+
"""
33+
B, M, N = 100, 128, 100
34+
fcrelu = tc.define(LANG, name="fcrelu")
35+
I, W1, B1 = torch.randn(B, M).cuda(), torch.randn(N, M).cuda(), torch.randn(N).cuda()
36+
out = fcrelu(I, W1, B1)
37+
38+
39+
if __name__ == '__main__':
40+
unittest.main()

test_python/run_test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ $PYTHON layers/test_dump_cuda.py -v
5151
$PYTHON layers/test_external_cuda_injection.py -v
5252
$PYTHON layers/test_fc.py -v
5353
$PYTHON layers/test_fusion_fcrelu.py -v
54+
$PYTHON layers/test_broadcast_fcrelu.py -v
5455
$PYTHON layers/test_group_convolution.py -v
5556
$PYTHON layers/test_group_convolution_strided.py -v
5657
$PYTHON layers/test_indexing.py -v

0 commit comments

Comments
 (0)