Skip to content

Commit 7768c2b

Browse files
sboshinchuyang-deng
authored andcommitted
horovod_mnist rewrite (#252)
* fix: tensorflow-2.0 library code changes (#247) * change: tensorflow-2.0 tests * fix: tensorflow-2.0 library code changes * remove >=2.0 off tensorflow restrictions * fix: update mnist scripts for tf-2.0 * add dockerfiles * fix: update scripts to support tf-2.0 (#250) * Upgrading Keras * Upgrading horovod_mnist for v2, based on horovod mnist example for tf2 on horovod github * Upgrading Keras * Upgrading horovod_mnist for v2, based on horovod mnist example for tf2 on horovod github
1 parent aeeb116 commit 7768c2b

File tree

2 files changed

+64
-100
lines changed

2 files changed

+64
-100
lines changed

docker/2.0.0/py3/Dockerfile.gpu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ RUN ${PIP} install --no-cache-dir -U \
143143
keras_preprocessing==1.1.0 \
144144
requests==2.22.0 \
145145
keras==2.3.1 \
146-
awscli \
146+
awscli==1.16.196 \
147147
mpi4py==3.0.2 \
148148
"sagemaker-tensorflow>=2.0,<2.1" \
149149
# Let's install TensorFlow separately in the end to avoid

test/resources/mnist/horovod_mnist.py

Lines changed: 63 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -10,120 +10,84 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
from __future__ import absolute_import, print_function
14-
1513
import os
16-
import subprocess
17-
18-
import keras
19-
from keras.datasets import mnist
20-
from keras.models import Sequential
21-
from keras.layers import Dense, Dropout, Flatten
22-
from keras.layers import Conv2D, MaxPooling2D
23-
from keras import backend as K
2414
import tensorflow as tf
25-
import horovod.keras as hvd
26-
15+
import horovod.tensorflow as hvd
2716

2817
# Horovod: initialize Horovod.
2918
hvd.init()
3019

3120
# Horovod: pin GPU to be used to process local rank (one GPU per process)
32-
config = tf.compat.v1.ConfigProto()
33-
config.gpu_options.allow_growth = True
34-
config.gpu_options.visible_device_list = str(hvd.local_rank())
35-
K.set_session(tf.compat.v1.Session(config=config))
36-
37-
batch_size = 128
38-
num_classes = 10
39-
40-
epochs = 1
41-
42-
# Input image dimensions
43-
img_rows, img_cols = 28, 28
44-
45-
# The data, shuffled and split between train and test sets
46-
(x_train, y_train), (x_test, y_test) = mnist.load_data()
47-
48-
x_train = x_train[:600]
49-
y_train = y_train[:600]
50-
x_test = x_test[:100]
51-
y_test = y_test[:100]
52-
53-
if K.image_data_format() == 'channels_first':
54-
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
55-
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
56-
input_shape = (1, img_rows, img_cols)
57-
else:
58-
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
59-
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
60-
input_shape = (img_rows, img_cols, 1)
61-
62-
x_train = x_train.astype('float32')
63-
x_test = x_test.astype('float32')
64-
x_train /= 255
65-
x_test /= 255
66-
print('x_train shape:', x_train.shape)
67-
print(x_train.shape[0], 'train samples')
68-
print(x_test.shape[0], 'test samples')
69-
70-
# Convert class vectors to binary class matrices
71-
y_train = keras.utils.to_categorical(y_train, num_classes)
72-
y_test = keras.utils.to_categorical(y_test, num_classes)
73-
74-
model = Sequential()
75-
model.add(Conv2D(32, kernel_size=(3, 3),
76-
activation='relu',
77-
input_shape=input_shape))
78-
model.add(Conv2D(64, (3, 3), activation='relu'))
79-
model.add(MaxPooling2D(pool_size=(2, 2)))
80-
model.add(Dropout(0.25))
81-
model.add(Flatten())
82-
model.add(Dense(128, activation='relu'))
83-
model.add(Dropout(0.5))
84-
model.add(Dense(num_classes, activation='softmax'))
21+
gpus = tf.config.experimental.list_physical_devices('GPU')
22+
for gpu in gpus:
23+
tf.config.experimental.set_memory_growth(gpu, True)
24+
if gpus:
25+
tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
26+
27+
(mnist_images, mnist_labels), _ = \
28+
tf.keras.datasets.mnist.load_data(path='mnist-%d.npz' % hvd.rank())
29+
30+
dataset = tf.data.Dataset.from_tensor_slices(
31+
(tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32),
32+
tf.cast(mnist_labels, tf.int64))
33+
)
34+
dataset = dataset.repeat().shuffle(10000).batch(128)
35+
36+
mnist_model = tf.keras.Sequential([
37+
tf.keras.layers.Conv2D(32, [3, 3], activation='relu'),
38+
tf.keras.layers.Conv2D(64, [3, 3], activation='relu'),
39+
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
40+
tf.keras.layers.Dropout(0.25),
41+
tf.keras.layers.Flatten(),
42+
tf.keras.layers.Dense(128, activation='relu'),
43+
tf.keras.layers.Dropout(0.5),
44+
tf.keras.layers.Dense(10, activation='softmax')
45+
])
46+
loss = tf.losses.SparseCategoricalCrossentropy()
8547

8648
# Horovod: adjust learning rate based on number of GPUs.
87-
opt = keras.optimizers.Adadelta(1.0 * hvd.size())
49+
opt = tf.optimizers.Adam(0.001 * hvd.size())
8850

89-
# Horovod: add Horovod Distributed Optimizer.
90-
opt = hvd.DistributedOptimizer(opt)
51+
checkpoint_dir = './checkpoints'
52+
checkpoint = tf.train.Checkpoint(model=mnist_model, optimizer=opt)
9153

92-
model.compile(loss=keras.losses.categorical_crossentropy,
93-
optimizer=opt,
94-
metrics=['accuracy'])
9554

96-
callbacks = [
97-
# Horovod: broadcast initial variable states from rank 0 to all other processes.
98-
# This is necessary to ensure consistent initialization of all workers when
99-
# training is started with random weights or restored from a checkpoint.
100-
hvd.callbacks.BroadcastGlobalVariablesCallback(0),
101-
]
55+
@tf.function
56+
def training_step(images, labels, first_batch):
57+
with tf.GradientTape() as tape:
58+
probs = mnist_model(images, training=True)
59+
loss_value = loss(labels, probs)
10260

103-
# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
104-
if hvd.rank() == 0:
105-
callbacks.append(keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5'))
61+
# Horovod: add Horovod Distributed GradientTape.
62+
tape = hvd.DistributedGradientTape(tape)
10663

107-
model.fit(x_train, y_train,
108-
batch_size=batch_size,
109-
callbacks=callbacks,
110-
epochs=epochs,
111-
verbose=1,
112-
validation_data=(x_test, y_test))
113-
score = model.evaluate(x_test, y_test, verbose=0)
114-
print('Test loss:', score[0])
115-
print('Test accuracy:', score[1])
64+
grads = tape.gradient(loss_value, mnist_model.trainable_variables)
65+
opt.apply_gradients(zip(grads, mnist_model.trainable_variables))
11666

67+
# Horovod: broadcast initial variable states from rank 0 to all other processes.
68+
# This is necessary to ensure consistent initialization of all workers when
69+
# training is started with random weights or restored from a checkpoint.
70+
#
71+
# Note: broadcast should be done after the first gradient step to ensure optimizer
72+
# initialization.
73+
if first_batch:
74+
hvd.broadcast_variables(mnist_model.variables, root_rank=0)
75+
hvd.broadcast_variables(opt.variables(), root_rank=0)
11776

118-
if hvd.rank() == 0:
119-
# Exports the keras model as TensorFlow Serving Saved Model
120-
with K.get_session() as session:
77+
return loss_value
12178

122-
init = tf.compat.v1.global_variables_initializer()
123-
session.run(init)
12479

125-
tf.compat.v1.saved_model.simple_save(
126-
session,
127-
os.path.join('/opt/ml/model/mnist/1'),
128-
inputs={'input_image': model.input},
129-
outputs={t.name: t for t in model.outputs})
80+
# Horovod: adjust number of steps based on number of GPUs.
81+
for batch, (images, labels) in enumerate(dataset.take(600 // hvd.size())):
82+
loss_value = training_step(images, labels, batch == 0)
83+
84+
if batch % 10 == 0 and hvd.local_rank() == 0:
85+
print('Step #%d\tLoss: %.6f' % (batch, loss_value))
86+
87+
# Horovod: save checkpoints only on worker 0 to prevent other workers from
88+
# corrupting it.
89+
if hvd.rank() == 0:
90+
# Export the keras model as Tensorflow SavedModelBundle
91+
mnist_model.save(
92+
os.path.join('/opt/ml/model/mnist/1'),
93+
save_format='tf')

0 commit comments

Comments
 (0)