CatBoost - 亚马逊 SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

CatBoost

CatBoost是梯度提升决策树 (GBDT) 算法的流行且高性能的开源实现。GBDT 是一种监督式学习算法,它试图通过组合一组更简单和更弱的模型的估计值来准确预测目标变量。

CatBoost 介绍了 GBDT 的两项关键算法进步:

  1. 实现有序提升,这是传统算法的置换驱动替代方案

  2. 一种用于处理分类特征的创新算法

这两种技术都是为了对抗当前所有现有的梯度提升算法实现中存在的特殊目标泄漏引起的预测偏移。

如何使用 SageMaker CatBoost

您可以用 CatBoost 作亚马逊的 SageMaker 内置算法。以下部分介绍如何与 SageMaker Python SDK CatBoost 一起使用。有关如何在亚马逊 SageMaker Studio 用户界面 CatBoost 中使用的信息,请参阅SageMaker JumpStart

  • 用 CatBoost 作内置算法

    使用 CatBoost 内置算法构建 CatBoost 训练容器,如以下代码示例所示。您可以使用 SageMaker image_uris.retrieve API(如果使用亚马逊 SageMaker Python SDK 版本 2,则get_image_uri为 API)自动发现 CatBoost内置算法图像 URI。

    指定 CatBoost 图像 URI 后,您可以使用 CatBoost 容器使用 Estimator API 构造 SageMaker 估算器并启动训练作业。 CatBoost 内置算法在脚本模式下运行,但训练脚本是为您提供的,无需替换。如果您在使用脚本模式创建 SageMaker 训练作业方面有丰富的经验,则可以合并自己的 CatBoost 训练脚本。

    from sagemaker import image_uris, model_uris, script_uris train_model_id, train_model_version, train_scope = "catboost-classification-model", "*", "training" training_instance_type = "ml.m5.xlarge" # Retrieve the docker image train_image_uri = image_uris.retrieve( region=None, framework=None, model_id=train_model_id, model_version=train_model_version, image_scope=train_scope, instance_type=training_instance_type ) # Retrieve the training script train_source_uri = script_uris.retrieve( model_id=train_model_id, model_version=train_model_version, script_scope=train_scope ) train_model_uri = model_uris.retrieve( model_id=train_model_id, model_version=train_model_version, model_scope=train_scope ) # Sample training data is available in this bucket training_data_bucket = f"jumpstart-cache-prod-{aws_region}" training_data_prefix = "training-datasets/tabular_multiclass/" training_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}/train" validation_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}/validation" output_bucket = sess.default_bucket() output_prefix = "jumpstart-example-tabular-training" s3_output_location = f"s3://{output_bucket}/{output_prefix}/output" from sagemaker import hyperparameters # Retrieve the default hyperparameters for training the model hyperparameters = hyperparameters.retrieve_default( model_id=train_model_id, model_version=train_model_version ) # [Optional] Override default hyperparameters with custom values hyperparameters[ "iterations" ] = "500" print(hyperparameters) from sagemaker.estimator import Estimator from sagemaker.utils import name_from_base training_job_name = name_from_base(f"built-in-algo-{train_model_id}-training") # Create SageMaker Estimator instance tabular_estimator = Estimator( role=aws_role, image_uri=train_image_uri, source_dir=train_source_uri, model_uri=train_model_uri, entry_point="transfer_learning.py", instance_count=1, instance_type=training_instance_type, max_run=360000, hyperparameters=hyperparameters, output_path=s3_output_location ) # Launch a SageMaker Training job by passing the S3 path of the training data tabular_estimator.fit( { "training": training_dataset_s3_path, "validation": validation_dataset_s3_path, }, logs=True, job_name=training_job_name )

    有关如何设置算法的更多信息 CatBoost ,请参阅以下笔记本电脑示例。

CatBoost算法的输入和输出接口

梯度提升对表格数据进行操作,其中行表示观察、一个列表示目标变量或标签,其余列表示特征。

的 SageMaker 实现 CatBoost 支持 CSV 用于训练和推理:

  • 对于培训 ContentType,有效输入必须是文本/csv

  • 对于推理 ContentType,有效的输入必须是文本/csv

注意

对于 CSV 训练,算法假定目标变量在第一列中,而 CSV 没有标头记录。

对于 CSV 推理,算法假定 CSV 输入没有标签列。

训练数据、验证数据和分类特征的输入格式

注意如何格式化训练数据以输入到 CatBoost 模型中。您必须提供包含训练和验证数据的路径。您还可以包含分类要素列表。使用trainingvalidation通道提供您的输入数据。或者,您只能使用training频道。

同时使用trainingvalidation频道

您可以通过两条 S3 路径提供输入数据,一条用于training频道,一条用于validation频道。每个 S3 路径可以是指向一个或多个 CSV 文件的 S3 前缀,也可以是指向一个特定 CSV 文件的完整 S3 路径。目标变量应位于 CSV 文件的第一列中。预测变量(特征)应位于其余列中。如果为trainingvalidation频道提供了多个 CSV 文件,则 CatBoost 算法会连接这些文件。验证数据用于在每次提升迭代结束时计算验证分数。当验证分数停止提高时,将应用提前停止。

如果您的预测变量包含分类要素,则可以提供一个与您的一个或多个训练数据文件categorical_index.json在相同位置命名的 JSON 文件。如果您为分类功能提供 JSON 文件,则您的training频道必须指向 S3 前缀,而不是特定的 CSV 文件。此文件应包含 Python 字典,其中键是字符串"cat_index_list",值是唯一整数列表。值列表中的每个整数都应表示训练数据 CSV 文件中相应类别特征的列索引。每个值都应为正整数(大于零,因为零代表目标值),小于Int32.MaxValue (2147483647),小于总列数。应该只有一个类别索引 JON 文件。

仅使用training频道

您也可以通过training频道的单个 S3 路径提供输入数据。此 S3 路径应指向包含一个名为的子目录的目录training/,该目录包含一个或多个 CSV 文件。您可以选择在名为的相同位置添加另一个子目录validation/,该子目录也包含一个或多个 CSV 文件。如果未提供验证数据,则随机抽取 20% 的训练数据作为验证数据。如果您的预测变量包含分类要素,则可以提供与数据子目录categorical_index.json在相同位置命名的 JSON 文件。

注意

对于 CSV 训练输入模式,算法可用的总内存(实例数乘以中的可用内存InstanceType)必须能够容纳训练数据集。

SageMaker CatBoost 使用catboost.CatBoostClassifiercatboost.CatBoostRegressor模块序列化或反序列化模型,可用于保存或加载模型。

使用使用训练过 SageMaker CatBoost 的模型catboost
  • 使用以下 Python 代码:

    import tarfile from catboost import CatBoostClassifier t = tarfile.open('model.tar.gz', 'r:gz') t.extractall() file_path = os.path.join(model_file_path, "model") model = CatBoostClassifier() model.load_model(file_path) # prediction with test data # dtest should be a pandas DataFrame with column names feature_0, feature_1, ..., feature_d pred = model.predict(dtest)

CatBoost算法的Amazon EC2 实例推荐

SageMaker CatBoost 目前仅使用 CPU 进行列车。 CatBoost 是一种内存密集型算法(而不是计算密集型)算法。因此,通用计算实例(例如 M5)比计算优化实例(例如 C5)更好。此外,我们建议您在选定的实例中有足够的总内存来保存训练数据。

CatBoost 样本笔记本

下表概述了针对亚马逊 SageMaker CatBoost 算法不同用例的各种示例笔记本。

笔记本标题 描述

使用亚马逊 SageMaker LightGBM 和 CatBoost 算法进行表格分类

本笔记本演示了如何使用 Amazon SageMaker CatBoost 算法来训练和托管表格分类模型。

使用亚马逊 SageMaker LightGBM 和 CatBoost 算法进行表格回归

本笔记本演示了如何使用 Amazon SageMaker CatBoost 算法来训练和托管表格回归模型。

有关如何创建和访问可用于运行示例的 Jupyter 笔记本实例的说明 SageMaker,请参阅亚马逊 SageMaker 笔记本实例。创建并打开笔记本实例后,选择 “SageMaker示例” 选项卡以查看所有 SageMaker 示例的列表。要打开笔记本,请选择其 Use (使用) 选项卡,然后选择 Create copy (创建副本)