Skip to content

Commit c99f1ed

Browse files
committed
Rename and remove unused argument
1 parent c2460fe commit c99f1ed

File tree

2 files changed

+7
-10
lines changed

2 files changed

+7
-10
lines changed

graph_weather/models/aurora/model.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -188,12 +188,11 @@ def __init__(
188188
self,
189189
input_features: int,
190190
output_features: int,
191-
embed_dim: int = 256,
192191
latent_dim: int = 256,
193192
num_layers: int = 4,
194193
max_points: int = 10000,
195194
max_seq_len: int = 1024,
196-
use_checkPointing: bool = False,
195+
use_checkpointing: bool = False,
197196
):
198197
super().__init__()
199198

@@ -203,12 +202,12 @@ def __init__(
203202
self.output_features = output_features
204203

205204
# Model components
206-
self.encoder = PointEncoder(input_features, embed_dim, max_seq_len)
207-
self.processor = PointCloudProcessor(embed_dim, num_layers)
208-
self.decoder = PointDecoder(embed_dim, output_features)
205+
self.encoder = PointEncoder(input_features, latent_dim, max_seq_len)
206+
self.processor = PointCloudProcessor(latent_dim, num_layers)
207+
self.decoder = PointDecoder(latent_dim, output_features)
209208

210209
# Add gradient checkpointing
211-
self.use_checkpointing = use_checkPointing
210+
self.use_checkpointing = use_checkpointing
212211

213212
# Initialize weights properly
214213
self._init_weights()

tests/test_aurora.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ def model_config():
4444
return {
4545
"input_features": 2,
4646
"output_features": 2,
47-
"embed_dim": 8, # Adjusted to avoid overfitting
48-
"latent_dim": 16, # Adjusted
47+
"latent_dim": 8, # Adjusted
4948
"max_points": 50, # Adjusted
5049
"max_seq_len": 128, # Adjusted
5150
}
@@ -74,8 +73,7 @@ def test_gradient_checkpointing_config():
7473
config_no_checkpoint = {
7574
"input_features": 2,
7675
"output_features": 2,
77-
"embed_dim": 8,
78-
"latent_dim": 16,
76+
"latent_dim": 8,
7977
"max_points": 50,
8078
"max_seq_len": 128,
8179
"use_checkpointing": False,

0 commit comments

Comments
 (0)