@@ -42,103 +42,86 @@ def normalise_number(number):
42
42
43
43
44
44
def ensembles_perturbations (ensembles , center , mean , remapping = {}, patches = {}):
45
- n_ensembles = len (normalise_number (ensembles ["number" ]))
45
+ number_list = normalise_number (ensembles ["number" ])
46
+ n_numbers = len (number_list )
47
+
48
+ keys = ["param" , "level" , "valid_datetime" , "date" , "time" , "step" , "number" ]
46
49
47
50
print (f"Retrieving ensemble data with { ensembles } " )
48
- ensembles = load_source (** ensembles )
51
+ ensembles = load_source (** ensembles ). order_by ( * keys )
49
52
print (f"Retrieving center data with { center } " )
50
- center = load_source (** center )
53
+ center = load_source (** center ). order_by ( * keys )
51
54
print (f"Retrieving mean data with { mean } " )
52
- mean = load_source (** mean )
55
+ mean = load_source (** mean ). order_by ( * keys )
53
56
54
- assert len (mean ) * n_ensembles == len (ensembles ), (
57
+ assert len (mean ) * n_numbers == len (ensembles ), (
55
58
len (mean ),
56
- n_ensembles ,
59
+ n_numbers ,
57
60
len (ensembles ),
58
61
)
59
- assert len (center ) * n_ensembles == len (ensembles ), (
62
+ assert len (center ) * n_numbers == len (ensembles ), (
60
63
len (center ),
61
- n_ensembles ,
64
+ n_numbers ,
62
65
len (ensembles ),
63
66
)
64
67
68
+ # prepare output tmp file so we can read it back
65
69
tmp = temp_file ()
66
70
path = tmp .path
67
71
out = new_grib_output (path )
68
72
69
- keys = ["param" , "level" , "valid_datetime" , "number" , "date" , "time" , "step" ]
70
-
71
- ensembles_coords = ensembles .unique_values (* keys )
72
- center_coords = center .unique_values (* keys )
73
- mean_coords = mean .unique_values (* keys )
74
-
75
- for k in keys :
76
- if k == "number" :
77
- assert len (mean_coords [k ]) == 1
78
- assert len (center_coords [k ]) == 1
79
- assert len (ensembles_coords [k ]) == n_ensembles
80
- continue
81
- assert set (center_coords [k ]) == set (ensembles_coords [k ]), (
82
- k ,
83
- center_coords [k ],
84
- ensembles_coords [k ],
85
- )
86
- assert set (center_coords [k ]) == set (mean_coords [k ]), (
87
- k ,
88
- center_coords [k ],
89
- mean_coords [k ],
90
- )
91
-
92
- for field in tqdm .tqdm (center ):
73
+ for i , field in tqdm .tqdm (enumerate (ensembles )):
93
74
param = field .metadata ("param" )
94
- grid = field .metadata ("grid" )
95
-
96
- selection = dict (
97
- valid_datetime = field .metadata ("valid_datetime" ),
98
- param = field .metadata ("param" ),
99
- level = field .metadata ("level" ),
100
- date = field .metadata ("date" ),
101
- time = field .metadata ("time" ),
102
- step = field .metadata ("step" ),
103
- )
104
- mean_field = get_unique_field (mean , selection )
105
- assert mean_field .metadata ("grid" ) == grid , (mean_field .metadata ("grid" ), grid )
75
+ number = field .metadata ("number" )
76
+ ii = i // n_numbers
77
+
78
+ i_number = number_list .index (number )
79
+ assert i == ii * n_numbers + i_number , (i , ii , n_numbers , i_number , number_list )
80
+
81
+ center_field = center [ii ]
82
+ mean_field = mean [ii ]
83
+
84
+ for k in keys + ["grid" , "shape" ]:
85
+ if k == "number" :
86
+ continue
87
+ assert center_field .metadata (k ) == field .metadata (k ), (
88
+ k ,
89
+ center_field .metadata (k ),
90
+ field .metadata (k ),
91
+ )
92
+ assert mean_field .metadata (k ) == field .metadata (k ), (
93
+ k ,
94
+ mean_field .metadata (k ),
95
+ field .metadata (k ),
96
+ )
106
97
98
+ e = field .to_numpy ()
107
99
m = mean_field .to_numpy ()
108
- c = field .to_numpy ()
100
+ c = center_field .to_numpy ()
109
101
assert m .shape == c .shape , (m .shape , c .shape )
110
102
111
- for number in ensembles_coords ["number" ]:
112
- ensembles_field = get_unique_field (ensembles .sel (number = number ), selection )
113
- assert ensembles_field .metadata ("grid" ) == grid , (
114
- ensembles_field .metadata ("grid" ),
115
- grid ,
116
- )
117
-
118
- e = ensembles_field .to_numpy ()
119
- assert c .shape == e .shape , (c .shape , e .shape )
120
-
121
- x = c + m - e
122
- if param == "q" :
123
- warnings .warn ("Clipping q" )
124
- x = np .maximum (x , 0 )
103
+ #################################
104
+ # Actual computation happens here
105
+ x = c + m - e
106
+ if param == "q" :
107
+ warnings .warn ("Clipping q" )
108
+ x = np .maximum (x , 0 )
109
+ #################################
125
110
126
- assert x .shape == c .shape , (x .shape , c .shape )
111
+ assert x .shape == e .shape , (x .shape , e .shape )
127
112
128
- check_data_values (x , name = param )
129
- out .write (x , template = ensembles_field )
113
+ check_data_values (x , name = param )
114
+ out .write (x , template = field )
130
115
131
116
out .close ()
132
117
133
118
ds = load_source ("file" , path )
134
- assert len (ds ) == len (ensembles ), (len (ds ), len (ensembles ))
119
+ # save a reference to the tmp file so it is deleted
120
+ # only when the dataset is not used anymore
135
121
ds ._tmp = tmp
136
122
137
- assert len (mean ) * n_ensembles == len (ensembles )
138
- assert len (center ) * n_ensembles == len (ensembles )
123
+ assert len (ds ) == len (ensembles ), (len (ds ), len (ensembles ))
139
124
140
- final_coords = ds .unique_values (* keys )
141
- assert len (final_coords ["number" ]) == n_ensembles , final_coords
142
125
return ds
143
126
144
127
0 commit comments