对脚本应用 SageMaker 智能筛选 PyTorch - 亚马逊 SageMaker AI
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

对脚本应用 SageMaker 智能筛选 PyTorch

这些说明演示了如何使用训练脚本启用 SageMaker 智能筛选。

  1. 配置 SageMaker 智能筛选界面。

    SageMaker 智能筛选库实现了一种基于相对阈值损耗的采样技术,该技术有助于筛选出对降低损耗值影响较小的样本。 SageMaker 智能筛选算法使用正向传递计算每个输入数据样本的损失值,并根据先前数据的损失值计算其相对百分位数。

    以下两个参数是创建筛选配置对象时需要为 RelativeProbabilisticSiftConfig 类指定的参数。

    • 指定用于 beta_value 参数训练的数据比例。

    • 使用 loss_history_length 参数指定用于比较的样本数。

    以下代码示例演示了如何设置 RelativeProbabilisticSiftConfig 类的对象。

    from smart_sifting.sift_config.sift_configs import ( RelativeProbabilisticSiftConfig LossConfig SiftingBaseConfig ) sift_config=RelativeProbabilisticSiftConfig( beta_value=0.5, loss_history_length=500, loss_based_sift_config=LossConfig( sift_config=SiftingBaseConfig(sift_delay=0) ) )

    有关loss_based_sift_config参数和相关类的更多信息,请参阅SageMaker 智能筛选配置模块 SageMaker 智能筛选 Python SDK 参考部分中的。

    前面代码示例中的 sift_config 对象在第 4 步中用于设置 SiftingDataloader 类。

  2. (可选)配置 SageMaker 智能筛选批量转换类。

    不同的训练使用场景需要不同的训练数据格式。鉴于数据格式多种多样, SageMaker 智能筛选算法需要确定如何对特定批次进行筛选。为了解决这个问题, SageMaker 智能筛选提供了一个批量转换模块,可以帮助将批次转换为可以高效筛选的标准化格式。

    1. SageMaker 智能筛选处理以下格式的训练数据的批量转换:Python 列表、字典、元组和张量。对于这些数据格式, SageMaker 智能筛选会自动处理批量数据格式转换,您可以跳过此步骤的其余部分。如果您跳过此步骤,在配置 SiftingDataloader 的第 4 步中,请将 SiftingDataloaderbatch_transforms 参数保留为默认值 None

    2. 如果您的数据集不是这些格式,则您应继续本步骤的其余部分,使用 SiftingBatchTransform 创建自定义批量转换。

      如果您的数据集不是 SageMaker 智能筛选支持的格式之一,则可能会遇到错误。此类数据格式错误可以通过在 SiftingDataloader 类中添加 batch_format_indexbatch_transforms 参数来解决,您可以在第 4 步中进行设置。下面显示了由于数据格式不兼容而导致的错误示例以及解决方法。

      错误消息 解决方案

      默认情况下,{type(batch)}不支持该类型的批处理。

      此错误表示默认不支持批次格式。您应该实现一个自定义批次转换类,并通过将其指定给 SiftingDataloader 类的 batch_transforms 参数中来使用它。

      无法为该批次编制索引 {type(batch)}

      此错误表明无法正常为批次对象编制索引。用户必须实现自定义批次转换,并使用 batch_transforms 参数传递。

      Batch 大小与维度 0 或维度 1 的大小{batch_size}不匹配

      当提供的批次大小与批次的维度 0 或维度 1 不匹配时,会出现此错误。用户必须实现自定义批次转换,并使用 batch_transforms 参数传递。

      维度 0 和维度 1 都匹配批次大小

      此错误表明,由于多个维度与提供的批次大小相匹配,因此需要更多信息来筛选批次。用户可提供 batch_format_index 参数,指示批次是否可按样本或特征编制索引。用户也可以实施自定义批次转换,但这比所需的工作量更大。

      要解决上述问题,您需要使用 SiftingBatchTransform 模块创建自定义批处理转换类。批次转换类应由一对转换和反向转换函数组成。函数对将您的数据格式转换为 SageMaker 智能筛选算法可以处理的格式。创建批次转换类后,此类会返回一个 SiftingBatch 对象,您将在第 4 步中把此对象传递给 SiftingDataloader 类。

      以下是 SiftingBatchTransform 模块中自定义批次转换类的示例。

      • 使用 SageMaker 智能筛选实现自定义列表批量转换的示例,适用于数据加载器块包含输入、掩码和标签的情况。

        from typing import Any import torch from smart_sifting.data_model.data_model_interface import SiftingBatchTransform from smart_sifting.data_model.list_batch import ListBatch class ListBatchTransform(SiftingBatchTransform): def transform(self, batch: Any): inputs = batch[0].tolist() labels = batch[-1].tolist() # assume the last one is the list of labels return ListBatch(inputs, labels) def reverse_transform(self, list_batch: ListBatch): a_batch = [torch.tensor(list_batch.inputs), torch.tensor(list_batch.labels)] return a_batch
      • 使用 SageMaker 智能筛选实现自定义列表批量转换的示例,适用于不需要标签进行反向转换的情况。

        class ListBatchTransformNoLabels(SiftingBatchTransform): def transform(self, batch: Any): return ListBatch(batch[0].tolist()) def reverse_transform(self, list_batch: ListBatch): a_batch = [torch.tensor(list_batch.inputs)] return a_batch
      • 在数据加载器块有输入、掩码和标签的情况下,使用 SageMaker 智能筛选的自定义张量批处理实现示例。

        from typing import Any from smart_sifting.data_model.data_model_interface import SiftingBatchTransform from smart_sifting.data_model.tensor_batch import TensorBatch class TensorBatchTransform(SiftingBatchTransform): def transform(self, batch: Any): a_tensor_batch = TensorBatch( batch[0], batch[-1] ) # assume the last one is the list of labels return a_tensor_batch def reverse_transform(self, tensor_batch: TensorBatch): a_batch = [tensor_batch.inputs, tensor_batch.labels] return a_batch

      在您创建已执行 SiftingBatchTransform 批次转换类后,可在第 4 步中使用 SiftingDataloader 类进行设置。本指南的其余部分假设已创建了一个 ListBatchTransform 类。在第 4 步中,此类将传递给 batch_transforms

  3. 创建用于实现 SageMaker 智能筛选Loss接口的类。本教程假定此类名为 SiftingImplementedLoss。在设置此类时,我们建议您在模型训练循环中使用相同的损失函数。按照以下子步骤创建 SageMaker 智能筛选Loss实现的类。

    1. SageMaker 智能筛选计算每个训练数据样本的损失值,而不是计算批次的单个损失值。为确保 SageMaker 智能筛选使用相同的损失计算逻辑,请使用 SageMaker 智能筛选Loss模块创建 smart-sifting-implemented损失函数,该模块使用您的损失函数并计算每个训练样本的损失。

      提示

      SageMaker 智能筛选算法在每个数据样本上运行,而不是在整个批次上运行,因此您应该添加一个初始化函数来设置 PyTorch 损失函数,而无需任何还原策略。

      class SiftingImplementedLoss(Loss): def __init__(self): self.loss = torch.nn.CrossEntropyLoss(reduction='none')

      以下代码示例也说明了这一点。

    2. 定义一个接受original_batch(或者transformed_batch如果您在步骤 2 中设置了批量变换)和 PyTorch模型的损失函数。 SageMaker 智能筛选使用不减值的指定损失函数,对每个数据样本进行正向传递,以评估其损失值。

    以下代码是一个名为的 smart-sifting-implementedLoss接口的示例SiftingImplementedLoss

    from typing import Any import torch import torch.nn as nn from torch import Tensor from smart_sifting.data_model.data_model_interface import SiftingBatch from smart_sifting.loss.abstract_sift_loss_module import Loss model=... # a PyTorch model based on torch.nn.Module class SiftingImplementedLoss(Loss): # You should add the following initializaztion function # to calculate loss per sample, not per batch. def __init__(self): self.loss_no_reduction = torch.nn.CrossEntropyLoss(reduction='none') def loss( self, model: torch.nn.Module, transformed_batch: SiftingBatch, original_batch: Any = None, ) -> torch.Tensor: device = next(model.parameters()).device batch = [t.to(device) for t in original_batch] # use this if you use original batch and skipped step 2 # batch = [t.to(device) for t in transformed_batch] # use this if you transformed batches in step 2 # compute loss outputs = model(batch) return self.loss_no_reduction(outputs.logits, batch[2])

    在训练循环进入实际前向传递之前,每次迭代获取批次数据的数据加载阶段都会进行筛选损失计算。然后将单个损失值与之前的损失值进行比较,并根据步骤 1 中设置的 RelativeProbabilisticSiftConfig 对象估算出其相对百分位数。

  4. 按 SageMaker AI SiftingDataloader 类封装 PyTroch 数据加载器。

    最后,将您在前面步骤中配置的所有 SageMaker 智能筛选实现的类用于 SageMaker AI SiftingDataloder 配置类。这个类是的封装器。 PyTorch DataLoader通过封装 PyTorchDataLoader, SageMaker 智能筛选被注册为在 PyTorch 训练作业的每次迭代中作为数据加载的一部分运行。以下代码示例演示如何实现 SageMaker AI 数据筛选到. PyTorch DataLoader

    from smart_sifting.dataloader.sift_dataloader import SiftingDataloader from torch.utils.data import DataLoader train_dataloader = DataLoader(...) # PyTorch data loader # Wrap the PyTorch data loader by SiftingDataloder train_dataloader = SiftingDataloader( sift_config=sift_config, # config object of RelativeProbabilisticSiftConfig orig_dataloader=train_dataloader, batch_transforms=ListBatchTransform(), # Optional, this is the custom class from step 2 loss_impl=SiftingImplementedLoss(), # PyTorch loss function wrapped by the Sifting Loss interface model=model, log_batch_data=False )