Skip to content

Commit 04d678c

Browse files
authored
Merge pull request #68 from mmwillet/reciprocal-stride-hack
Reciprocal stride hack
2 parents 3a81b94 + 6d49e00 commit 04d678c

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

src/util.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,21 @@ float round_to_float(double v) {
7575
return roundf(v * powl(10, 6)) / powl(10, 6);
7676
}
7777

78+
struct ggml_tensor * reciprocal(ggml_context * ctx, struct ggml_tensor * x) {
79+
TTS_ASSERT(x->ne[0] == 1);
80+
static constexpr float one = 1.0f;
81+
ggml_tensor * numerator = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, x->ne[1]);
82+
// stride trick so that the scalar numerator can be divided by x.
83+
numerator->nb[1] = 0;
84+
numerator->data = const_cast<float *>(&one);
85+
return ggml_div(ctx, numerator, x);
86+
}
87+
7888
// Described in https://arxiv.org/abs/2006.08195
7989
// Snake1d is a common tunable activation function used in the DAC model.
8090
struct ggml_tensor * snake_1d(ggml_context * ctx, struct ggml_tensor * alpha, struct ggml_tensor * a) {
8191
assert(a->ne[2] == 1 && a->ne[3] == 1);
82-
return ggml_add(ctx, a, ggml_mul(ctx, ggml_sqr(ctx, ggml_sin(ctx, ggml_mul(ctx, a, alpha))), ggml_reciprocal(ctx, alpha)));
92+
return ggml_add(ctx, a, ggml_mul(ctx, ggml_sqr(ctx, ggml_sin(ctx, ggml_mul(ctx, a, alpha))), reciprocal(ctx, alpha)));
8393
}
8494

8595
bool has_suffix(std::string value, std::string suffix) {

src/util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ void uv_noise_compute(struct ggml_tensor * dst, const struct ggml_tensor * a, co
5050
// This is a custom op for logit correction in the Dia model.
5151
void cfg_scale(struct ggml_tensor * dst, const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
5252

53+
struct ggml_tensor * reciprocal(ggml_context * ctx, struct ggml_tensor * x);
54+
5355
bool has_suffix(std::string value, std::string suffix);
5456
bool has_prefix(std::string value, std::string prefix);
5557

0 commit comments

Comments
 (0)