|
10 | 10 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
11 | 11 | # ANY KIND, either express or implied. See the License for the specific
|
12 | 12 | # language governing permissions and limitations under the License.
|
13 |
| -from __future__ import absolute_import, print_function |
14 |
| - |
15 | 13 | import os
|
16 |
| -import subprocess |
17 |
| - |
18 |
| -import keras |
19 |
| -from keras.datasets import mnist |
20 |
| -from keras.models import Sequential |
21 |
| -from keras.layers import Dense, Dropout, Flatten |
22 |
| -from keras.layers import Conv2D, MaxPooling2D |
23 |
| -from keras import backend as K |
24 | 14 | import tensorflow as tf
|
25 |
| -import horovod.keras as hvd |
26 |
| - |
| 15 | +import horovod.tensorflow as hvd |
27 | 16 |
|
28 | 17 | # Horovod: initialize Horovod.
|
29 | 18 | hvd.init()
|
30 | 19 |
|
31 | 20 | # Horovod: pin GPU to be used to process local rank (one GPU per process)
|
32 |
| -config = tf.compat.v1.ConfigProto() |
33 |
| -config.gpu_options.allow_growth = True |
34 |
| -config.gpu_options.visible_device_list = str(hvd.local_rank()) |
35 |
| -K.set_session(tf.compat.v1.Session(config=config)) |
36 |
| - |
37 |
| -batch_size = 128 |
38 |
| -num_classes = 10 |
39 |
| - |
40 |
| -epochs = 1 |
41 |
| - |
42 |
| -# Input image dimensions |
43 |
| -img_rows, img_cols = 28, 28 |
44 |
| - |
45 |
| -# The data, shuffled and split between train and test sets |
46 |
| -(x_train, y_train), (x_test, y_test) = mnist.load_data() |
47 |
| - |
48 |
| -x_train = x_train[:600] |
49 |
| -y_train = y_train[:600] |
50 |
| -x_test = x_test[:100] |
51 |
| -y_test = y_test[:100] |
52 |
| - |
53 |
| -if K.image_data_format() == 'channels_first': |
54 |
| - x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) |
55 |
| - x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) |
56 |
| - input_shape = (1, img_rows, img_cols) |
57 |
| -else: |
58 |
| - x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) |
59 |
| - x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) |
60 |
| - input_shape = (img_rows, img_cols, 1) |
61 |
| - |
62 |
| -x_train = x_train.astype('float32') |
63 |
| -x_test = x_test.astype('float32') |
64 |
| -x_train /= 255 |
65 |
| -x_test /= 255 |
66 |
| -print('x_train shape:', x_train.shape) |
67 |
| -print(x_train.shape[0], 'train samples') |
68 |
| -print(x_test.shape[0], 'test samples') |
69 |
| - |
70 |
| -# Convert class vectors to binary class matrices |
71 |
| -y_train = keras.utils.to_categorical(y_train, num_classes) |
72 |
| -y_test = keras.utils.to_categorical(y_test, num_classes) |
73 |
| - |
74 |
| -model = Sequential() |
75 |
| -model.add(Conv2D(32, kernel_size=(3, 3), |
76 |
| - activation='relu', |
77 |
| - input_shape=input_shape)) |
78 |
| -model.add(Conv2D(64, (3, 3), activation='relu')) |
79 |
| -model.add(MaxPooling2D(pool_size=(2, 2))) |
80 |
| -model.add(Dropout(0.25)) |
81 |
| -model.add(Flatten()) |
82 |
| -model.add(Dense(128, activation='relu')) |
83 |
| -model.add(Dropout(0.5)) |
84 |
| -model.add(Dense(num_classes, activation='softmax')) |
| 21 | +gpus = tf.config.experimental.list_physical_devices('GPU') |
| 22 | +for gpu in gpus: |
| 23 | + tf.config.experimental.set_memory_growth(gpu, True) |
| 24 | +if gpus: |
| 25 | + tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU') |
| 26 | + |
| 27 | +(mnist_images, mnist_labels), _ = \ |
| 28 | + tf.keras.datasets.mnist.load_data(path='mnist-%d.npz' % hvd.rank()) |
| 29 | + |
| 30 | +dataset = tf.data.Dataset.from_tensor_slices( |
| 31 | + (tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32), |
| 32 | + tf.cast(mnist_labels, tf.int64)) |
| 33 | +) |
| 34 | +dataset = dataset.repeat().shuffle(10000).batch(128) |
| 35 | + |
| 36 | +mnist_model = tf.keras.Sequential([ |
| 37 | + tf.keras.layers.Conv2D(32, [3, 3], activation='relu'), |
| 38 | + tf.keras.layers.Conv2D(64, [3, 3], activation='relu'), |
| 39 | + tf.keras.layers.MaxPooling2D(pool_size=(2, 2)), |
| 40 | + tf.keras.layers.Dropout(0.25), |
| 41 | + tf.keras.layers.Flatten(), |
| 42 | + tf.keras.layers.Dense(128, activation='relu'), |
| 43 | + tf.keras.layers.Dropout(0.5), |
| 44 | + tf.keras.layers.Dense(10, activation='softmax') |
| 45 | +]) |
| 46 | +loss = tf.losses.SparseCategoricalCrossentropy() |
85 | 47 |
|
86 | 48 | # Horovod: adjust learning rate based on number of GPUs.
|
87 |
| -opt = keras.optimizers.Adadelta(1.0 * hvd.size()) |
| 49 | +opt = tf.optimizers.Adam(0.001 * hvd.size()) |
88 | 50 |
|
89 |
| -# Horovod: add Horovod Distributed Optimizer. |
90 |
| -opt = hvd.DistributedOptimizer(opt) |
| 51 | +checkpoint_dir = './checkpoints' |
| 52 | +checkpoint = tf.train.Checkpoint(model=mnist_model, optimizer=opt) |
91 | 53 |
|
92 |
| -model.compile(loss=keras.losses.categorical_crossentropy, |
93 |
| - optimizer=opt, |
94 |
| - metrics=['accuracy']) |
95 | 54 |
|
96 |
| -callbacks = [ |
97 |
| - # Horovod: broadcast initial variable states from rank 0 to all other processes. |
98 |
| - # This is necessary to ensure consistent initialization of all workers when |
99 |
| - # training is started with random weights or restored from a checkpoint. |
100 |
| - hvd.callbacks.BroadcastGlobalVariablesCallback(0), |
101 |
| -] |
| 55 | +@tf.function |
| 56 | +def training_step(images, labels, first_batch): |
| 57 | + with tf.GradientTape() as tape: |
| 58 | + probs = mnist_model(images, training=True) |
| 59 | + loss_value = loss(labels, probs) |
102 | 60 |
|
103 |
| -# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them. |
104 |
| -if hvd.rank() == 0: |
105 |
| - callbacks.append(keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5')) |
| 61 | + # Horovod: add Horovod Distributed GradientTape. |
| 62 | + tape = hvd.DistributedGradientTape(tape) |
106 | 63 |
|
107 |
| -model.fit(x_train, y_train, |
108 |
| - batch_size=batch_size, |
109 |
| - callbacks=callbacks, |
110 |
| - epochs=epochs, |
111 |
| - verbose=1, |
112 |
| - validation_data=(x_test, y_test)) |
113 |
| -score = model.evaluate(x_test, y_test, verbose=0) |
114 |
| -print('Test loss:', score[0]) |
115 |
| -print('Test accuracy:', score[1]) |
| 64 | + grads = tape.gradient(loss_value, mnist_model.trainable_variables) |
| 65 | + opt.apply_gradients(zip(grads, mnist_model.trainable_variables)) |
116 | 66 |
|
| 67 | + # Horovod: broadcast initial variable states from rank 0 to all other processes. |
| 68 | + # This is necessary to ensure consistent initialization of all workers when |
| 69 | + # training is started with random weights or restored from a checkpoint. |
| 70 | + # |
| 71 | + # Note: broadcast should be done after the first gradient step to ensure optimizer |
| 72 | + # initialization. |
| 73 | + if first_batch: |
| 74 | + hvd.broadcast_variables(mnist_model.variables, root_rank=0) |
| 75 | + hvd.broadcast_variables(opt.variables(), root_rank=0) |
117 | 76 |
|
118 |
| -if hvd.rank() == 0: |
119 |
| - # Exports the keras model as TensorFlow Serving Saved Model |
120 |
| - with K.get_session() as session: |
| 77 | + return loss_value |
121 | 78 |
|
122 |
| - init = tf.compat.v1.global_variables_initializer() |
123 |
| - session.run(init) |
124 | 79 |
|
125 |
| - tf.compat.v1.saved_model.simple_save( |
126 |
| - session, |
127 |
| - os.path.join('/opt/ml/model/mnist/1'), |
128 |
| - inputs={'input_image': model.input}, |
129 |
| - outputs={t.name: t for t in model.outputs}) |
| 80 | +# Horovod: adjust number of steps based on number of GPUs. |
| 81 | +for batch, (images, labels) in enumerate(dataset.take(600 // hvd.size())): |
| 82 | + loss_value = training_step(images, labels, batch == 0) |
| 83 | + |
| 84 | + if batch % 10 == 0 and hvd.local_rank() == 0: |
| 85 | + print('Step #%d\tLoss: %.6f' % (batch, loss_value)) |
| 86 | + |
| 87 | +# Horovod: save checkpoints only on worker 0 to prevent other workers from |
| 88 | +# corrupting it. |
| 89 | +if hvd.rank() == 0: |
| 90 | + # Export the keras model as Tensorflow SavedModelBundle |
| 91 | + mnist_model.save( |
| 92 | + os.path.join('/opt/ml/model/mnist/1'), |
| 93 | + save_format='tf') |
0 commit comments