本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
LightGBM
LightGBM
如何使用 SageMaker LightGBM
你可以使用 LightGBM 作为亚马逊的 SageMaker 内置算法。以下部分介绍如何在 SageMaker Phon 软件开发工具中使用 LightGBM。有关如何从亚马逊 SageMaker Studio 用户界面使用 LightGBM 的信息,请参阅SageMaker JumpStart。
-
使用 LightGBM 作为内置算法
使用 LightGBM 内置算法构建 LightGBM 训练容器,如以下代码示例所示。你可以使用 API(如果使用亚马逊 SageMaker Python SDK
版本 2,则为 SageMaker image_uris.retrieve
API)自动发现 LightGBM 内置算法的图像 URget_image_uri
I。指定 LightGBM 图像 URI 后,您可以使用 LightGBM 容器使用 Estimator API 构造 SageMaker 估算器并启动训练作业。LightGBM 内置算法在脚本模式下运行,但训练脚本是为你提供的,无需替换。如果您在使用脚本模式创建 SageMaker 训练作业方面有丰富的经验,则可以合并自己的 LightGBM 训练脚本。
from sagemaker import image_uris, model_uris, script_uris train_model_id, train_model_version, train_scope = "lightgbm-classification-model", "*", "training" training_instance_type = "ml.m5.xlarge" # Retrieve the docker image train_image_uri = image_uris.retrieve( region=None, framework=None, model_id=train_model_id, model_version=train_model_version, image_scope=train_scope, instance_type=training_instance_type ) # Retrieve the training script train_source_uri = script_uris.retrieve( model_id=train_model_id, model_version=train_model_version, script_scope=train_scope ) train_model_uri = model_uris.retrieve( model_id=train_model_id, model_version=train_model_version, model_scope=train_scope ) # Sample training data is available in this bucket training_data_bucket = f"jumpstart-cache-prod-{aws_region}" training_data_prefix = "training-datasets/tabular_multiclass/" training_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}/train" validation_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}/validation" output_bucket = sess.default_bucket() output_prefix = "jumpstart-example-tabular-training" s3_output_location = f"s3://{output_bucket}/{output_prefix}/output" from sagemaker import hyperparameters # Retrieve the default hyperparameters for training the model hyperparameters = hyperparameters.retrieve_default( model_id=train_model_id, model_version=train_model_version ) # [Optional] Override default hyperparameters with custom values hyperparameters[ "num_boost_round" ] = "500" print(hyperparameters) from sagemaker.estimator import Estimator from sagemaker.utils import name_from_base training_job_name = name_from_base(f"built-in-algo-{train_model_id}-training") # Create SageMaker Estimator instance tabular_estimator = Estimator( role=aws_role, image_uri=train_image_uri, source_dir=train_source_uri, model_uri=train_model_uri, entry_point="transfer_learning.py", instance_count=1, # for distributed training, specify an instance_count greater than 1 instance_type=training_instance_type, max_run=360000, hyperparameters=hyperparameters, output_path=s3_output_location ) # Launch a SageMaker Training job by passing the S3 path of the training data tabular_estimator.fit( { "train": training_dataset_s3_path, "validation": validation_dataset_s3_path, }, logs=True, job_name=training_job_name )
有关如何将 LightGBM 设置为内置算法的更多信息,请参阅以下笔记本。
LightGBM 算法的输入和输出接口
梯度提升对表格数据进行操作,其中行表示观察、一个列表示目标变量或标签,其余列表示特征。
LightGBM 的 SageMaker 实现支持 CSV 用于训练和推理:
-
对于培训 ContentType,有效输入必须是文本/csv。
-
对于推理 ContentType,有效的输入必须是文本/csv。
注意
对于 CSV 训练,算法假定目标变量在第一列中,而 CSV 没有标头记录。
对于 CSV 推理,算法假定 CSV 输入没有标签列。
训练数据、验证数据和分类特征的输入格式
注意如何格式化训练数据以输入 LightGBM 模型。您必须提供包含训练和验证数据的 Simple Storage(Amazon S3)存储桶的路径。您还可以包含分类要素列表。使用train
和validation
通道提供您的输入数据。或者,您只能使用train
频道。
注意
train
和都training
是 LightGBM 训练的有效频道名称。
同时使用train
和validation
频道
您可以通过两条 S3 路径提供输入数据,一条用于train
频道,一条用于validation
频道。每个 S3 路径可以是指向一个或多个 CSV 文件的 S3 前缀,也可以是指向一个特定 CSV 文件的完整 S3 路径。目标变量应位于 CSV 文件的第一列中。预测变量(特征)应位于其余列中。如果为train
或validation
频道提供了多个 CSV 文件,LightGBM 算法会将这些文件连接起来。验证数据用于在每次提升迭代结束时计算验证分数。当验证分数停止提高时,将应用提前停止。
如果您的预测变量包含分类要素,则可以提供一个与您的一个或多个训练数据文件categorical_index.json
在相同位置命名的 JSON 文件。如果您为分类功能提供 JSON 文件,则您的train
频道必须指向 S3 前缀,而不是特定的 CSV 文件。此文件应包含 Python 字典,其中键是字符串"cat_index_list"
,值是唯一整数列表。值列表中的每个整数都应表示训练数据 CSV 文件中相应类别特征的列索引。每个值都应为正整数(大于零,因为零代表目标值),小于Int32.MaxValue
(2147483647),小于总列数。应该只有一个类别索引 JSON 文件。
仅使用train
频道:
您也可以通过train
频道的单个 S3 路径提供输入数据。此 S3 路径应指向包含一个名为的子目录的目录train/
,该目录包含一个或多个 CSV 文件。您可以选择在名为的相同位置添加另一个子目录validation/
,该子目录也包含一个或多个 CSV 文件。如果未提供验证数据,则随机抽取 20% 的训练数据作为验证数据。如果您的预测变量包含分类要素,则可以提供与数据子目录categorical_index.json
在相同位置命名的 JSON 文件。
注意
对于 CSV 训练输入模式,算法可用的总内存(实例数乘以中的可用内存InstanceType
)必须能够容纳训练数据集。
SageMaker LightGBM 使用 Python Joblib 模块对模型进行序列化或反序列化,该模型可用于保存或加载模型。
在 JobLib 模块中使用使用 SageMaker LightGBM 训练的模型
-
使用以下 Python 代码:
import joblib import tarfile t = tarfile.open('model.tar.gz', 'r:gz') t.extractall() model = joblib.load(
model_file_path
) # prediction with test data # dtest should be a pandas DataFrame with column names feature_0, feature_1, ..., feature_d pred = model.predict(dtest
)
Amazon EC2 实例推荐
SageMaker LightGBM 目前支持单实例和多实例 CPU 训练。对于多实例 CPU 训练(分布式训练),在定义 Estimator 时指定instance_count
大于 1。有关使用 LightGBM 进行分布式训练的更多信息,请参阅使用 Dask 的亚马逊 SageMaker LightGBM 分布式训练
LightGBM 是一种内存密集型算法(而不是计算密集型)算法。因此,通用计算实例(例如 M5)比计算优化实例(例如 C5)更好。此外,我们建议您在选定的实例中有足够的总内存来保存训练数据。
LightGBM 样本笔记本
下表概述了各种示例笔记本电脑,这些笔记本电脑涉及 Amazon SageMaker LightGBM 算法的不同用例。
笔记本标题 | 描述 |
---|---|
本笔记本演示了如何使用 Amazon SageMaker LightGBM 算法来训练和托管表格分类模型。 |
|
本笔记本演示了如何使用 Amazon SageMaker LightGBM 算法来训练和托管表格回归模型。 |
|
本笔记本演示了使用 Dask 框架使用 Amazon SageMaker LightGBM 算法进行分布式训练。 |
有关如何创建和访问可用于运行示例的 Jupyter 笔记本实例的说明 SageMaker,请参阅亚马逊 SageMaker 笔记本实例。创建并打开笔记本实例后,选择 “SageMaker示例” 选项卡以查看所有 SageMaker 示例的列表。要打开笔记本,请选择其 Use (使用) 选项卡,然后选择 Create copy (创建副本)。