LightGBM - 亚马逊 SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

LightGBM

LightGBM 是梯度提升决策树 (GBDT) 算法的流行而高效的开源实现。GBDT 是一种监督式学习算法,它试图通过组合一组更简单和更弱的模型的估计值来准确预测目标变量。LightGBM 使用其他技术来显著提高传统 GBDT 的效率和可扩展性。

如何使用 SageMaker LightGBM

你可以使用 LightGBM 作为亚马逊的 SageMaker 内置算法。以下部分介绍如何在 SageMaker Phon 软件开发工具中使用 LightGBM。有关如何从亚马逊 SageMaker Studio 用户界面使用 LightGBM 的信息,请参阅SageMaker JumpStart

  • 使用 LightGBM 作为内置算法

    使用 LightGBM 内置算法构建 LightGBM 训练容器,如以下代码示例所示。你可以使用 API(如果使用亚马逊 SageMaker Python SDK 版本 2,则为 SageMaker image_uris.retrieve API)自动发现 LightGBM 内置算法的图像 URget_image_uri I。

    指定 LightGBM 图像 URI 后,您可以使用 LightGBM 容器使用 Estimator API 构造 SageMaker 估算器并启动训练作业。LightGBM 内置算法在脚本模式下运行,但训练脚本是为你提供的,无需替换。如果您在使用脚本模式创建 SageMaker 训练作业方面有丰富的经验,则可以合并自己的 LightGBM 训练脚本。

    from sagemaker import image_uris, model_uris, script_uris train_model_id, train_model_version, train_scope = "lightgbm-classification-model", "*", "training" training_instance_type = "ml.m5.xlarge" # Retrieve the docker image train_image_uri = image_uris.retrieve( region=None, framework=None, model_id=train_model_id, model_version=train_model_version, image_scope=train_scope, instance_type=training_instance_type ) # Retrieve the training script train_source_uri = script_uris.retrieve( model_id=train_model_id, model_version=train_model_version, script_scope=train_scope ) train_model_uri = model_uris.retrieve( model_id=train_model_id, model_version=train_model_version, model_scope=train_scope ) # Sample training data is available in this bucket training_data_bucket = f"jumpstart-cache-prod-{aws_region}" training_data_prefix = "training-datasets/tabular_multiclass/" training_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}/train" validation_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}/validation" output_bucket = sess.default_bucket() output_prefix = "jumpstart-example-tabular-training" s3_output_location = f"s3://{output_bucket}/{output_prefix}/output" from sagemaker import hyperparameters # Retrieve the default hyperparameters for training the model hyperparameters = hyperparameters.retrieve_default( model_id=train_model_id, model_version=train_model_version ) # [Optional] Override default hyperparameters with custom values hyperparameters[ "num_boost_round" ] = "500" print(hyperparameters) from sagemaker.estimator import Estimator from sagemaker.utils import name_from_base training_job_name = name_from_base(f"built-in-algo-{train_model_id}-training") # Create SageMaker Estimator instance tabular_estimator = Estimator( role=aws_role, image_uri=train_image_uri, source_dir=train_source_uri, model_uri=train_model_uri, entry_point="transfer_learning.py", instance_count=1, # for distributed training, specify an instance_count greater than 1 instance_type=training_instance_type, max_run=360000, hyperparameters=hyperparameters, output_path=s3_output_location ) # Launch a SageMaker Training job by passing the S3 path of the training data tabular_estimator.fit( { "train": training_dataset_s3_path, "validation": validation_dataset_s3_path, }, logs=True, job_name=training_job_name )

    有关如何将 LightGBM 设置为内置算法的更多信息,请参阅以下笔记本。

LightGBM 算法的输入和输出接口

梯度提升对表格数据进行操作,其中行表示观察、一个列表示目标变量或标签,其余列表示特征。

LightGBM 的 SageMaker 实现支持 CSV 用于训练和推理:

  • 对于培训 ContentType,有效输入必须是文本/csv

  • 对于推理 ContentType,有效的输入必须是文本/csv

注意

对于 CSV 训练,算法假定目标变量在第一列中,而 CSV 没有标头记录。

对于 CSV 推理,算法假定 CSV 输入没有标签列。

训练数据、验证数据和分类特征的输入格式

注意如何格式化训练数据以输入 LightGBM 模型。您必须提供包含训练和验证数据的 Simple Storage(Amazon S3)存储桶的路径。您还可以包含分类要素列表。使用trainvalidation通道提供您的输入数据。或者,您只能使用train频道。

注意

train和都training是 LightGBM 训练的有效频道名称。

同时使用trainvalidation频道

您可以通过两条 S3 路径提供输入数据,一条用于train频道,一条用于validation频道。每个 S3 路径可以是指向一个或多个 CSV 文件的 S3 前缀,也可以是指向一个特定 CSV 文件的完整 S3 路径。目标变量应位于 CSV 文件的第一列中。预测变量(特征)应位于其余列中。如果为trainvalidation频道提供了多个 CSV 文件,LightGBM 算法会将这些文件连接起来。验证数据用于在每次提升迭代结束时计算验证分数。当验证分数停止提高时,将应用提前停止。

如果您的预测变量包含分类要素,则可以提供一个与您的一个或多个训练数据文件categorical_index.json在相同位置命名的 JSON 文件。如果您为分类功能提供 JSON 文件,则您的train频道必须指向 S3 前缀,而不是特定的 CSV 文件。此文件应包含 Python 字典,其中键是字符串"cat_index_list",值是唯一整数列表。值列表中的每个整数都应表示训练数据 CSV 文件中相应类别特征的列索引。每个值都应为正整数(大于零,因为零代表目标值),小于Int32.MaxValue (2147483647),小于总列数。应该只有一个类别索引 JSON 文件。

仅使用train频道

您也可以通过train频道的单个 S3 路径提供输入数据。此 S3 路径应指向包含一个名为的子目录的目录train/,该目录包含一个或多个 CSV 文件。您可以选择在名为的相同位置添加另一个子目录validation/,该子目录也包含一个或多个 CSV 文件。如果未提供验证数据,则随机抽取 20% 的训练数据作为验证数据。如果您的预测变量包含分类要素,则可以提供与数据子目录categorical_index.json在相同位置命名的 JSON 文件。

注意

对于 CSV 训练输入模式,算法可用的总内存(实例数乘以中的可用内存InstanceType)必须能够容纳训练数据集。

SageMaker LightGBM 使用 Python Joblib 模块对模型进行序列化或反序列化,该模型可用于保存或加载模型。

在 JobLib 模块中使用使用 SageMaker LightGBM 训练的模型
  • 使用以下 Python 代码:

    import joblib import tarfile t = tarfile.open('model.tar.gz', 'r:gz') t.extractall() model = joblib.load(model_file_path) # prediction with test data # dtest should be a pandas DataFrame with column names feature_0, feature_1, ..., feature_d pred = model.predict(dtest)

Amazon EC2 实例推荐

SageMaker LightGBM 目前支持单实例和多实例 CPU 训练。对于多实例 CPU 训练(分布式训练),在定义 Estimator 时指定instance_count大于 1。有关使用 LightGBM 进行分布式训练的更多信息,请参阅使用 Dask 的亚马逊 SageMaker LightGBM 分布式训练

LightGBM 是一种内存密集型算法(而不是计算密集型)算法。因此,通用计算实例(例如 M5)比计算优化实例(例如 C5)更好。此外,我们建议您在选定的实例中有足够的总内存来保存训练数据。

LightGBM 样本笔记本

下表概述了各种示例笔记本电脑,这些笔记本电脑涉及 Amazon SageMaker LightGBM 算法的不同用例。

笔记本标题 描述

使用亚马逊 SageMaker LightGBM 和 CatBoost 算法进行表格分类

本笔记本演示了如何使用 Amazon SageMaker LightGBM 算法来训练和托管表格分类模型。

使用亚马逊 SageMaker LightGBM 和 CatBoost 算法进行表格回归

本笔记本演示了如何使用 Amazon SageMaker LightGBM 算法来训练和托管表格回归模型。

亚马逊 SageMaker LightGBM 使用 Dask 进行分布式训练

本笔记本演示了使用 Dask 框架使用 Amazon SageMaker LightGBM 算法进行分布式训练。

有关如何创建和访问可用于运行示例的 Jupyter 笔记本实例的说明 SageMaker,请参阅亚马逊 SageMaker 笔记本实例。创建并打开笔记本实例后,选择 “SageMaker示例” 选项卡以查看所有 SageMaker 示例的列表。要打开笔记本,请选择其 Use (使用) 选项卡,然后选择 Create copy (创建副本)