Skip to content

Commit 1d34631

Browse files
csjfwangwang85wang85wang85wang85
authored
Fix the issue - "Empty source still have embedding network" (ecmwf#1114)
* Replace cf.rank==0 with utils.distributed.is_root * fix empty source inputs still have embedding layer * fix lint * fix source empty or source exclude all * fix source empty or source exclude all * fix forecast mode empty source --------- Co-authored-by: wang85 <wang85@jwlogin22.juwels> Co-authored-by: wang85 <wang85@jwlogin24.juwels> Co-authored-by: wang85 <wang85@jwb0149.juwels> Co-authored-by: wang85 <wang85@jwlogin21.juwels>
1 parent 5e9c3ef commit 1d34631

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

src/weathergen/datasets/multi_stream_data_sampler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,9 @@ def advance(self):
248248
###################################################
249249
def get_sources_size(self):
250250
return [
251-
ds[0].get_source_num_channels()
251+
0
252+
if ds[0].get_source_num_channels() == 0
253+
else ds[0].get_source_num_channels()
252254
+ ds[0].get_geoinfo_size()
253255
+ ds[0].get_coords_size()
254256
+ self.tokenizer.get_size_time_embedding()

src/weathergen/model/engines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(self, cf: Config, sources_size) -> None:
4747
for i, si in enumerate(self.cf.streams):
4848
stream_name = si.get("name", i)
4949

50-
if "diagnostic" in si and si["diagnostic"]:
50+
if si.get("diagnostic", False) or self.sources_size[i] == 0:
5151
self.embeds.append(torch.nn.Identity())
5252
continue
5353

0 commit comments

Comments
 (0)