为生产中的模型创建 SHAP 基线 - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅中国的 Amazon Web Services 服务入门

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

为生产中的模型创建 SHAP 基线

解释通常是对比性的,也就是说,它们反映了偏离基准的情况。有关解释基线的信息,请参阅用于解释的 SAP 基线.

除了为每个实例的推理提供解释之外,SageMaker Clelity 还支持对 ML 模型的全局解释,从而帮助您了解模型作为一个整体的特性。SageMaker 澄清通过聚合多个实例上的 Shapley 值来生成 ML 模型的全局解释。SageMaker 澄清支持以下不同的聚合方式,您可以使用这些方法来定义基线:

  • mean_abs— 所有实例的绝对 SHAP 值的平均值。

  • median— 所有实例的 SHAP 值的中位数。

  • mean_sq— 所有实例的平方 SHAP 值的平均值。

将应用程序配置为捕获实时推理数据后,要监视要素属性偏移的第一个任务是创建要比较的基线。这包括配置数据输入、哪些组是敏感的、如何捕获预测,以及模型及其训练后偏差度量。然后你需要启动基线作业。模型解释监视器可以解释已部署模型的预测,该模型正在生成推断并定期检测要素归因漂移。

model_explainability_monitor = ModelExplainabilityMonitor( role=role, sagemaker_session=sagemaker_session, max_runtime_in_seconds=1800, )

在此示例中,解释性基线作业将测试数据集与偏差基线作业共享,因此它使用相同的DataConfig,唯一的区别是作业输出 URI。

model_explainability_baselining_job_result_uri = f"{baseline_results_uri}/model_explainability" model_explainability_data_config = DataConfig( s3_data_input_path=validation_dataset, s3_output_path=model_explainability_baselining_job_result_uri, label=label_header, headers=all_headers, dataset_type=dataset_type, )

目前 SageMaker 澄清解释者提供了 SHAP 的可扩展且高效的实现,因此解释性配置是 ShapConfig,包括以下内容:

  • baseline— 用作内核 SHAP 算法中基线数据集的行(至少一个)或 S3 对象 URI 的列表。格式应与数据集格式相同。每行应仅包含要素列/值,并忽略标签列/值。

  • num_samples— 要在内核 SHAP 算法中使用的样本数。此数字确定生成的合成数据集的大小,以计算 SHAP 值。

  • agg_method — 全局 SHAP 值的聚合方法。有效值有:

    • mean_abs— 所有实例的绝对 SHAP 值的平均值。

    • median— 所有实例的 SHAP 值的中位数。

    • mean_sq— 所有实例的平方 SHAP 值的平均值。

  • use_logit— logit 函数是否应用于模型预测的指示器。默认为 False。如果use_logitTrue,则 SHAP 值将具有对数赔率单位。

  • save_local_shap_values(bool) — 指示是否将本地 SHAP 值保存在输出位置。默认为 False

# Here use the mean value of test dataset as SHAP baseline test_dataframe = pd.read_csv(test_dataset, header=None) shap_baseline = [list(test_dataframe.mean())] shap_config = SHAPConfig( baseline=shap_baseline, num_samples=100, agg_method="mean_abs", save_local_shap_values=False, )

启动基线作业。相同model_config是必需的,因为解释基线作业需要创建阴影端点才能获取生成的合成数据集的预测。

model_explainability_monitor.suggest_baseline( data_config=model_explainability_data_config, model_config=model_config, explainability_config=shap_config, ) print(f"ModelExplainabilityMonitor baselining job: {model_explainability_monitor.latest_baselining_job_name}")