1
1
#include " siren_nerf.h"
2
2
3
3
SirenLayer::SirenLayer (int64_t dim_in, int64_t dim_out,
4
- bool is_first, bool use_bias, float c)
5
- : dim_in_(dim_in), is_first_(is_first), w0_(is_first ? 120 . 0f : 1 . 0f ) {
4
+ bool is_first, float w0, bool use_bias, float c)
5
+ : dim_in_(dim_in), is_first_(is_first), w0_(w0 ) {
6
6
// Initialize weight and bias
7
7
weight_ = register_parameter (" weight" , torch::zeros ({dim_out, dim_in}));
8
8
float w_std = is_first_ ? (1 .0f / dim_in_) : (std::sqrt (c / dim_in_) / w0_);
@@ -28,12 +28,12 @@ SirenNeRF::SirenNeRF(torch::Device device, int W, int D): device_(device) {
28
28
D = std::max (D, 2 );
29
29
30
30
// Create position encoder SIREN layers
31
- pos_siren_ = std::make_shared<SirenLayer>(1 , 64 , true );
31
+ pos_siren_ = std::make_shared<SirenLayer>(1 , 64 , true , 120 );
32
32
register_module (" pos_siren" , pos_siren_);
33
33
pos_siren_->to (device_);
34
34
35
35
// Create view direction encoder SIREN layers
36
- view_siren_ = std::make_shared<SirenLayer>(1 , 32 , true );
36
+ view_siren_ = std::make_shared<SirenLayer>(1 , 32 , true , 20 );
37
37
register_module (" view_siren" , view_siren_);
38
38
view_siren_->to (device_);
39
39
0 commit comments