Skip to content

Commit e937cde

Browse files
committed
..
1 parent d92de18 commit e937cde

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

include/siren_nerf.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ class SirenLayer : public torch::nn::Module {
88
explicit SirenLayer(int64_t dim_in,
99
int64_t dim_out,
1010
bool is_first = false,
11+
float w0,
1112
bool use_bias = true,
1213
float c = 6.0f);
1314

src/siren_nerf.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include "siren_nerf.h"
22

33
SirenLayer::SirenLayer(int64_t dim_in, int64_t dim_out,
4-
bool is_first, bool use_bias, float c)
5-
: dim_in_(dim_in), is_first_(is_first), w0_(is_first ? 120.0f : 1.0f) {
4+
bool is_first, float w0, bool use_bias, float c)
5+
: dim_in_(dim_in), is_first_(is_first), w0_(w0) {
66
// Initialize weight and bias
77
weight_ = register_parameter("weight", torch::zeros({dim_out, dim_in}));
88
float w_std = is_first_ ? (1.0f / dim_in_) : (std::sqrt(c / dim_in_) / w0_);
@@ -28,12 +28,12 @@ SirenNeRF::SirenNeRF(torch::Device device, int W, int D): device_(device) {
2828
D = std::max(D, 2);
2929

3030
// Create position encoder SIREN layers
31-
pos_siren_ = std::make_shared<SirenLayer>(1, 64, true);
31+
pos_siren_ = std::make_shared<SirenLayer>(1, 64, true, 120);
3232
register_module("pos_siren", pos_siren_);
3333
pos_siren_->to(device_);
3434

3535
// Create view direction encoder SIREN layers
36-
view_siren_ = std::make_shared<SirenLayer>(1, 32, true);
36+
view_siren_ = std::make_shared<SirenLayer>(1, 32, true, 20);
3737
register_module("view_siren", view_siren_);
3838
view_siren_->to(device_);
3939

0 commit comments

Comments
 (0)