SageMaker 智能筛选 Python SDK 参考 - 亚马逊 SageMaker AI
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

SageMaker 智能筛选 Python SDK 参考

本页提供了在训练脚本中应用 SageMaker 智能筛选所需的 Python 模块的参考。

SageMaker 智能筛选配置模块

class smart_sifting.sift_config.sift_configs.RelativeProbabilisticSiftConfig()

SageMaker 智能筛选配置类。

参数

  • beta_value(浮点数)- β(常数)值。它用于根据损失值历史记录中的损失百分位数计算选择样本进行训练的概率。降低 β 值会降低筛选数据的百分比,而提高此值则会提高筛选数据的百分比。β 值没有最小值或最大值之分,但必须是正值。下表提供了与 beta_value 有关的筛选率信息。

    beta_value 保留数据的比例 (%) 筛选出的数据比例 (%)
    0.1 90.91 9.01
    0.25 80 20
    0.5 66.67 33.33
    1 50 50
    2 33.33 66.67
    3 25 75
    10 9.09 90.92
    100 0.99 99.01
  • loss_history_length (int):基于相对阈值损失的采样要存储的先前训练损失的数量。

  • loss_based_sift_config(dict 或LossConfig对象)— 指定返回 SageMaker 智能筛选 Loss 接口配置的LossConfig对象。

class smart_sifting.sift_config.sift_configs.LossConfig()

RelativeProbabilisticSiftConfigloss_based_sift_config 参数的配置类。

参数

  • sift_config(dict 或 SiftingBaseConfig 对象):指定返回筛选基础配置字典的 SiftingBaseConfig 对象。

class smart_sifting.sift_config.sift_configs.SiftingBaseConfig()

LossConfigsift_config 参数的配置类。

参数

  • sift_delay (int):开始筛选之前要等待的训练步骤数。我们建议您在模型中的所有层都有足够的训练数据视图后再开始筛选。默认值为 1000

  • repeat_delay_per_epoch (bool):指定是否延迟筛选每个历时的时间。默认值为 False

SageMaker 智能筛选数据批量转换模块

class smart_sifting.data_model.data_model_interface.SiftingBatchTransform

一个 SageMaker 智能筛选 Python 模块,用于定义如何执行批量转换。使用它,您可以设置一个批处理转换类,将训练数据的数据SiftingBatch格式转换为格式。 SageMaker 智能筛选可以将这种格式的数据筛选并累积成经过筛选的批次。

class smart_sifting.data_model.data_model_interface.SiftingBatch

用于定义可筛选和累积的批次数据类型的界面。

class smart_sifting.data_model.list_batch.ListBatch

用于跟踪列表批次以进行筛选的模块。

class smart_sifting.data_model.tensor_batch.TensorBatch

用于跟踪张量批次以进行筛选的模块。

SageMaker 智能筛选损失实现模块

class smart_sifting.loss.abstract_sift_loss_module.Loss

一个包装模块,用于将 SageMaker 智能筛选接口注册到 PyTorch基于模型的损失函数。

SageMaker 智能筛选数据加载器封装模块

class smart_sifting.dataloader.sift_dataloader.SiftingDataloader

一个封装模块,用于将 SageMaker 智能筛选接口注册到 PyTorch基于模型的数据加载器。

主筛选数据加载器迭代器根据筛选配置从数据加载器中筛选出训练样本。

参数

  • sift_config(dict 或 RelativeProbabilisticSiftConfig 对象):RelativeProbabilisticSiftConfig 对象。

  • orig_dataloader( PyTorch DataLoader 对象)— 指定要封装的 PyTorch Dataloader 对象。

  • batch_transformsSiftingBatchTransform对象)—(可选)如果 SageMaker 智能筛选库的默认转换不支持您的数据格式,则必须使用该SiftingBatchTransform模块创建批处理转换类。此参数用于传递批次转换类。该类用于将数据SiftingDataloader转换为 SageMaker 智能筛选算法可以接受的格式。

  • model( PyTorch 模型对象)-原始 PyTorch模型

  • loss_impl(的筛选损失函数smart_sifting.loss.abstract_sift_loss_module.Loss)— 一种筛选损失函数,它与Loss模块一起配置并封装损失函数。 PyTorch

  • log_batch_data (bool):指定是否记录批次数据。如果设置为True,则 SageMaker 智能筛选会记录保留或筛选的批次的详细信息。我们建议您只在测试训练作业时打开它。开启日志记录时,样本会被加载到 GPU 并传输到 CPU,这会带来开销。默认值为 False