There are generally two ways to distribute computation across multiple devices:
Data parallelism, where a single model gets replicated on multiple devices or multiple machines. Each of them processes different batches of data, then they merge their results. There exist many variants of this setup, that differ in how the different model replicas merge results, in whether they stay in sync at every batch or whether they are more loosely coupled, etc.
Model parallelism, where different parts of a single model run on different devices, processing a single batch of data together. This works best with models that have a naturally-parallel architecture, such as models that feature multiple branches.
This guide focuses on data parallelism, in particular synchronous data parallelism, where the different replicas of the model stay in sync after each batch they process. Synchronicity keeps the model convergence behavior identical to what you would see for single-device training.
Specifically, this guide teaches you how to use the
tf.distribute API to train Keras models on multiple GPUs,
with minimal changes to your code, on multiple GPUs (typically 2 to 16)
installed on a single machine (single host, multi-device training). This
is the most common setup for researchers and small-scale industry
workflows.
In this setup, you have one machine with several GPUs on it (typically 2 to 16). Each device will run a copy of your model (called a replica). For simplicity, in what follows, we’ll assume we’re dealing with 8 GPUs, at no loss of generality.
How it works
At each step of training:
In practice, the process of synchronously updating the weights of the model replicas is handled at the level of each individual weight variable. This is done through a mirrored variable object.
How to use it
To do single-host, multi-device synchronous training with a Keras
model, you would use the tf$distribute$MirroredStrategy
API. Here’s how it works:
MirroredStrategy, optionally configuring
which specific devices you want to use (by default the strategy will use
all GPUs available).fit() may also create variables, so it’s a good idea to put
your fit() call in the scope as well.fit() as usual.Importantly, we recommend that you use tf.data.Dataset
objects to load data in a multi-device or distributed workflow.
Schematically, it looks like this:
# Create a MirroredStrategy.
strategy <- tf$distribute$MirroredStrategy()
cat(sprintf('Number of devices: %d\n', strategy$num_replicas_in_sync))
# Open a strategy scope.
with(startegy$scope(), {
  # Everything that creates variables should be under the strategy scope.
  # In general this is only model construction & `compile()`.
  model <- Model(...)
  model |> compile(...)
  # Train the model on all available devices.
  model |> fit(train_dataset, validation_data=val_dataset, ...)
  # Test the model on all available devices.
  model |> evaluate(test_dataset)
})Here’s a simple end-to-end runnable example:
get_compiled_model <- function() {
  inputs <- keras_input(shape = 784)
  outputs <- inputs |>
    layer_dense(units = 256, activation = "relu") |>
    layer_dense(units = 256, activation = "relu") |>
    layer_dense(units = 10)
  model <- keras_model(inputs, outputs)
  model |> compile(
    optimizer = optimizer_adam(),
    loss = loss_sparse_categorical_crossentropy(from_logits = TRUE),
    metrics = list(metric_sparse_categorical_accuracy()),
    # XLA compilation is temporarily disabled due to a bug
    # https://github.com/keras-team/keras/issues/19005
    jit_compile = FALSE
  )
  model
}
get_dataset <- function(batch_size = 64) {
  c(c(x_train, y_train), c(x_test, y_test)) %<-% dataset_mnist()
  x_train <- array_reshape(x_train, c(-1, 784))
  x_test <- array_reshape(x_test, c(-1, 784))
  # Reserve 10,000 samples for validation.
  val_i <- sample.int(nrow(x_train), 10000)
  x_val <- x_train[val_i,]
  y_val <- y_train[val_i]
  x_train = x_train[-val_i,]
  y_train = y_train[-val_i]
  # Prepare the training dataset.
  train_dataset <- list(x_train, y_train) |>
    tensor_slices_dataset() |>
    dataset_batch(batch_size)
  # Prepare the validation dataset.
  val_dataset <- list(x_val, y_val) |>
    tensor_slices_dataset() |>
    dataset_batch(batch_size)
  # Prepare the test dataset.
  test_dataset <- list(x_test, y_test) |>
    tensor_slices_dataset() |>
    dataset_batch(batch_size)
  list(train_dataset, val_dataset, test_dataset)
}
# Create a MirroredStrategy.
strategy <- tf$distribute$MirroredStrategy()
cat(sprintf('Number of devices: %d\n', strategy$num_replicas_in_sync))## Number of devices: 2# Open a strategy scope.
with(strategy$scope(), {
  # Everything that creates variables should be under the strategy scope.
  # In general this is only model construction & `compile()`.
  model <- get_compiled_model()
  c(train_dataset, val_dataset, test_dataset) %<-% get_dataset()
  # Train the model on all available devices.
  model |> fit(train_dataset, epochs = 2, validation_data = val_dataset)
  # Test the model on all available devices.
  model |> evaluate(test_dataset)
})## Epoch 1/2
## 782/782 - 5s - 7ms/step - loss: 3.0622 - sparse_categorical_accuracy: 0.8615 - val_loss: 1.1367 - val_sparse_categorical_accuracy: 0.9006
## Epoch 2/2
## 782/782 - 3s - 4ms/step - loss: 0.5774 - sparse_categorical_accuracy: 0.9259 - val_loss: 0.6612 - val_sparse_categorical_accuracy: 0.9210
## 157/157 - 0s - 2ms/step - loss: 0.6729 - sparse_categorical_accuracy: 0.9150## $loss
## [1] 0.6728871
##
## $sparse_categorical_accuracy
## [1] 0.915When using distributed training, you should always make sure you have
a strategy to recover from failure (fault tolerance). The simplest way
to handle this is to pass ModelCheckpoint callback to
fit(), to save your model at regular intervals (e.g. every
100 batches or every epoch). You can then restart training from your
saved model.
Here’s a simple example:
# Prepare a directory to store all the checkpoints.
checkpoint_dir <- "./ckpt"
if (!dir.exists(checkpoint_dir)) {
  dir.create(checkpoint_dir)
}
make_or_restore_model <- function() {
  # Either restore the latest model, or create a fresh one
  # if there is no checkpoint available.
  checkpoints <- list.files(checkpoint_dir, #pattern = "ckpt-.*\\.keras",
                            full.names = TRUE)
  if (length(checkpoints) > 0) {
    checkpoint_epochs <- as.integer(sub("ckpt-([0-9]+)\\.keras", "\\1",
                                        basename(checkpoints)))
    latest_checkpoint <- checkpoints[which.max(checkpoint_epochs)]
    load_model(latest_checkpoint)
  } else {
    get_compiled_model()
  }
}
run_training <- function(epochs = 1) {
  # Create a MirroredStrategy.
  strategy <- tf$distribute$MirroredStrategy()
  # Open a strategy scope and create/restore the model
  with(strategy$scope(), {
    model <- make_or_restore_model()
    callbacks <- list(
      # This callback saves a SavedModel every epoch
      # We include the current epoch in the folder name.
      callback_model_checkpoint(
        filepath = paste0(checkpoint_dir, "/ckpt-{epoch}.keras"),
        save_freq = "epoch"
      ))
    model |> fit(
      train_dataset,
      epochs = epochs,
      callbacks = callbacks,
      validation_data = val_dataset,
      verbose = 2
    )
  })
}
# Running the first time creates the model
run_training(epochs = 1)## 782/782 - 4s - 6ms/step - loss: 2.9519 - sparse_categorical_accuracy: 0.8655 - val_loss: 1.3110 - val_sparse_categorical_accuracy: 0.8836## 782/782 - 3s - 4ms/step - loss: 0.5998 - sparse_categorical_accuracy: 0.9270 - val_loss: 0.8736 - val_sparse_categorical_accuracy: 0.9128tf$data performance tipsWhen doing distributed training, the efficiency with which you load
data can often become critical. Here are a few tips to make sure your
tf$data pipelines run as fast as possible.
Note about dataset batching
When creating your dataset, make sure it is batched with the global batch size. For instance, if each of your 8 GPUs is capable of running a batch of 64 samples, you call use a global batch size of 512.
Calling dataset_cache()
If you call dataset_cache() on a dataset, its data will
be cached after running through the first iteration over the data. Every
subsequent iteration will use the cached data. The cache can be in
memory (default) or to a local file you specify.
This can improve performance when:
Calling
dataset_prefetch(buffer_size)
You should almost always call
dataset_prefetch(buffer_size) after creating a dataset. It
means your data pipeline will run asynchronously from your model, with
new samples being preprocessed and stored in a buffer while the current
batch samples are used to train the model. The next batch will be
prefetched in GPU memory by the time the current batch is over.
That’s it!