Professional Documents
Culture Documents
Data Parallel Training With KerasNLP and TF - Distribute - 1716328140606
Data Parallel Training With KerasNLP and TF - Distribute - 1716328140606
Introduction
Distributed training is a technique used to train deep learning models on multiple devices or
machines simultaneously. It helps to reduce training time and allows for training larger models with
more data. KerasNLP is a library that provides tools and utilities for natural language processing
tasks, including distributed training.
In this tutorial, we will use KerasNLP to train a BERT-based masked language model (MLM) on the
wikitext-2 dataset (a 2 million word dataset of wikipedia articles). The MLM task involves predicting
the masked words in a sentence, which helps the model learn contextual representations of words.
This guide focuses on data parallelism, in particular synchronous data parallelism, where each
accelerator (a GPU or TPU) holds a complete replica of the model, and sees a di erent partial batch
of the input data. Partial gradients are computed on each device, aggregated, and used to compute
a global gradient update.
Speci cally, this guide teaches you how to use the tf.distribute API to train Keras models on
multiple GPUs, with minimal changes to your code, in the following two setups:
Imports
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import tensorflow as tf
import keras
import keras_nlp
Before we start any training, let's con gure our single GPU to show up as two logical devices.
When you are training with two or more physical GPUs, this is totally uncessary. This is just a trick to
show real distributed training on the default colab GPU runtime, which has only one GPU available.
!nvidia-smi --query-gpu=memory.total --format=csv,noheader
logical_devices = tf.config.list_logical_devices("GPU")
logical_devices
EPOCHS = 3
24576 MiB
To do single-host, multi-device synchronous training with a Keras model, you would use the
tf.distribute.MirroredStrategy API. Here's how it works:
Instantiate a MirroredStrategy, optionally con guring which speci c devices you want to use
(by default the strategy will use all GPUs available).
Use the strategy object to open a scope, and within this scope, create all the Keras objects you
need that contain variables. Typically, that means creating & compiling the model inside the
distribution scope.
Train the model via fit() as usual.
strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")
base_batch_size = 32
base_learning_rate = 1e-4
Now, we need to download and preprocess the wikitext-2 dataset. This dataset will be used for
pretraining the BERT model. We will lter out short lines to ensure that the data has enough context
for training.
keras.utils.get_file(
origin="https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip",
extract=True,
)
wiki_dir = os.path.expanduser("~/.keras/datasets/wikitext-2/") Data Parallel Training with
KerasNLP and tf.distribute
# Load wikitext-103 and filter out short lines.
◆ Introduction
wiki_train_ds = (
tf.data.TextLineDataset( ◆ Imports
wiki_dir + "wiki.train.tokens", Further reading
)
.filter(lambda x: tf.strings.length(x) > 100)
.shuffle(buffer_size=500)
.batch(scaled_batch_size)
.cache()
.prefetch(tf.data.AUTOTUNE)
)
wiki_val_ds = (
tf.data.TextLineDataset(wiki_dir + "wiki.valid.tokens")
.filter(lambda x: tf.strings.length(x) > 100)
.shuffle(buffer_size=500)
.batch(scaled_batch_size)
.cache()
.prefetch(tf.data.AUTOTUNE)
)
wiki_test_ds = (
tf.data.TextLineDataset(wiki_dir + "wiki.test.tokens")
.filter(lambda x: tf.strings.length(x) > 100)
.shuffle(buffer_size=500)
.batch(scaled_batch_size)
.cache()
.prefetch(tf.data.AUTOTUNE)
)
In the above code, we download the wikitext-2 dataset and extract it. Then, we de ne three
datasets: wiki_train_ds, wiki_val_ds, and wiki_test_ds. These datasets are ltered to remove short
lines and are batched for e cient training.
It's a common practice to use a decayed learning rate in NLP training/tuning. We'll use
PolynomialDecay schedule here.
class PrintLR(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
print(
f"\nLearning rate for epoch {epoch + 1} is
{model_dist.optimizer.learning_rate.numpy()}"
)
Let's also make a callback to TensorBoard, this will enable visualization of di erent metrics while we
train the model in later part of this tutorial. We put all the callbacks together as follows:
callbacks = [
tf.keras.callbacks.TensorBoard(log_dir="./logs"),
PrintLR(),
]
print(tf.config.list_physical_devices("GPU"))
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
With the datasets prepared, we now initialize and compile our model and optimizer within the
strategy.scope():
with strategy.scope():
# Everything that creates variables should be under the strategy scope.
# In general this is only model construction & `compile()`.
model_dist = keras_nlp.models.BertMaskedLM.from_preset("bert_tiny_en_uncased")
Data Parallel Training with
# This line just sets pooled_dense layer as non-trainiable, we do this to avoid KerasNLP and tf.distribute
# warnings of this layer being unused
◆ Introduction
model_dist.get_layer("bert_backbone").get_layer("pooled_dense").trainable = False
◆ Imports
model_dist.compile( Further reading
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.AdamW(learning_rate=scaled_learning_rate),
weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
jit_compile=False,
)
model_dist.fit(
wiki_train_ds, validation_data=wiki_val_ds, epochs=EPOCHS, callbacks=callbacks
)
Epoch 1/3
Learning rate for epoch 1 is 0.00019999999494757503
239/239 ━━━━━━━━━━━━━━━━━━━━ 43s 136ms/step - loss: 3.7009 -
sparse_categorical_accuracy: 0.1499 - val_loss: 1.1509 -
val_sparse_categorical_accuracy: 0.3485
Epoch 2/3
239/239 ━━━━━━━━━━━━━━━━━━━━ 0s 122ms/step - loss: 2.6094 -
sparse_categorical_accuracy: 0.5284
Learning rate for epoch 2 is 0.00019999999494757503
239/239 ━━━━━━━━━━━━━━━━━━━━ 32s 133ms/step - loss: 2.6038 -
sparse_categorical_accuracy: 0.5274 - val_loss: 0.9812 -
val_sparse_categorical_accuracy: 0.4006
Epoch 3/3
239/239 ━━━━━━━━━━━━━━━━━━━━ 0s 123ms/step - loss: 2.3564 -
sparse_categorical_accuracy: 0.6053
Learning rate for epoch 3 is 0.00019999999494757503
239/239 ━━━━━━━━━━━━━━━━━━━━ 32s 134ms/step - loss: 2.3514 -
sparse_categorical_accuracy: 0.6040 - val_loss: 0.9213 -
val_sparse_categorical_accuracy: 0.4230
model_dist.evaluate(wiki_test_ds)
[0.9470901489257812, 0.4373602867126465]
For distributed training across multiple machines (as opposed to training that only leverages
multiple devices on a single machine), there are two distribution strategies you could use:
MultiWorkerMirroredStrategy and ParameterServerStrategy:
Further reading
1. TensorFlow distributed training guide
2. Tutorial on multi-worker training with Keras
3. MirroredStrategy docs
4. MultiWorkerMirroredStrategy docs
5. Distributed training in tf.keras with Weights & Biases
Terms | Privacy