Skip to content

Commit cdb67b2

Browse files
authored
infra: add training script to benchmark directory (#349)
1 parent 42218eb commit cdb67b2

File tree

5 files changed

+72
-7
lines changed

5 files changed

+72
-7
lines changed

benchmarks/tf_benchmarks/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# TensorFlow benchmarking scripts
22

3-
This folder contains the TF training scripts https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks.
3+
This folder contains a copy of [TensorFlow's `tf_cnn_benchmarks.py` script](https://github.com/tensorflow/benchmarks/blob/e3bd1370ba21b02c4d34340934ffb4941977d96f/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py).
44

55
## Basic usage
6-
**execute_tensorflow_training.py train** uses SageMaker python sdk to start a training job.
6+
**execute_tensorflow_training.py train** uses SageMaker python sdk to start a training job.
77

88
```bash
99
./execute_tensorflow_training.py train --help
@@ -26,7 +26,7 @@ Options:
2626
--help Show this message and exit.
2727

2828
```
29-
**execute_tensorflow_training.py generate_reports** generate benchmark reports.
29+
**execute_tensorflow_training.py generate_reports** generate benchmark reports.
3030

3131
## Examples:
3232

benchmarks/tf_benchmarks/benchmarks

Lines changed: 0 additions & 1 deletion
This file was deleted.

benchmarks/tf_benchmarks/execute_tensorflow_training.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1212
# ANY KIND, either express or implied. See the License for the specific
1313
# language governing permissions and limitations under the License.
14-
1514
from __future__ import absolute_import
1615

1716
import argparse
@@ -107,4 +106,4 @@ def create_hyperparameters(model_dir, script_args):
107106

108107
if __name__ == '__main__':
109108
args, script_args = get_args()
110-
main(args, script_args)
109+
main(args, script_args)

benchmarks/tf_benchmarks/models

Lines changed: 0 additions & 1 deletion
This file was deleted.
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Benchmark script for TensorFlow.
14+
15+
Originally copied from:
16+
https://github.com/tensorflow/benchmarks/blob/e3bd1370ba21b02c4d34340934ffb4941977d96f/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py
17+
"""
18+
from __future__ import absolute_import, division, print_function
19+
20+
from absl import app
21+
from absl import flags as absl_flags
22+
import tensorflow.compat.v1 as tf
23+
24+
import benchmark_cnn
25+
import cnn_util
26+
import flags
27+
import mlperf
28+
from cnn_util import log_fn
29+
30+
31+
flags.define_flags()
32+
for name in flags.param_specs.keys():
33+
absl_flags.declare_key_flag(name)
34+
35+
absl_flags.DEFINE_boolean(
36+
'ml_perf_compliance_logging', False,
37+
'Print logs required to be compliant with MLPerf. If set, must clone the '
38+
'MLPerf training repo https://github.com/mlperf/training and add '
39+
'https://github.com/mlperf/training/tree/master/compliance to the '
40+
'PYTHONPATH')
41+
42+
43+
def main(positional_arguments):
44+
# Command-line arguments like '--distortions False' are equivalent to
45+
# '--distortions=True False', where False is a positional argument. To prevent
46+
# this from silently running with distortions, we do not allow positional
47+
# arguments.
48+
assert len(positional_arguments) >= 1
49+
if len(positional_arguments) > 1:
50+
raise ValueError('Received unknown positional arguments: %s'
51+
% positional_arguments[1:])
52+
53+
params = benchmark_cnn.make_params_from_flags()
54+
with mlperf.mlperf_logger(absl_flags.FLAGS.ml_perf_compliance_logging,
55+
params.model):
56+
params = benchmark_cnn.setup(params)
57+
bench = benchmark_cnn.BenchmarkCNN(params)
58+
59+
tfversion = cnn_util.tensorflow_version_tuple()
60+
log_fn('TensorFlow: %i.%i' % (tfversion[0], tfversion[1]))
61+
62+
bench.print_info()
63+
bench.run()
64+
65+
66+
if __name__ == '__main__':
67+
tf.disable_v2_behavior()
68+
app.run(main) # Raises error on invalid flags, unlike tf.app.run()

0 commit comments

Comments
 (0)