1+ // Tests for the WaveNet
2+
3+ #include < Eigen/Dense>
4+ #include < cassert>
5+ #include < iostream>
6+
7+ #include " NAM/wavenet.h"
8+
9+ namespace test_wavenet
10+ {
11+ void test_gated ()
12+ {
13+ // Assert correct nuemrics of the gating activation.
14+ // Issue 101
15+ const int conditionSize = 1 ;
16+ const int channels = 1 ;
17+ const int kernelSize = 1 ;
18+ const int dilation = 1 ;
19+ const std::string activation = " ReLU" ;
20+ const bool gated = true ;
21+ auto layer = nam::wavenet::_Layer (conditionSize, channels, kernelSize, dilation, activation, gated);
22+
23+ // Conv, input mixin, 1x1
24+ std::vector<float > weights{
25+ // Conv (weight, bias) NOTE: 2 channels out bc gated, so shapes are (2,1,1), (2,)
26+ 1 .0f , 1 .0f , 0 .0f , 0 .0f ,
27+ // Input mixin (weight only: (2,1,1))
28+ 1 .0f , -1 .0f ,
29+ // 1x1 (weight (1,1,1), bias (1,))
30+ // NOTE: Weights are (1,1) on conv, (1,-1), so the inputs sum on the upper channel and cancel on the lower.
31+ // This should give us a nice zero if the input & condition are the same, so that'll sigmoid to 0.5 for the
32+ // gate.
33+ 1 .0f , 0 .0f };
34+ auto it = weights.begin ();
35+ layer.set_weights_ (it);
36+ assert (it == weights.end ());
37+
38+ const long numFrames = 4 ;
39+ layer.set_num_frames_ (numFrames);
40+
41+ Eigen::MatrixXf input, condition, headInput, output;
42+ input.resize (channels, numFrames);
43+ condition.resize (channels, numFrames);
44+ headInput.resize (channels, numFrames);
45+ output.resize (channels, numFrames);
46+
47+ const float signalValue = 0 .25f ;
48+ input.fill (signalValue);
49+ condition.fill (signalValue);
50+ // So input & condition will sum to 0.5 on the top channel (-> ReLU), cancel to 0 on bottom (-> sigmoid)
51+
52+ headInput.setZero ();
53+ output.setZero ();
54+
55+ layer.process_ (input, condition, headInput, output, 0 , 0 );
56+
57+ // 0.25 + 0.25 -> 0.5 for conv & input mixin top channel
58+ // (0 on bottom channel)
59+ // Top ReLU -> preseves 0.5
60+ // Bottom sigmoid 0->0.5
61+ // Product is 0.25
62+ // 1x1 is unity
63+ // Skip-connect -> 0.25 (input) + 0.25 (output) -> 0.5 output
64+ // head output gets 0+0.25 = 0.25
65+ const float expectedOutput = 0.5 ;
66+ const float expectedHeadInput = 0.25 ;
67+ for (int i = 0 ; i < numFrames; i++)
68+ {
69+ const float actualOutput = output (0 , i);
70+ const float actualHeadInput = headInput (0 , i);
71+ // std::cout << actualOutput << std::endl;
72+ assert (actualOutput == expectedOutput);
73+ assert (actualHeadInput == expectedHeadInput);
74+ }
75+ }
76+ }; // namespace test_wavenet
0 commit comments