CatBoost - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

CatBoost

CatBoost 是梯度提升决策树 (GBDT) 算法的一种热门的开源实施,性能非常出色。GBDT 是一种有监督学习算法,它尝试将一组较简单且较弱模型的一系列估计值结合在一起,从而准确地预测目标变量。

CatBoost 向 GBDT 引入了两项关键的算法改进:

  1. 实施了有序提升,这是对经典算法的以排列驱动的替代方案

  2. 用于处理分类特征的创新算法

这两种技术都是为了对抗由一种特殊目标泄漏所引起的预测偏移,这种偏移存在于梯度提升算法的所有现有实施中。

如何使用 SageMaker CatBoost

您可以使用 CatBoost 作为 Amazon SageMaker 的内置算法。以下部分介绍如何将 CatBoost 与 SageMaker Python SDK 结合使用。有关如何通过 Amazon SageMaker Studio 用户界面使用 CatBoost 的信息,请参阅 SageMaker JumpStart

  • 使用 CatBoost 作为内置算法

    使用 CatBoost 内置算法构建 CatBoost 训练容器,如以下代码示例所示。您可以使用 SageMaker image_uris.retrieve API(如果使用 Amazon SageMaker Python SDK 版本 2,则为 get_image_uri API),自动发现 CatBoost 内置算法映像 URI。

    指定 CatBoost 映像 URI 后,您可以使用 CatBoost 容器,通过 SageMaker Estimator API 构造评估程序并启动训练作业。CatBoost 内置算法运行在脚本模式下,不过训练脚本是为您提供的,无需替换。如果您在使用脚本模式创建 SageMaker 训练作业方面有丰富的经验,您也可以合并自己的 CatBoost 训练脚本。

    from sagemaker import image_uris, model_uris, script_uris train_model_id, train_model_version, train_scope = "catboost-classification-model", "*", "training" training_instance_type = "ml.m5.xlarge" # Retrieve the docker image train_image_uri = image_uris.retrieve( region=None, framework=None, model_id=train_model_id, model_version=train_model_version, image_scope=train_scope, instance_type=training_instance_type ) # Retrieve the training script train_source_uri = script_uris.retrieve( model_id=train_model_id, model_version=train_model_version, script_scope=train_scope ) train_model_uri = model_uris.retrieve( model_id=train_model_id, model_version=train_model_version, model_scope=train_scope ) # Sample training data is available in this bucket training_data_bucket = f"jumpstart-cache-prod-{aws_region}" training_data_prefix = "training-datasets/tabular_multiclass/" training_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}/train" validation_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}/validation" output_bucket = sess.default_bucket() output_prefix = "jumpstart-example-tabular-training" s3_output_location = f"s3://{output_bucket}/{output_prefix}/output" from sagemaker import hyperparameters # Retrieve the default hyperparameters for training the model hyperparameters = hyperparameters.retrieve_default( model_id=train_model_id, model_version=train_model_version ) # [Optional] Override default hyperparameters with custom values hyperparameters[ "iterations" ] = "500" print(hyperparameters) from sagemaker.estimator import Estimator from sagemaker.utils import name_from_base training_job_name = name_from_base(f"built-in-algo-{train_model_id}-training") # Create SageMaker Estimator instance tabular_estimator = Estimator( role=aws_role, image_uri=train_image_uri, source_dir=train_source_uri, model_uri=train_model_uri, entry_point="transfer_learning.py", instance_count=1, instance_type=training_instance_type, max_run=360000, hyperparameters=hyperparameters, output_path=s3_output_location ) # Launch a SageMaker Training job by passing the S3 path of the training data tabular_estimator.fit( { "training": training_dataset_s3_path, "validation": validation_dataset_s3_path, }, logs=True, job_name=training_job_name )

    有关如何将 CatBoost 设置为内置算法的更多信息,请参阅以下笔记本示例。

CatBoost 算法的输入和输出接口

梯度提升对表格数据进行操作,其中行表示观察、一个列表示目标变量或标签,其余列表示特征。

CatBoost 的 SageMaker 实施支持使用 CSV 进行训练和推理:

  • 对于训练 ContentType,有效的输入必须是文本/csv

  • 对于推理 ContentType,有效的输入必须是文本/csv

注意

对于 CSV 训练,算法假定目标变量在第一列中,而 CSV 没有标头记录。

对于 CSV 推理,算法假定 CSV 输入没有标签列。

训练数据、验证数据和类别特征的输入格式

请注意如何对训练数据进行格式化,以便输入 CatBoost 模型。您必须提供包含训练和验证数据的 Amazon S3 存储桶的路径。您还可以包含类别特征列表。请使用 trainingvalidation 通道来提供您的输入数据。您也可以只使用 training 通道。

使用 trainingvalidation 通道

您可以通过两条 S3 路径来提供输入数据,一条用于 training 通道,一条用于 validation 通道。每个 S3 路径可以是指向一个或多个 CSV 文件的 S3 前缀,也可以是指向一个特定 CSV 文件的完整 S3 路径。目标变量应位于 CSV 文件的第一列。预测器变量(特征)应位于其余列。如果为 trainingvalidation 通道提供了多个 CSV 文件,则 CatBoost 算法会将这些文件连接起来。验证数据用于在每次提升迭代结束时计算验证分数。当验证分数停止提高时,将应用提前停止。

如果您的预测器包含类别特征,则可以在与您的训练数据文件相同的位置,提供一个名为 categorical_index.json 的 JSON 文件。如果您为类别特征提供 JSON 文件,则您的 training 通道必须指向 S3 前缀而不是特定 CSV 文件。此文件应包含一个 Python 字典,其中的键是字符串 "cat_index_list",值是唯一整数列表。值列表中的每个整数都应指示训练数据 CSV 文件中对应分类特征的列索引。每个值都应为正整数(大于零,因为零表示目标值),小于 Int32.MaxValue (2147483647),并且小于列的总数。只应有一个类别索引 JSON 文件。

仅使用 training 通道

您也可以通过单个 S3 路径,为 training 通道提供输入数据。此 S3 路径指向的目录中应包含一个名为 training/ 的子目录,而该子目录中包含一个或多个 CSV 文件。您可以选择在相同位置添加另一个名为 validation/ 的子目录,该子目录同样包含一个或多个 CSV 文件。如果未提供验证数据,则会随机采样 20% 的训练数据作为验证数据。如果您的预测器包含类别特征,则可以在与您的数据子目录相同的位置,提供一个名为 categorical_index.json 的 JSON 文件。

注意

对于 CSV 训练输入模式,供算法使用的内存总量(实例计数乘以 InstanceType 中的可用内存)必须能够容纳训练数据集。

SageMaker CatBoost 使用 catboost.CatBoostClassifiercatboost.CatBoostRegressor 模块来序列化或反序列化模型,这可用于保存或加载模型。

将通过 SageMaker CatBoost 训练过的模型与 catboost 结合使用
  • 使用以下 Python 代码:

    import tarfile from catboost import CatBoostClassifier t = tarfile.open('model.tar.gz', 'r:gz') t.extractall() file_path = os.path.join(model_file_path, "model") model = CatBoostClassifier() model.load_model(file_path) # prediction with test data # dtest should be a pandas DataFrame with column names feature_0, feature_1, ..., feature_d pred = model.predict(dtest)

适用于 CatBoost 算法的 Amazon EC2 实例推荐

SageMaker CatBoost 目前仅使用 CPU 进行训练。CatBoost 是一种内存限制型(而不是计算限制型)算法。因此,通用计算实例(例如 M5)是比计算优化型实例(例如 C5)更适合的选择。此外,我们建议您在选定的实例中有足够的总内存来保存训练数据。

CatBoost 示例笔记本

下表概述了解决 Amazon SageMaker CatBoost 算法的不同使用场景的各种示例笔记本。

笔记本标题 描述

使用 Amazon SageMaker LightGBM 和 CatBoost 算法进行表格分类

本笔记本演示了如何使用 Amazon SageMaker CatBoost 算法来训练和托管表格分类模型。

使用 Amazon SageMaker LightGBM 和 CatBoost 算法进行表格回归

本笔记本演示了如何使用 Amazon SageMaker CatBoost 算法来训练和托管表格回归模型。

有关如何创建和访问可用于在 SageMaker 中运行示例的 Jupyter 笔记本实例的说明,请参阅 Amazon SageMaker 笔记本实例。创建笔记本实例并将其打开后,选择 SageMaker 示例选项卡以查看所有 SageMaker 示例的列表。要打开笔记本,请选择其 Use (使用) 选项卡,然后选择 Create copy (创建副本)