使用 smdebug 客户端库以 Python 脚本创建自定义规则 - 亚马逊 SageMaker AI
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

使用 smdebug 客户端库以 Python 脚本创建自定义规则

smdebug 规则 API 提供了一个接口,用于设置自己的自定义规则。以下 Python 脚本示例演示了如何构造自定义规则 CustomGradientRule。本教程的自定义规则监控梯度变是否太大并将默认阈值设置为 10。自定义规则采用 A SageMaker I 估算器在启动训练作业时创建的基础试验。

from smdebug.rules.rule import Rule class CustomGradientRule(Rule): def __init__(self, base_trial, threshold=10.0): super().__init__(base_trial) self.threshold = float(threshold) def invoke_at_step(self, step): for tname in self.base_trial.tensor_names(collection="gradients"): t = self.base_trial.tensor(tname) abs_mean = t.reduction_value(step, "mean", abs=True) if abs_mean > self.threshold: return True return False

您可以在同一个 python 脚本中按需要添加任意数量的自定义规则类,并通过在下个部分中构造自定义规则对象,来将它们部署到任何训练作业试验中。