使用 Tensor 并行度进行检查点的说明 - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅中国的 Amazon Web Services 服务入门

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

使用 Tensor 并行度进行检查点的说明

这些区域有: SageMaker 模型 parallel 库支持使用张量并行性来保存部分或完整的检查点。以下指南介绍了在使用 tensor 并行机制时如何修改脚本以保存和加载检查点。

  1. 准备一个模型对象并用库的包装函数包装它smp.DistributedModel().

    model = MyModel(...) model = smp.DistributedModel(model)
  2. 为模型准备优化器。一组模型参数是优化器函数所需的可迭代参数。要准备一组模型参数,必须处理model.parameters()为单个模型参数分配唯一 ID。

    如果模型参数可迭代对象中存在具有重复 ID 的参数,则加载检查点优化程序状态将失败。要为优化程序创建具有唯一 ID 的模型参数的可迭代对象,请参阅以下内容:

    unique_params = [] unique_params_set = set() for p in model.parameters(): if p not in unique_params_set: unique_params.append(p) unique_params_set.add(p) del unique_params_set optimizer = MyOpt(unique_params, ...)
  3. 使用库的包装函数包装优化器smp.DistributedOptimizer().

    optimizer = smp.DistributedOptimizer(optimizer)
  4. 使用保存模型和优化程序状态smp.save(). 根据您要保存检查点的方式,请选择以下两个选项之一:

    • 选项 1:在每个模型上保存部分模型mp_rank单个MP_GROUP.

      model_dict = model.local_state_dict() # save a partial model opt_dict = optimizer.local_state_dict() # save a partial optimizer state # Save the dictionaries at rdp_rank 0 as a checkpoint if smp.rdp_rank() == 0: smp.save( {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, f"/checkpoint.pt", partial=True, )

      使用 tensor 并行性,库保存以下格式命名的检查点文件:checkpoint.pt_{pp_rank}_{tp_rank}.

      注意

      使用张量并行度,请务必将 if 语句设置为if smp.rdp_rank() == 0而不是if smp.dp_rank() == 0. 当优化程序状态与张量并行分区时,所有减少的数据并 parallel 排名都必须保存自己的优化程序状态分区。使用错误如果检查点声明可能会导致培训工作停滞。有关使用的更多信息if smp.dp_rank() == 0没有张量并行性,请参见保存和加载的一般说明中的SageMaker Python 开发工具包文.

    • 选项 2:保存完整模型。

      if smp.rdp_rank() == 0: model_dict = model.state_dict(gather_to_rank0=True) # save the full model if smp.rank() == 0: smp.save( {"model_state_dict": model_dict}, "/checkpoint.pt", partial=False, )
      注意

      请考虑以下内容进行完整的检查点:

      • 如果你设置gather_to_rank0=True,除以外的所有级别0返回空字典。

      • 对于完整的检查点,你只能检查模型。目前不支持优化程序状态的完整检查点。

      • 完整模型只需要保存在smp.rank() == 0.

  5. 使用加载检查点smp.load(). 根据您在上一步中检查点的方式,选择以下两个选项之一:

    • 选项 1:加载部分检查点。

      checkpoint = smp.load("/checkpoint.pt", partial=True) model.load_state_dict(checkpoint["model_state_dict"], same_partition_load=False) optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

      您可以设置same_partition_load=Truemodel.load_state_dict()为了更快地加载,如果你知道分区不会改变。

    • 选项 2:加载完整的检查点。

      if smp.rdp_rank() == 0: checkpoint = smp.load("/checkpoint.pt", partial=False) model.load_state_dict(checkpoint["model_state_dict"])

      这些区域有:if smp.rdp_rank() == 0条件不是必需的,但它可以帮助避免不同之间的冗余装载MP_GROUP。张量并行性目前不支持完整的检查点优化器状态 dict。