IBM Large Model Support (LMS)

LMS seamlessly moves layers of a model between the GPU and CPU to overcome GPU memory limits. This allows training of:

  • Deeper models

  • Higher resolution data

  • Larger batch sizes

Satori nodes have a fast NVLink 2.0 connection between the CPU and GPU, which allows for data swapping with minimal overhead compared with traditional x86 GPU accelerated systems where PCIe Gen3 is available for connection between the CPU and GPU.

# Example for enable TFLMSv2 in TensorFlow
# ----------------------------------------
# Import the TF LMS module
from tensorflow_large_model_support import LMS
# Instantiate the LMS object, with maximum swapping parameters
# If you wanted to make use of the auto-tuning feature simply
# initialise the LMS object without any arguments
# e.g. lms_hook = LMS()
lms_hook = LMS(swapout_threshold=1,
# Make LMS aware of the train_batch_size parameter
lms_hook.batch_size = FLAGS.train_batch_size
# Include the lms_hook object in the estimator hooks list

NOTE: TFLMSv2 introduces four hyper-parameters to work with. Typically you don’t need to worry about them, as LMS introduces an auto-tuning feature which automatically evaluates your computational graph and sets appropriate values for these hyper-parameters, based upon estimated memory consumption throughout training. However manual tuning allows for closer control- squeezing out maximum performance. The four hyper-parameters introduced are:

  • swapout_threshold: The number of tensors to hold within GPU memory before pushing them to system memory.

  • swapin_ahead: The larger swapin_ahead is, the earlier a tensor is swapped in to the GPU memory from the host memory.

  • swapin_groupby: Multiple swap-in operations of the same tensor will be grouped or fused into one swap-in operation for better performance if they are close to each other (the distance between them is within swapin_groupby).

  • sync_mode: Whether to do synchronisation between data transfer and kernel computation or not.

Installing PyTorch Large Model Support LMS is built into the pytorch conda package so it is installed by default when you install the GPU enabled PyTorch from WML CE. The support is currently available in the WML CE conda channel. For getting best LMS performance you need to use Pytorch 1.4 from Watson ML Community Edition Early Access Channel (pytorch 1.4.0==23447.g18a1a27)

How to enable LMS in Pytorch 1.14?

The LMS functionality is disabled by default in PyTorch and needs to be enabled before your model creates tensors. Enabling LMS is as simple as calling the enablement API at the start of your program:

import torch

Additional Documentation and Tutorial: