TabTransformer 超参数 - 亚马逊 SageMaker AI
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

TabTransformer 超参数

下表包含 Amazon SageMaker AI TabTransformer 算法必需或最常用的部分超参数。用户可以设置这些参数,以便于从数据中估算模型参数。SageMaker AI TabTransformer 算法是开源 TabTransformer 包的实施。

注意

默认超参数基于 TabTransformer 示例笔记本中的示例数据集。

SageMaker AI TabTransformer 算法会根据分类问题的类型,自动选择评估指标和目标函数。TabTransformer 算法根据数据中的标签数量来检测分类问题的类型。对于回归问题,评估指标为 r 平方,目标函数为均方误差。对于二元分类问题,评估指标和目标函数都是二元交叉熵。对于多元分类问题,评估指标和目标函数都是二元交叉熵。

注意

TabTransformer 评估指标和目标函数目前不能作为超参数使用。而是由 SageMaker AI TabTransformer 内置算法根据标签列中唯一整数的数量,自动检测分类任务的类型(回归、二元或多元),并分配评估指标和目标函数。

参数名称 描述
n_epochs

训练深度神经网络的纪元数。

有效值:整数,范围:正整数。

默认值:5

patience

如果在过去的 patience 轮中,某个验证数据点的某个指标没有改善,则训练将停止。

有效值:整数,范围:(260)。

默认值:10

learning_rate

完成每批训练样本后,更新模型权重的速率。

有效值:浮点型,范围:正浮点数。

默认值:0.001

batch_size

通过网络传播的示例数量。

有效值:整数,范围:(1, 2048)。

默认值:256

input_dim

用于对类别和/或连续列进行编码的嵌入的维度。

有效值:字符串,以下任意值:"16""32""64""128""256""512"

默认值:"32"

n_blocks

转换器编码器块的数量。

有效值:整数,范围:(1, 12)。

默认值:4

attn_dropout

应用于多头注意力层的丢弃比率。

有效值:浮点型,范围:(01)。

默认值:0.2

mlp_dropout

应用于转换器编码器上的编码器层以及最终 MLP 层内的前馈网络的丢弃比率。

有效值:浮点型,范围:(01)。

默认值:0.1

frac_shared_embed

一个特定列的所有不同类别共享的嵌入的比例。

有效值:浮点型,范围:(01)。

默认值:0.25