为生产中的模型创建 SHAP 基准 - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

为生产中的模型创建 SHAP 基准

解释通常是对比性的,也就是说,它们解释了与基准的偏差。有关可解释性基准的信息,请参阅SHAP 可解释性基准

除了为每个实例的推理提供解释外,SageMaker Clarify 还支持对机器学习模型进行全局解释,以帮助您根据模型的特征了解模型的整体行为。SageMaker Clarify 通过聚合多个实例的 Shapley 值,生成机器学习模型的全局解释。SageMaker Clarify 支持以下不同的聚合方式,您可以用这些方式来定义基准:

  • mean_abs - 所有实例的绝对 SHAP 值的平均值。

  • median - 所有实例的 SHAP 值的中位数。

  • mean_sq - 所有实例的平方 SHAP 值的平均值。

将应用程序配置为捕获实时或批量转换推理数据后,监控特征归因偏移的第一项任务是创建基准以进行比较。这包括配置数据输入、哪些组是敏感组、如何捕获预测,以及模型及其训练后偏差指标。然后,您需要启动设定基准作业。模型可解释性监控器可以解释正在生成推理的已部署模型的预测,并定期检测特征归因偏移。

model_explainability_monitor = ModelExplainabilityMonitor( role=role, sagemaker_session=sagemaker_session, max_runtime_in_seconds=1800, )

在本例中,可解释性基准作业与偏差基准作业共享测试数据集,因此它使用相同的 DataConfig,唯一的区别是作业输出 URI。

model_explainability_baselining_job_result_uri = f"{baseline_results_uri}/model_explainability" model_explainability_data_config = DataConfig( s3_data_input_path=validation_dataset, s3_output_path=model_explainability_baselining_job_result_uri, label=label_header, headers=all_headers, dataset_type=dataset_type, )

目前,SageMaker Clarify 解释器提供了 SHAP 的可扩展且高效的实施,因此可解释性配置是 SHAPConfig,包括以下内容:

  • baseline - Kernel SHAP 算法中用作基准数据集的行(至少一行)列表或 S3 对象 URI。其格式应与数据集格式相同。每行应仅包含特征列/值,而省略标签列/值。

  • num_samples - 要在 Kernel SHAP 算法中使用的样本数。该数字决定了生成的用于计算 SHAP 值的合成数据集的大小。

  • agg_method – 全局 SHAP 值的聚合方法。有效值如下所示:

    • mean_abs - 所有实例的绝对 SHAP 值的平均值。

    • median - 所有实例的 SHAP 值的中位数。

    • mean_sq - 所有实例的平方 SHAP 值的平均值。

  • use_logit - 指示是否将 logit 函数应用于模型预测。默认为 False。如果 use_logitTrue,SHAP 值将具有对数几率单位。

  • save_local_shap_values (bool) - 指示是否将局部 SHAP 值保存在输出位置。默认为 False

# Here use the mean value of test dataset as SHAP baseline test_dataframe = pd.read_csv(test_dataset, header=None) shap_baseline = [list(test_dataframe.mean())] shap_config = SHAPConfig( baseline=shap_baseline, num_samples=100, agg_method="mean_abs", save_local_shap_values=False, )

启动设定基准作业。需要相同的 model_config,因为可解释性设定基准作业需要创建一个影子端点,以获取生成的合成数据集的预测。

model_explainability_monitor.suggest_baseline( data_config=model_explainability_data_config, model_config=model_config, explainability_config=shap_config, ) print(f"ModelExplainabilityMonitor baselining job: {model_explainability_monitor.latest_baselining_job_name}")