@@ -29,6 +29,64 @@ function merge_stats(us)
29
29
reduce (merge, st)
30
30
end
31
31
32
+ mutable struct AggregateLogger{T<: Logging.AbstractLogger } <: Logging.AbstractLogger
33
+ progress:: Dict{Symbol, Float64}
34
+ done_counter:: Int
35
+ total:: Float64
36
+ print_time:: Float64
37
+ lock:: ReentrantLock
38
+ logger:: T
39
+ end
40
+ AggregateLogger (logger:: Logging.AbstractLogger ) = AggregateLogger (Dict {Symbol, Float64} (),0 , 0.0 , 0.0 , ReentrantLock (), logger)
41
+
42
+ function Logging. handle_message (l:: AggregateLogger , level, message, _module, group, id, file, line; kwargs... )
43
+ if convert (Logging. LogLevel, level) == Logging. LogLevel (- 1 ) && haskey (kwargs, :progress )
44
+ pr = kwargs[:progress ]
45
+ if trylock (l. lock) || (pr == " done" && lock (l. lock)=== nothing )
46
+ try
47
+ if pr == " done"
48
+ pr = 1.0
49
+ l. done_counter += 1
50
+ end
51
+ len = length (l. progress)
52
+ if haskey (l. progress, id)
53
+ l. total += (pr- l. progress[id])/ len
54
+ else
55
+ l. total = l. total* (len/ (len+ 1 )) + pr/ (len+ 1 )
56
+ len += 1
57
+ end
58
+ l. progress[id] = pr
59
+ # validation check (slow)
60
+ # tot = sum(values(l.progress))/length(l.progress)
61
+ # @show tot l.total l.total ≈ tot
62
+ curr_time = time ()
63
+ if l. done_counter >= len
64
+ tot= " done"
65
+ empty! (l. progress)
66
+ l. done_counter = 0
67
+ l. print_time = 0.0
68
+ elseif curr_time- l. print_time > 0.1
69
+ tot = l. total
70
+ l. print_time = curr_time
71
+ else
72
+ return
73
+ end
74
+ id= :total
75
+ message= " Total"
76
+ kwargs= merge (values (kwargs), (progress= tot,))
77
+ finally
78
+ unlock (l. lock)
79
+ end
80
+ else
81
+ return
82
+ end
83
+ end
84
+ Logging. handle_message (l. logger, level, message, _module, group, id, file, line; kwargs... )
85
+ end
86
+ Logging. shouldlog (l:: AggregateLogger , args... ) = Logging. shouldlog (l. logger, args... )
87
+ Logging. min_enabled_level (l:: AggregateLogger ) = Logging. min_enabled_level (l. logger)
88
+ Logging. catch_exceptions (l:: AggregateLogger ) = Logging. catch_exceptions (l. logger)
89
+
32
90
function __solve (prob:: AbstractEnsembleProblem ,
33
91
alg:: Union{AbstractDEAlgorithm, Nothing} ;
34
92
kwargs... )
59
117
function __solve (prob:: AbstractEnsembleProblem ,
60
118
alg:: A ,
61
119
ensemblealg:: BasicEnsembleAlgorithm ;
62
- trajectories, batch_size = trajectories,
120
+ trajectories, batch_size = trajectories, progress_aggregate = true ,
63
121
pmap_batch_size = batch_size ÷ 100 > 0 ? batch_size ÷ 100 : 1 , kwargs... ) where {A}
64
- num_batches = trajectories ÷ batch_size
65
- num_batches < 1 &&
66
- error (" trajectories ÷ batch_size cannot be less than 1, got $num_batches " )
67
- num_batches * batch_size != trajectories && (num_batches += 1 )
68
-
69
- if num_batches == 1 && prob. reduction === DEFAULT_REDUCTION
70
- elapsed_time = @elapsed u = solve_batch (prob, alg, ensemblealg, 1 : trajectories,
71
- pmap_batch_size; kwargs... )
72
- _u = tighten_container_eltype (u)
73
- stats = merge_stats (_u)
74
- return EnsembleSolution (_u, elapsed_time, true , stats)
75
- end
122
+ logger = progress_aggregate ? AggregateLogger (Logging. current_logger ()) : Logging. current_logger ()
123
+
124
+ Logging. with_logger (logger) do
125
+ num_batches = trajectories ÷ batch_size
126
+ num_batches < 1 &&
127
+ error (" trajectories ÷ batch_size cannot be less than 1, got $num_batches " )
128
+ num_batches * batch_size != trajectories && (num_batches += 1 )
76
129
77
- converged:: Bool = false
78
- elapsed_time = @elapsed begin
79
- i = 1
80
- II = (batch_size * (i - 1 ) + 1 ): (batch_size * i)
130
+ if get (kwargs, :progress , false )
131
+ name = get (kwargs, :progress_name , " Ensemble" )
132
+ for i in 1 : trajectories
133
+ msg = " $name #$i "
134
+ Logging. @logmsg (Logging. LogLevel (- 1 ), msg, _id= Symbol (" SciMLBase_$i " ), progress= 0 )
135
+ end
136
+ end
137
+
81
138
82
- batch_data = solve_batch (prob, alg, ensemblealg, II, pmap_batch_size; kwargs... )
139
+ if num_batches == 1 && prob. reduction === DEFAULT_REDUCTION
140
+ elapsed_time = @elapsed u = solve_batch (prob, alg, ensemblealg, 1 : trajectories,
141
+ pmap_batch_size; kwargs... )
142
+ _u = tighten_container_eltype (u)
143
+ stats = merge_stats (_u)
144
+ return EnsembleSolution (_u, elapsed_time, true , stats)
145
+ end
146
+
147
+ converged:: Bool = false
148
+ elapsed_time = @elapsed begin
149
+ i = 1
150
+ II = (batch_size * (i - 1 ) + 1 ): (batch_size * i)
83
151
84
- u = prob. u_init === nothing ? similar (batch_data, 0 ) : prob. u_init
85
- u, converged = prob. reduction (u, batch_data, II)
86
- for i in 2 : num_batches
87
- converged && break
88
- if i == num_batches
89
- II = (batch_size * (i - 1 ) + 1 ): trajectories
90
- else
91
- II = (batch_size * (i - 1 ) + 1 ): (batch_size * i)
92
- end
93
152
batch_data = solve_batch (prob, alg, ensemblealg, II, pmap_batch_size; kwargs... )
153
+
154
+ u = prob. u_init === nothing ? similar (batch_data, 0 ) : prob. u_init
94
155
u, converged = prob. reduction (u, batch_data, II)
156
+ for i in 2 : num_batches
157
+ converged && break
158
+ if i == num_batches
159
+ II = (batch_size * (i - 1 ) + 1 ): trajectories
160
+ else
161
+ II = (batch_size * (i - 1 ) + 1 ): (batch_size * i)
162
+ end
163
+ batch_data = solve_batch (prob, alg, ensemblealg, II, pmap_batch_size; kwargs... )
164
+ u, converged = prob. reduction (u, batch_data, II)
165
+ end
95
166
end
167
+ _u = tighten_container_eltype (u)
168
+ stats = merge_stats (_u)
169
+ return EnsembleSolution (_u, elapsed_time, converged, stats)
96
170
end
97
- _u = tighten_container_eltype (u)
98
- stats = merge_stats (_u)
99
- return EnsembleSolution (_u, elapsed_time, converged, stats)
100
171
end
101
172
102
173
function batch_func (i, prob, alg; kwargs... )
103
174
iter = 1
104
175
_prob = prob. safetycopy ? deepcopy (prob. prob) : prob. prob
105
176
new_prob = prob. prob_func (_prob, i, iter)
106
177
rerun = true
178
+
179
+ progress = get (kwargs, :progress , false )
180
+ if progress
181
+ name = get (kwargs, :progress_name , " Ensemble" )
182
+ progress_name = " $name #$i "
183
+ progress_id = Symbol (" SciMLBase_$i " )
184
+ kwargs = (kwargs... , progress_name= progress_name, progress_id= progress_id)
185
+ end
107
186
x = prob. output_func (solve (new_prob, alg; kwargs... ), i)
108
187
if ! (x isa Tuple)
109
188
rerun_warn ()
0 commit comments