Skip to content

Commit 42bc38d

Browse files
authored
Merge pull request #93 from rafbiels/dl-mnist-print-compute-time
[dl-mnist] Print both total and compute-only times
2 parents 890d3f0 + fab1278 commit 42bc38d

File tree

3 files changed

+34
-5
lines changed

3 files changed

+34
-5
lines changed

dl-mnist/CUDA/main.cudnn.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ int main(int argc, const char** argv) {
6464
cout << "=======================================================================" << endl;
6565
cout << endl;
6666

67+
Time wallClockCompute = get_time_now();
68+
6769
WorkloadParams workload_params(argc, argv);
6870

6971

@@ -151,13 +153,21 @@ int main(int argc, const char** argv) {
151153
#ifdef DEVICE_TIMER
152154
cout << "Final time across all networks: " << timer->getTotalOpTime() << " s" << std::endl;
153155
#endif
156+
157+
double computeTime = calculate_op_time_taken(wallClockCompute);
158+
154159
delete dlNetworkMgr;
155160
delete timer;
156161

157162
if (handle) cudnnDestroy(handle);
158163

159-
std::cout << "dl-mnist - total time for whole calculation: " << calculate_op_time_taken(wallClockStart) - dataFileReadTimer->getTotalOpTime()<< " s" << std::endl;
160-
164+
double totalTime = calculate_op_time_taken(wallClockStart);
165+
double ioTime = dataFileReadTimer->getTotalOpTime();
166+
167+
std::cout << "dl-mnist - I/O time: " << ioTime << " s" << std::endl;
168+
std::cout << "dl-mnist - compute time: " << computeTime - ioTime << " s" << std::endl;
169+
std::cout << "dl-mnist - total time for whole calculation: " << totalTime - ioTime << " s" << std::endl;
170+
161171
} catch (DlInfraException& e) {
162172
std::cout << e.what() << "\n";
163173
std::cout << "Workload execution failed" << std::endl;

dl-mnist/HIP/main.miopen.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ int main(int argc, const char** argv) {
6363
cout << "=======================================================================" << endl;
6464
cout << endl;
6565

66+
Time wallClockCompute = get_time_now();
67+
6668
WorkloadParams workload_params(argc, argv);
6769

6870

@@ -151,14 +153,20 @@ int main(int argc, const char** argv) {
151153
#ifdef DEVICE_TIMER
152154
cout << "Final time across all networks: " << timer->getTotalOpTime() << " s" << std::endl;
153155
#endif
156+
157+
double computeTime = calculate_op_time_taken(wallClockCompute);
158+
154159
delete dlNetworkMgr;
155160
delete timer;
156161

157162
if (handle) miopenDestroy(handle);
158-
159163

160-
std::cout << "dl-mnist - total time for whole calculation: " << calculate_op_time_taken(wallClockStart) - dataFileReadTimer->getTotalOpTime()<< " s" << std::endl;
164+
double totalTime = calculate_op_time_taken(wallClockStart);
165+
double ioTime = dataFileReadTimer->getTotalOpTime();
161166

167+
std::cout << "dl-mnist - I/O time: " << ioTime << " s" << std::endl;
168+
std::cout << "dl-mnist - compute time: " << computeTime - ioTime << " s" << std::endl;
169+
std::cout << "dl-mnist - total time for whole calculation: " << totalTime - ioTime << " s" << std::endl;
162170

163171
};
164172

dl-mnist/SYCL/main.onednn.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ int main(int argc, const char** argv) {
8989
cout << "=======================================================================" << endl;
9090
cout << endl;
9191

92+
Time wallClockCompute = get_time_now();
93+
9294
WorkloadParams workload_params(argc, argv);
9395

9496
// Since this workload is for inference, we are keeping N=1 which is the usual case for inference.
@@ -172,6 +174,9 @@ int main(int argc, const char** argv) {
172174
#ifdef DEVICE_TIMER
173175
cout << "Final time across all networks: " << timer->getTotalOpTime() << " s" << std::endl;
174176
#endif
177+
178+
double computeTime = calculate_op_time_taken(wallClockCompute);
179+
175180
delete dlNetworkMgr;
176181

177182
delete dht;
@@ -180,7 +185,13 @@ int main(int argc, const char** argv) {
180185

181186
delete timer;
182187

183-
std::cout << "dl-mnist - total time for whole calculation: " << calculate_op_time_taken(wallClockStart) - dataFileReadTimer->getTotalOpTime()<< " s" << std::endl;
188+
double totalTime = calculate_op_time_taken(wallClockStart);
189+
double ioTime = dataFileReadTimer->getTotalOpTime();
190+
191+
std::cout << "dl-mnist - I/O time: " << ioTime << " s" << std::endl;
192+
std::cout << "dl-mnist - compute time: " << computeTime - ioTime << " s" << std::endl;
193+
std::cout << "dl-mnist - total time for whole calculation: " << totalTime - ioTime << " s" << std::endl;
194+
184195
} catch (dnnl::error& e) {
185196
std::cout << e.what() << "\n";
186197
std::cout << "Workload execution failed" << std::endl;

0 commit comments

Comments
 (0)