教程:构建 XGBoost 模型 - Amazon Redshift
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

教程:构建 XGBoost 模型

在本教程中,您使用来自 Amazon S3 的数据创建模型,并使用 Amazon Redshift ML 对模型运行预测查询。XGBoost 算法是梯度提升树算法的一种优化实施。与其他梯度提升树算法相比,XGBoost 处理的数据类型和关系更多,数据分布更广泛。您可以使用 XGBoost 来处理回归、二进制分类、多类别分类以及排名问题。有关 XGBoost 算法的更多信息,请参阅《Amazon SageMaker 开发人员指南》中的 XGBoost 算法

带有 AUTO OFF 选项的 Amazon Redshift ML CREATE MODEL 操作当前支持将 XGBoost 作为 MODEL_TYPE。根据您的使用案例,您可以在 CREATE MODEL 命令中提供相关信息,例如目标和超参数。

在本教程中,您将使用钞票验证数据集,这是一个二进制分类问题,用于预测给定的钞票是真钞还是假钞。

使用案例示例

您可以使用 Amazon Redshift ML 解决其他二进制分类问题,例如预测接受治疗者是健康的还是患有疾病。您还可以预测电子邮件是否为垃圾邮件。

任务

  • 先决条件

  • 步骤 1:将数据从 Amazon S3 加载到 Amazon Redshift

  • 步骤 2:创建机器学习模型

  • 步骤 3:使用模型执行预测

先决条件

要完成此教程,必须完成 Amazon Redshift ML 的管理设置

步骤 1:将数据从 Amazon S3 加载到 Amazon Redshift

使用 Amazon Redshift 查询器 v2 运行以下查询。

以下查询创建两个表,从 Amazon S3 加载数据,然后将数据拆分为训练集和测试集。您将使用训练集来训练模型并创建预测函数。然后,您将在测试集上测试预测函数。

--create training set table CREATE TABLE banknoteauthentication_train( variance FLOAT, skewness FLOAT, curtosis FLOAT, entropy FLOAT, class INT ); --Load into training table COPY banknoteauthentication_train FROM 's3://redshiftbucket-ml-sagemaker/banknote_authentication/train_data/' IAM_ROLE default REGION 'us-west-2' IGNOREHEADER 1 CSV; --create testing set table CREATE TABLE banknoteauthentication_test( variance FLOAT, skewness FLOAT, curtosis FLOAT, entropy FLOAT, class INT ); --Load data into testing table COPY banknoteauthentication_test FROM 's3://redshiftbucket-ml-sagemaker/banknote_authentication/test_data/' IAM_ROLE default REGION 'us-west-2' IGNOREHEADER 1 CSV;

步骤 2:创建机器学习模型

以下查询根据您在上一步中创建的训练集,在 Amazon Redshift ML 中创建 XGBoost 模型。将 DOC-EXAMPLE-BUCKET 替换为您自己的 S3_BUCKET,它将存储您的输入数据集和其他 Redshift ML 构件。

CREATE MODEL model_banknoteauthentication_xgboost_binary FROM banknoteauthentication_train TARGET class FUNCTION func_model_banknoteauthentication_xgboost_binary IAM_ROLE default AUTO OFF MODEL_TYPE xgboost OBJECTIVE 'binary:logistic' PREPROCESSORS 'none' HYPERPARAMETERS DEFAULT EXCEPT(NUM_ROUND '100') SETTINGS(S3_BUCKET '<DOC-EXAMPLE-BUCKET>');

显示模型训练的状态(可选)

您可以使用 SHOW MODEL 命令来了解模型何时准备就绪。

使用以下查询监控模型训练的进度。

SHOW MODEL model_banknoteauthentication_xgboost_binary;

如果模型为 READY,则 SHOW MODEL 操作还提供 train:error 指标,如以下输出示例所示。train:error 指标用于衡量模型的准确性,精确到小数点后六位。值为 0 表示最准确,值为 1 表示最不准确。

+--------------------------+--------------------------------------------------+ | Model Name | model_banknoteauthentication_xgboost_binary | +--------------------------+--------------------------------------------------+ | Schema Name | public | | Owner | awsuser | | Creation Time | Tue, 21.06.2022 19:07:35 | | Model State | READY | | train:error | 0.000000 | | Estimated Cost | 0.006197 | | | | | TRAINING DATA: | | | Query | SELECT * | | | FROM "BANKNOTEAUTHENTICATION_TRAIN" | | Target Column | CLASS | | | | | PARAMETERS: | | | Model Type | xgboost | | Training Job Name | redshiftml-20220621190735686935-xgboost | | Function Name | func_model_banknoteauthentication_xgboost_binary | | Function Parameters | variance skewness curtosis entropy | | Function Parameter Types | float8 float8 float8 float8 | | IAM Role | default-aws-iam-role | | S3 Bucket | DOC-EXAMPLE-BUCKET | | Max Runtime | 5400 | | | | | HYPERPARAMETERS: | | | num_round | 100 | | objective | binary:logistic | +--------------------------+--------------------------------------------------+

步骤 3:使用模型执行预测

检查模型的准确性

以下预测查询使用在上一步中创建的预测函数来检查模型的准确性。对测试集运行此查询,以确保模型与训练集的对应关系不会过于紧密。这种紧密的对应关系也称为过拟合,而过拟合可能导致模型做出不可靠的预测。

WITH predict_data AS ( SELECT class AS label, func_model_banknoteauthentication_xgboost_binary (variance, skewness, curtosis, entropy) AS predicted, CASE WHEN label IS NULL THEN 0 ELSE label END AS actual, CASE WHEN actual = predicted THEN 1 :: INT ELSE 0 :: INT END AS correct FROM banknoteauthentication_test ), aggr_data AS ( SELECT SUM(correct) AS num_correct, COUNT(*) AS total FROM predict_data ) SELECT (num_correct :: FLOAT / total :: FLOAT) AS accuracy FROM aggr_data;

预测真钞和假钞的数量

以下预测查询返回测试集中预测的真钞和假钞的数量。

WITH predict_data AS ( SELECT func_model_banknoteauthentication_xgboost_binary(variance, skewness, curtosis, entropy) AS predicted FROM banknoteauthentication_test ) SELECT CASE WHEN predicted = '0' THEN 'Original banknote' WHEN predicted = '1' THEN 'Counterfeit banknote' ELSE 'NA' END AS banknote_authentication, COUNT(1) AS count FROM predict_data GROUP BY 1;

找出真钞和假钞的平均观察值

以下预测查询返回测试集中预测为真钞和假钞的钞票的每个特征的平均值。

WITH predict_data AS ( SELECT func_model_banknoteauthentication_xgboost_binary(variance, skewness, curtosis, entropy) AS predicted, variance, skewness, curtosis, entropy FROM banknoteauthentication_test ) SELECT CASE WHEN predicted = '0' THEN 'Original banknote' WHEN predicted = '1' THEN 'Counterfeit banknote' ELSE 'NA' END AS banknote_authentication, TRUNC(AVG(variance), 2) AS avg_variance, TRUNC(AVG(skewness), 2) AS avg_skewness, TRUNC(AVG(curtosis), 2) AS avg_curtosis, TRUNC(AVG(entropy), 2) AS avg_entropy FROM predict_data GROUP BY 1 ORDER BY 2;

有关 Amazon Redshift ML 的更多信息,请参阅以下文档:

有关机器学习的更多信息,请参阅以下文档: