@@ -38,9 +38,12 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
38
38
reader_cost = 0
39
39
batch_cost = 0
40
40
reader_tic = time .perf_counter ()
41
+
42
+ input_dict_list = []
43
+ label_dict_list = []
44
+ weight_dict_list = []
41
45
for _ , _constraint in solver .constraint .items ():
42
46
input_dict , label_dict , weight_dict = next (_constraint .data_iter )
43
-
44
47
# profile code below
45
48
# profiler.add_profiler_step(solver.cfg["profiler_options"])
46
49
if iter_id == 5 :
@@ -50,20 +53,34 @@ def train_epoch_func(solver, epoch_id: int, log_freq: int):
50
53
reader_cost += time .perf_counter () - reader_tic
51
54
total_batch_size .append (next (iter (input_dict .values ())).shape [0 ])
52
55
53
- for v in input_dict .values ():
54
- v .stop_gradient = False
56
+ # gather each constraint's input, label, weight to a list
57
+ input_dict_list .append (input_dict )
58
+ label_dict_list .append (label_dict )
59
+ weight_dict_list .append (weight_dict )
55
60
56
- # forward for every constraint
57
- with solver .autocast_context_manager ():
58
- output_dict = solver .expr_helper (
59
- _constraint .output_expr , input_dict , solver .model
60
- )
61
- constraint_loss = _constraint .loss (output_dict , label_dict , weight_dict )
62
- total_loss += constraint_loss
61
+ reader_tic = time .perf_counter ()
63
62
64
- loss_dict [_constraint .name ] += float (constraint_loss )
63
+ for x in input_dict_list :
64
+ for v in x .values ():
65
+ v .stop_gradient = False
65
66
66
- reader_tic = time .perf_counter ()
67
+ # forward for every constraint, including model and equation expression
68
+ with solver .autocast_context_manager ():
69
+ output_dict_list = solver .expr_helper (
70
+ [_constraint .output_expr for _constraint in solver .constraint .values ()],
71
+ input_dict_list ,
72
+ solver .model ,
73
+ )
74
+
75
+ # compute loss for each constraint according to its' own output, label and weight
76
+ for i , (_ , _constraint ) in enumerate (solver .constraint .items ()):
77
+ constraint_loss = _constraint .loss (
78
+ output_dict_list [i ],
79
+ label_dict_list [i ],
80
+ weight_dict_list [i ],
81
+ )
82
+ total_loss += constraint_loss
83
+ loss_dict [_constraint .name ] += float (constraint_loss )
67
84
68
85
if solver .update_freq > 1 :
69
86
total_loss = total_loss / solver .update_freq
0 commit comments