为生产中的模型创建SHAP 基准 - Amazon SageMaker
AWS 文档中描述的 AWS 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅中国的 AWS 服务入门

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

为生产中的模型创建SHAP 基准

说明通常是对比性的,即,它们考虑的是偏离基准的情况。有关可解释性基准的信息,请参阅用于说明的 SAP 基准

除了提供每个实例推理的说明之外, SageMaker Clarify 还支持 ML 模型的全局说明,以帮助您了解模型的整体行为以及其功能。 通过聚合多个实例上的 Shapley 值SageMaker Clarify来生成 ML 模型的全局说明。 SageMaker Clarify 支持以下不同的聚合方式,您可以使用这些方式定义基准:

  • mean_abs – 所有实例的绝对 SAP 值的平均值。

  • median – 所有实例的 SAP 值中值。

  • mean_sq – 所有实例的 SAP 值平方的平均值。

在将应用程序配置为捕获实时推理数据后,第一个监控特征属性偏差的任务是创建要比较的基准。这包括配置数据输入、哪些组敏感、如何捕获预测以及模型及其训练后偏差指标。然后,您需要启动基准设置作业。模型可解释性监视器可以解释所部署模型的预测,该模型将生成推理并定期检测特征属性偏差。

model_explainability_monitor = ModelExplainabilityMonitor( role=role, sagemaker_session=sagemaker_session, max_runtime_in_seconds=1800, )

在此示例中,解释性基准设置作业与偏差基准设置作业共享测试数据集,因此它使用相同的 DataConfig,唯一的区别是作业输出 URI。

model_explainability_baselining_job_result_uri = f"{baseline_results_uri}/model_explainability" model_explainability_data_config = DataConfig( s3_data_input_path=validation_dataset, s3_output_path=model_explainability_baselining_job_result_uri, label=label_header, headers=all_headers, dataset_type=dataset_type, )

目前, SageMaker Clarify 说明程序提供了 SAP 的可扩展且高效的实施,因此,可解释性配置是 SHAPConfig包括以下内容:

  • baseline – 要在内核 SAP 算法中用作基准数据集的行(至少一行)或 S3 对象 URI 的列表。格式应与数据集格式相同。每行应仅包含特征列/值并忽略标签列/值。

  • num_samples – 内核 SAP 算法中要使用的样本数。此数字确定生成的合成数据集的大小,以计算SHAP 值。

  • agg_method – 全局 SAP 值的聚合方法。以下是有效值:

    • mean_abs – 所有实例的绝对 SAP 值的平均值。

    • median – 所有实例的 SAP 值中值。

    • mean_sq – 所有实例的 SAP 值平方的平均值。

  • use_logit – 指示 logit 函数是否应用于模型预测的指标。默认值为 False。如果 use_logitTrue,则 SAP 值将具有 log-ods 单位。

  • save_local_shap_values (布尔) 指示是否在输出位置保存本地SHAP 值的–指示器。默认值为 True

# Here use the mean value of test dataset as SHAP baseline test_dataframe = pd.read_csv(test_dataset, header=None) shap_baseline = [list(test_dataframe.mean())] shap_config = SHAPConfig( baseline=shap_baseline, num_samples=100, agg_method="mean_abs", save_local_shap_values=False, )

启动基准设置作业。这model_config同样是必需的,因为可解释性基准作业需要创建影子终端节点以获取生成的合成数据集的预测。

model_explainability_monitor.suggest_baseline( data_config=model_explainability_data_config, model_config=model_config, explainability_config=shap_config, ) print(f"ModelExplainabilityMonitor baselining job: {model_explainability_monitor.latest_baselining_job_name}")