本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
步骤 6:评估模型
现在,您已使用 训练和部署模型Amazon SageMaker,请评估模型以确保它对新数据生成准确的预测。对于模型评估,请使用您在中创建的测试数据集步骤 3:下载、浏览和转换数据集。
评估部署到 SageMaker 托管服务的模型
要评估模型并将其用于生产,请使用测试数据集调用终端节点,并检查推理是否返回要实现的目标准确性。
评估模型
-
设置以下函数以预测测试集的每一行。在以下示例代码中,
rows
参数指定要一次预测的行数。您可以更改其值以执行完全利用实例硬件资源的批处理推理。import numpy as np def predict(data, rows=1000): split_array = np.array_split(data, int(data.shape[0] / float(rows) + 1)) predictions = '' for array in split_array: predictions = ','.join([predictions, xgb_predictor.predict(array).decode('utf-8')]) return np.fromstring(predictions[1:], sep=',')
-
运行以下代码以预测测试数据集并绘制直方图。您只需获取测试数据集的特征列,不包括实际值0th 列。
import matplotlib.pyplot as plt predictions=predict(test.to_numpy()[:,1:]) plt.hist(predictions) plt.show()
-
预测值为浮点类型。要
True
根据浮点值确定False
或 ,您需要设置一个截断值。如以下示例代码中所示,使用 Scikit-learn 库返回截断值为 0.5 的输出混淆指标和分类报告。import sklearn cutoff=0.5 print(sklearn.metrics.confusion_matrix(test.iloc[:, 0], np.where(predictions > cutoff, 1, 0))) print(sklearn.metrics.classification_report(test.iloc[:, 0], np.where(predictions > cutoff, 1, 0)))
这应返回以下混淆矩阵:
-
要查找给定测试集的最佳截断值,请计算逻辑回归的日志损失函数。日志丢失函数的定义是返回其基本实际情况标签的预测概率的逻辑模型的负对数似然。以下示例代码的数值和迭代方式计算日志损失值 (
-(y*log(p)+(1-y)log(1-p)
),其中y
是真正的标签,p
是相应测试样本的概率估计值。它返回日志丢失与截止图表。import matplotlib.pyplot as plt cutoffs = np.arange(0.01, 1, 0.01) log_loss = [] for c in cutoffs: log_loss.append( sklearn.metrics.log_loss(test.iloc[:, 0], np.where(predictions > c, 1, 0)) ) plt.figure(figsize=(15,10)) plt.plot(cutoffs, log_loss) plt.xlabel("Cutoff") plt.ylabel("Log loss") plt.show()
这应返回以下日志丢失曲线。
-
使用 NumPy
argmin
和min
函数查找错误曲线的最低点:print( 'Log loss is minimized at a cutoff of ', cutoffs[np.argmin(log_loss)], ', and the log loss value at the minimum is ', np.min(log_loss) )
这应返回:
Log loss is minimized at a cutoff of 0.53, and the log loss value at the minimum is 4.348539186773897
。您可以估计成本函数作为替代方案,而不是计算和最大程度地减少日志损失函数。例如,如果要训练模型以针对业务问题 (例如客户流失预测问题) 执行二进制分类,可以对混淆矩阵的元素设置权重并相应地计算成本函数。
现在,您已在 中训练、部署和评估您的第一个模型SageMaker。
要监控模型质量、数据质量和偏差,请使用 Amazon SageMaker 模型监控器 和 SageMaker 阐明。要了解更多信息,请参阅Amazon SageMaker 模型监控器监控数据质量、监控模型质量、监控偏差和监控功能属性偏差。
要获得对低置信度 ML 预测或随机预测示例的人工审核,请使用Amazon Augmented AI人工审核工作流。有关更多信息,请参阅使用 Amazon Augmented AI 进行人工审核。