File tree 2 files changed +7
-10
lines changed
graph_weather/models/aurora
2 files changed +7
-10
lines changed Original file line number Diff line number Diff line change @@ -188,12 +188,11 @@ def __init__(
188
188
self ,
189
189
input_features : int ,
190
190
output_features : int ,
191
- embed_dim : int = 256 ,
192
191
latent_dim : int = 256 ,
193
192
num_layers : int = 4 ,
194
193
max_points : int = 10000 ,
195
194
max_seq_len : int = 1024 ,
196
- use_checkPointing : bool = False ,
195
+ use_checkpointing : bool = False ,
197
196
):
198
197
super ().__init__ ()
199
198
@@ -203,12 +202,12 @@ def __init__(
203
202
self .output_features = output_features
204
203
205
204
# Model components
206
- self .encoder = PointEncoder (input_features , embed_dim , max_seq_len )
207
- self .processor = PointCloudProcessor (embed_dim , num_layers )
208
- self .decoder = PointDecoder (embed_dim , output_features )
205
+ self .encoder = PointEncoder (input_features , latent_dim , max_seq_len )
206
+ self .processor = PointCloudProcessor (latent_dim , num_layers )
207
+ self .decoder = PointDecoder (latent_dim , output_features )
209
208
210
209
# Add gradient checkpointing
211
- self .use_checkpointing = use_checkPointing
210
+ self .use_checkpointing = use_checkpointing
212
211
213
212
# Initialize weights properly
214
213
self ._init_weights ()
Original file line number Diff line number Diff line change @@ -44,8 +44,7 @@ def model_config():
44
44
return {
45
45
"input_features" : 2 ,
46
46
"output_features" : 2 ,
47
- "embed_dim" : 8 , # Adjusted to avoid overfitting
48
- "latent_dim" : 16 , # Adjusted
47
+ "latent_dim" : 8 , # Adjusted
49
48
"max_points" : 50 , # Adjusted
50
49
"max_seq_len" : 128 , # Adjusted
51
50
}
@@ -74,8 +73,7 @@ def test_gradient_checkpointing_config():
74
73
config_no_checkpoint = {
75
74
"input_features" : 2 ,
76
75
"output_features" : 2 ,
77
- "embed_dim" : 8 ,
78
- "latent_dim" : 16 ,
76
+ "latent_dim" : 8 ,
79
77
"max_points" : 50 ,
80
78
"max_seq_len" : 128 ,
81
79
"use_checkpointing" : False ,
You can’t perform that action at this time.
0 commit comments