Refine data during training with Amazon SageMaker smart sifting - Amazon SageMaker
Services or capabilities described in Amazon Web Services documentation might vary by Region. To see the differences applicable to the China Regions, see Getting Started with Amazon Web Services in China (PDF).

Refine data during training with Amazon SageMaker smart sifting

SageMaker smart sifting is a capability of SageMaker Training that helps improve the efficiency of your training datasets and reduce total training time and cost.

Modern deep learning models such as large language models (LLMs) or vision transformer models often require massive datasets to achieve acceptable accuracy. For example, LLMs often require trillions of tokens or petabytes of data to converge. The growing size of training datasets, along with the size of state-of-the-art models, can increase the compute time and cost of model training.

Invariably, samples in a dataset do not contribute equally to the learning process during model training. A significant proportion of computational resources provisioned during training might be spent on processing easy samples that do not contribute substantially to the overall accuracy of a model. Ideally, training datasets would only include samples that are actually improving the model convergence. Filtering out less helpful data can reduce training time and compute cost. However, identifying less helpful data can be challenging and risky. It is practically difficult to identify which samples are less informative before training, and model accuracy can be impacted if the wrong samples or too many samples are excluded.

Smart sifting of data with Amazon SageMaker can help reduce training time and cost by improving data efficiency. The SageMaker smart sifting algorithm evaluates the loss value of each data during the data loading stage of a training job and excludes samples which are less informative to the model. By using refined data for training, the total time and cost of training your model is reduced by eliminating unnecessary forward and backward passes on non-improving data. Therefore, there is minimal or no impact on the accuracy of the model.

SageMaker smart sifting is available through SageMaker Training Deep Learning Containers (DLCs) and supports PyTorch workloads via the PyTorch DataLoader. Just a few lines of code change are needed to implement SageMaker smart sifting and you do not need to change your existing training or data processing workflows.

How SageMaker smart sifting works

The goal of SageMaker smart sifting is to sift through your training data during the training process and only feed more informative samples to the model. During typical training with PyTorch, data is iteratively sent in batches to the training loop and to accelerator devices (such as GPUs or Trainium chips) by the PyTorch DataLoader. SageMaker smart sifting is implemented at this data loading stage and is thus independent of any upstream data pre-processing in your training pipeline. SageMaker smart sifting uses your model and its user-specified loss function to do an evaluative forward pass of each data sample as it is loaded. Samples that return low-loss values have less of an impact on the model's learning and are thus excluded from training, because it is already easy for the model to make the right prediction about them with high confidence. Meanwhile, those relatively high-loss samples are what the model still needs to learn, so these are kept for training. A key input you can set for SageMaker smart sifting is the proportion of data to exclude. For example, by setting the proportion to 25%, samples distributed in the lowest quartile of the distribution of loss (taken from a user-specified number of previous samples) are excluded from training. High-loss samples are accumulated in a refined data batch. The refined data batch is sent to the training loop (forward and backward pass), and the model learns and trains on the refined data batch.

The following diagram shows an overview of how the SageMaker smart sifting algorithm is designed.

Architecture diagram of how SageMaker smart sifting operates during training as data is loaded.

In short, SageMaker smart sifting operates during training as data is loaded. The SageMaker smart sifting algorithm runs loss calculation over the batches, and sifts non-improving data out before the forward and backward pass of each iteration. The refined data batch is then used for the forward and backward pass.

SageMaker smart sifting works for PyTorch-based training jobs with classic distributed data parallelism, which makes model replicas on each GPU worker and performs AllReduce. It works with PyTorch DDP and the SageMaker distributed data parallel library.