调整 TensorFlow 训练脚本 - 亚马逊 SageMaker AI
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

调整 TensorFlow 训练脚本

要开始收集模型输出张量并调试训练问题,请在 TensorFlow 训练脚本中进行以下修改。

在 SageMaker AI 中为训练作业创建钩子

import smdebug.tensorflow as smd hook=smd.get_hook(hook_type="keras", create_if_not_exists=True)

当您启动 SageMaker 训练作业时,这会创建一个钩子。当您在使用 SageMaker Python SDK Debugger 启动训练作业中,在估算器中使用任意 DebuggerHookConfigTensorBoardConfigRules 启动训练作业时,SageMaker AI 将 JSON 配置文件添加到由 smd.get_hook 方法选取的训练实例中。请注意,如果您的估算器中未包含任何配置 API,则不会有配置文件可供钩子查找,并且函数会返回 None

(可选)为 SageMaker AI 之外的训练作业创建钩子

如果您在本地模式下,直接在 SageMaker Notebook 实例、Amazon EC2 实例或您自己的本地设备上运行训练作业,请使用 smd.Hook 类来创建钩子。但是,这种方法只能存储张量集合,并且可用于 TensorBoard 的可视化。SageMaker Debugger 的内置规则不可用于本地模式。在这种情况下,smd.get_hook 方法也会返回 None

如果您要创建手动钩子,请使用以下带有逻辑的代码片段来检查钩子是否返回 None,并使用 smd.Hook 类创建手动钩子。

import smdebug.tensorflow as smd hook=smd.get_hook(hook_type="keras", create_if_not_exists=True) if hook is None: hook=smd.KerasHook( out_dir='/path/to/your/local/output/', export_tensorboard=True )

添加钩子创建代码后,请继续到以下有关 TensorFlow Keras 的主题。

注意

SageMaker Debugger 目前仅支持 TensorFlow Keras。

在您的 TensorFlow Keras 训练脚本中注册钩子

以下过程介绍如何使用钩子及其方法,从模型和优化器中收集输出标量和张量。

  1. 用钩子的类方法包装您的 Keras 模型和优化器。

    hook.register_model() 方法获取您的模型并遍历每一层,寻找与您通过 使用 SageMaker Python SDK Debugger 启动训练作业 中配置提供的正则表达式匹配的任何张量。通过这种钩子方法可以收集到的张量包括权重、偏差和激活。

    model=tf.keras.Model(...) hook.register_model(model)
  2. hook.wrap_optimizer() 方法包装优化器。

    optimizer=tf.keras.optimizers.Adam(...) optimizer=hook.wrap_optimizer(optimizer)
  3. 在 TensorFlow 中以急切模式编译模型。

    要从模型中收集张量,例如每层的输入和输出张量,必须在急切模式下运行训练。否则,SageMaker AI Debugger 将无法收集张量。但是,模型权重、偏差和损失等其他张量,无需在急切模式下运行即可收集。

    model.compile( loss="categorical_crossentropy", optimizer=optimizer, metrics=["accuracy"], # Required for collecting tensors of each layer run_eagerly=True )
  4. 将钩子注册到 tf.keras.Model.fit() 方法。

    要从您注册的钩子中收集张量,请将 callbacks=[hook] 添加到 Keras model.fit() 类方法中。这会将 sagemaker-debugger 钩子作为 Keras 回调传递。

    model.fit( X_train, Y_train, batch_size=batch_size, epochs=epoch, validation_data=(X_valid, Y_valid), shuffle=True, callbacks=[hook] )
  5. TensorFlow 2.x 仅提供无法访问其值的符号梯度变量。要收集梯度,请使用 hook.wrap_tape() 方法包装 tf.GradientTape,这要求您如下所示编写自己的训练步骤。

    def training_step(model, dataset): with hook.wrap_tape(tf.GradientTape()) as tape: pred=model(data) loss_value=loss_fn(labels, pred) grads=tape.gradient(loss_value, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables))

    通过对磁带进行包装,sagemaker-debugger 钩子可以识别输出张量,例如梯度、参数和损失。对磁带进行包装可确保围绕磁带对象函数(例如,push_tape()pop_tape()gradient())的 hook.wrap_tape() 方法将设置 SageMaker Debugger 的编写器,并将提供的张量保存作为 gradient() 的输入(可训练的变量和损失)以及 gradient() 的输出(梯度)。

    注意

    要使用自定义训练循环进行收集,请确保使用急切模式。否则,SageMaker Debugger 将无法收集任何张量。

有关 sagemaker-debugger 钩子 API 提供的用于构造钩子和保存张量的完整操作列表,请参阅 sagemaker-debugger Python SDK 文档中的钩子方法

调整完训练脚本后,继续到 使用 SageMaker Python SDK Debugger 启动训练作业