将 SageMaker 智能筛选应用到您的 PyTorch 脚本中 - 亚马逊 SageMaker AI
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

将 SageMaker 智能筛选应用到您的 PyTorch 脚本中

这些说明演示了如何使用训练脚本启用 SageMaker 智能筛选。

  1. 配置 SageMaker 智能筛选界面。

    SageMaker 智能筛选存储库采用了一种基于相对阈值损失的采样技术,有助于筛选出对降低损失值影响较小的样本。SageMaker 智能筛选算法使用前向传递计算每个输入数据样本的损失值,并计算其与前面数据损失值的相对百分位数。

    以下两个参数是创建筛选配置对象时需要为 RelativeProbabilisticSiftConfig 类指定的参数。

    • 指定用于 beta_value 参数训练的数据比例。

    • 使用 loss_history_length 参数指定用于比较的样本数。

    以下代码示例演示了如何设置 RelativeProbabilisticSiftConfig 类的对象。

    from smart_sifting.sift_config.sift_configs import ( RelativeProbabilisticSiftConfig LossConfig SiftingBaseConfig ) sift_config=RelativeProbabilisticSiftConfig( beta_value=0.5, loss_history_length=500, loss_based_sift_config=LossConfig( sift_config=SiftingBaseConfig(sift_delay=0) ) )

    有关 loss_based_sift_config 参数和相关类的更多信息,请参阅 SageMaker 智能筛选 Python SDK 参考部分中的 SageMaker 智能筛选配置模块

    前面代码示例中的 sift_config 对象在第 4 步中用于设置 SiftingDataloader 类。

  2. (可选)配置 SageMaker 智能筛选批量转换类。

    不同的训练使用场景需要不同的训练数据格式。鉴于数据格式多种多样,SageMaker 智能筛选算法需要确定如何对特定批次进行筛选。为了解决此问题,SageMaker 智能筛分提供了一个批次转换模块,有助于将批次转换为可以高效筛选的标准化格式。

    1. SageMaker 智能筛选功能可批量转换以下格式的训练数据:Python 列表、字典、元组和张量。对于这些数据格式,SageMaker 智能筛选器会自动处理批量数据格式转换,您可以跳过此步骤的其余部分。如果您跳过此步骤,在配置 SiftingDataloader 的第 4 步中,请将 SiftingDataloaderbatch_transforms 参数保留为默认值 None

    2. 如果您的数据集不是这些格式,则您应继续本步骤的其余部分,使用 SiftingBatchTransform 创建自定义批量转换。

      如果您的数据集不是 SageMaker 智能筛选所支持的格式之一,则您可能会遇到错误。此类数据格式错误可以通过在 SiftingDataloader 类中添加 batch_format_indexbatch_transforms 参数来解决,您可以在第 4 步中进行设置。下面显示了由于数据格式不兼容而导致的错误示例以及解决方法。

      错误消息 解决方案

      默认不支持 {type(batch)} 类型的批次。

      此错误表示默认不支持批次格式。您应该实现一个自定义批次转换类,并通过将其指定给 SiftingDataloader 类的 batch_transforms 参数中来使用它。

      无法为类型为 {type(batch)} 的批次编制索引

      此错误表明无法正常为批次对象编制索引。用户必须实现自定义批次转换,并使用 batch_transforms 参数传递。

      批次大小 {batch_size} 与维度 0 或维度 1 大小不匹配

      当提供的批次大小与批次的维度 0 或维度 1 不匹配时,会出现此错误。用户必须实现自定义批次转换,并使用 batch_transforms 参数传递。

      维度 0 和维度 1 都匹配批次大小

      此错误表明,由于多个维度与提供的批次大小相匹配,因此需要更多信息来筛选批次。用户可提供 batch_format_index 参数,指示批次是否可按样本或特征编制索引。用户也可以实施自定义批次转换,但这比所需的工作量更大。

      要解决上述问题,您需要使用 SiftingBatchTransform 模块创建自定义批处理转换类。批次转换类应由一对转换和反向转换函数组成。此函数对将您的数据格式转换为 SageMaker 智能筛选算法可以处理的格式。创建批次转换类后,此类会返回一个 SiftingBatch 对象,您将在第 4 步中把此对象传递给 SiftingDataloader 类。

      以下是 SiftingBatchTransform 模块中自定义批次转换类的示例。

      • 使用 SageMaker 智能筛选实现自定义列表批次转换的示例,适用于数据加载器块具有输入、掩码和标签的情况。

        from typing import Any import torch from smart_sifting.data_model.data_model_interface import SiftingBatchTransform from smart_sifting.data_model.list_batch import ListBatch class ListBatchTransform(SiftingBatchTransform): def transform(self, batch: Any): inputs = batch[0].tolist() labels = batch[-1].tolist() # assume the last one is the list of labels return ListBatch(inputs, labels) def reverse_transform(self, list_batch: ListBatch): a_batch = [torch.tensor(list_batch.inputs), torch.tensor(list_batch.labels)] return a_batch
      • 使用 SageMaker 智能筛选功能实现自定义列表批次转换的示例,适用于不需要标签进行反向转换的情况。

        class ListBatchTransformNoLabels(SiftingBatchTransform): def transform(self, batch: Any): return ListBatch(batch[0].tolist()) def reverse_transform(self, list_batch: ListBatch): a_batch = [torch.tensor(list_batch.inputs)] return a_batch
      • 使用 SageMaker 智能筛选实现自定义张量批次的示例,适用于数据加载器块具有输入、掩码和标签的情况。

        from typing import Any from smart_sifting.data_model.data_model_interface import SiftingBatchTransform from smart_sifting.data_model.tensor_batch import TensorBatch class TensorBatchTransform(SiftingBatchTransform): def transform(self, batch: Any): a_tensor_batch = TensorBatch( batch[0], batch[-1] ) # assume the last one is the list of labels return a_tensor_batch def reverse_transform(self, tensor_batch: TensorBatch): a_batch = [tensor_batch.inputs, tensor_batch.labels] return a_batch

      在您创建已执行 SiftingBatchTransform 批次转换类后,可在第 4 步中使用 SiftingDataloader 类进行设置。本指南的其余部分假设已创建了一个 ListBatchTransform 类。在第 4 步中,此类将传递给 batch_transforms

  3. 创建一个用于实现 SageMaker 智能筛选 Loss 界面的类。本教程假定此类名为 SiftingImplementedLoss。在设置此类时,我们建议您在模型训练循环中使用相同的损失函数。按照以下子步骤创建 SageMaker 智能筛选 Loss 实现的类。

    1. SageMaker 智能筛选会计算每个训练数据样本的损失值,而不是计算批次的单个损失值。为确保 SageMaker 智能筛选使用相同的损失计算逻辑,使用 SageMaker 智能筛选 Loss 模块(此模块使用您的损失函数并计算每个训练样本的损失)创建智能筛选实现的损失函数。

      提示

      SageMaker 智能筛选算法在每个数据样本上运行,而不是在整个批次数据上运行,因此您应该添加一个初始化函数来设置 PyTorch 损失函数,而不使用任何还原策略。

      class SiftingImplementedLoss(Loss): def __init__(self): self.loss = torch.nn.CrossEntropyLoss(reduction='none')

      以下代码示例也说明了这一点。

    2. 定义一个可接受 original_batch(或 transformed_batch,如果在步骤 2 中设置了批次转换)和 PyTorch 模型的损失函数。SageMaker 智能筛选使用指定的损失函数,对每个数据样本进行前向传递,以评估其损失值。

    以下代码是一个名为 SiftingImplementedLoss 的智能筛选实现的 Loss 界面的示例。

    from typing import Any import torch import torch.nn as nn from torch import Tensor from smart_sifting.data_model.data_model_interface import SiftingBatch from smart_sifting.loss.abstract_sift_loss_module import Loss model=... # a PyTorch model based on torch.nn.Module class SiftingImplementedLoss(Loss): # You should add the following initializaztion function # to calculate loss per sample, not per batch. def __init__(self): self.loss_no_reduction = torch.nn.CrossEntropyLoss(reduction='none') def loss( self, model: torch.nn.Module, transformed_batch: SiftingBatch, original_batch: Any = None, ) -> torch.Tensor: device = next(model.parameters()).device batch = [t.to(device) for t in original_batch] # use this if you use original batch and skipped step 2 # batch = [t.to(device) for t in transformed_batch] # use this if you transformed batches in step 2 # compute loss outputs = model(batch) return self.loss_no_reduction(outputs.logits, batch[2])

    在训练循环进入实际前向传递之前,每次迭代获取批次数据的数据加载阶段都会进行筛选损失计算。然后将单个损失值与之前的损失值进行比较,并根据步骤 1 中设置的 RelativeProbabilisticSiftConfig 对象估算出其相对百分位数。

  4. 使用 SageMaker AI SiftingDataloader 类封装 PyTroch 数据加载器。

    最后,在 SageMaker AI SiftingDataloder 配置类中使用在前面步骤中配置的所有 SageMaker 智能筛选实现类。此类是 PyTorch DataLoader 的封装器。通过对 PyTorch DataLoader 进行封装,SageMaker 智能筛分被注册为 PyTorch 训练作业每次迭代中数据加载的一部分。以下代码示例演示了如何将 SageMaker AI 数据筛选实现到 PyTorch DataLoader

    from smart_sifting.dataloader.sift_dataloader import SiftingDataloader from torch.utils.data import DataLoader train_dataloader = DataLoader(...) # PyTorch data loader # Wrap the PyTorch data loader by SiftingDataloder train_dataloader = SiftingDataloader( sift_config=sift_config, # config object of RelativeProbabilisticSiftConfig orig_dataloader=train_dataloader, batch_transforms=ListBatchTransform(), # Optional, this is the custom class from step 2 loss_impl=SiftingImplementedLoss(), # PyTorch loss function wrapped by the Sifting Loss interface model=model, log_batch_data=False )