Skip to content

Commit e181f61

Browse files
authored
[BUGFIX] Fix gated activations in WaveNet (#131)
* Possible fix to gating bug. Haven't tried, needs tests * Unit test * Clean up comments
1 parent cd92997 commit e181f61

File tree

3 files changed

+86
-8
lines changed

3 files changed

+86
-8
lines changed

NAM/wavenet.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,14 @@ void nam::wavenet::_Layer::process_(const Eigen::MatrixXf& input, const Eigen::M
3737
}
3838
else
3939
{
40-
this->_activation->apply(this->_z.topRows(channels));
41-
activations::Activation::get_activation("Sigmoid")->apply(this->_z.bottomRows(channels));
42-
// activations::Activation::get_activation("Sigmoid")->apply(this->_z.block(channels, 0, channels,
43-
// this->_z.cols()));
44-
40+
// CAREFUL: .topRows() and .bottomRows() won't be memory-contiguous for a column-major matrix (Issue 125). Need to
41+
// do this column-wise:
42+
for (long i = 0; i < _z.cols(); i++)
43+
{
44+
this->_activation->apply(this->_z.block(0, i, channels, 1));
45+
activations::Activation::get_activation("Sigmoid")->apply(this->_z.block(channels, i, channels, 1));
46+
}
4547
this->_z.topRows(channels).array() *= this->_z.bottomRows(channels).array();
46-
// this->_z.topRows(channels) = this->_z.topRows(channels).cwiseProduct(
47-
// this->_z.bottomRows(channels)
48-
// );
4948
}
5049

5150
head_input += this->_z.topRows(channels);

tools/run_tests.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "test/test_activations.cpp"
66
#include "test/test_dsp.cpp"
77
#include "test/test_get_dsp.cpp"
8+
#include "test/test_wavenet.cpp"
89

910
int main()
1011
{
@@ -32,6 +33,8 @@ int main()
3233
test_get_dsp::test_null_input_level();
3334
test_get_dsp::test_null_output_level();
3435

36+
test_wavenet::test_gated();
37+
3538
std::cout << "Success!" << std::endl;
3639
return 0;
3740
}

tools/test/test_wavenet.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Tests for the WaveNet
2+
3+
#include <Eigen/Dense>
4+
#include <cassert>
5+
#include <iostream>
6+
7+
#include "NAM/wavenet.h"
8+
9+
namespace test_wavenet
10+
{
11+
void test_gated()
12+
{
13+
// Assert correct nuemrics of the gating activation.
14+
// Issue 101
15+
const int conditionSize = 1;
16+
const int channels = 1;
17+
const int kernelSize = 1;
18+
const int dilation = 1;
19+
const std::string activation = "ReLU";
20+
const bool gated = true;
21+
auto layer = nam::wavenet::_Layer(conditionSize, channels, kernelSize, dilation, activation, gated);
22+
23+
// Conv, input mixin, 1x1
24+
std::vector<float> weights{
25+
// Conv (weight, bias) NOTE: 2 channels out bc gated, so shapes are (2,1,1), (2,)
26+
1.0f, 1.0f, 0.0f, 0.0f,
27+
// Input mixin (weight only: (2,1,1))
28+
1.0f, -1.0f,
29+
// 1x1 (weight (1,1,1), bias (1,))
30+
// NOTE: Weights are (1,1) on conv, (1,-1), so the inputs sum on the upper channel and cancel on the lower.
31+
// This should give us a nice zero if the input & condition are the same, so that'll sigmoid to 0.5 for the
32+
// gate.
33+
1.0f, 0.0f};
34+
auto it = weights.begin();
35+
layer.set_weights_(it);
36+
assert(it == weights.end());
37+
38+
const long numFrames = 4;
39+
layer.set_num_frames_(numFrames);
40+
41+
Eigen::MatrixXf input, condition, headInput, output;
42+
input.resize(channels, numFrames);
43+
condition.resize(channels, numFrames);
44+
headInput.resize(channels, numFrames);
45+
output.resize(channels, numFrames);
46+
47+
const float signalValue = 0.25f;
48+
input.fill(signalValue);
49+
condition.fill(signalValue);
50+
// So input & condition will sum to 0.5 on the top channel (-> ReLU), cancel to 0 on bottom (-> sigmoid)
51+
52+
headInput.setZero();
53+
output.setZero();
54+
55+
layer.process_(input, condition, headInput, output, 0, 0);
56+
57+
// 0.25 + 0.25 -> 0.5 for conv & input mixin top channel
58+
// (0 on bottom channel)
59+
// Top ReLU -> preseves 0.5
60+
// Bottom sigmoid 0->0.5
61+
// Product is 0.25
62+
// 1x1 is unity
63+
// Skip-connect -> 0.25 (input) + 0.25 (output) -> 0.5 output
64+
// head output gets 0+0.25 = 0.25
65+
const float expectedOutput = 0.5;
66+
const float expectedHeadInput = 0.25;
67+
for (int i = 0; i < numFrames; i++)
68+
{
69+
const float actualOutput = output(0, i);
70+
const float actualHeadInput = headInput(0, i);
71+
// std::cout << actualOutput << std::endl;
72+
assert(actualOutput == expectedOutput);
73+
assert(actualHeadInput == expectedHeadInput);
74+
}
75+
}
76+
}; // namespace test_wavenet

0 commit comments

Comments
 (0)