创建模型质量基准 - Amazon SageMaker
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

创建模型质量基准

创建基准作业,将模型预测与存储在 Amazon S3 中的基准数据集中的 Ground Truth 标签进行比较。通常,您使用训练数据集作为基准数据集。基准作业计算模型的指标,并建议用于监控模型质量偏移的约束。

要创建基准作业,您需要有一个数据集,其中包含模型的预测以及代表数据 Ground Truth 的标签。

要创建基准作业,请使用 SageMaker Python SDK 提供的 ModelQualityMonitor 类,然后完成以下步骤。

创建模型质量基准作业
  1. 首先,创建 ModelQualityMonitor 类的实例。以下代码片段演示了如何执行此操作。

    from sagemaker import get_execution_role, session, Session from sagemaker.model_monitor import ModelQualityMonitor role = get_execution_role() session = Session() model_quality_monitor = ModelQualityMonitor( role=role, instance_count=1, instance_type='ml.m5.xlarge', volume_size_in_gb=20, max_runtime_in_seconds=1800, sagemaker_session=session )
  2. 现在调用 ModelQualityMonitor 对象的 suggest_baseline 方法来运行基准作业。以下代码片段假设您有一个基准数据集,其中包含存储在 Amazon S3 中的预测和标签。

    baseline_job_name = "MyBaseLineJob" job = model_quality_monitor.suggest_baseline( job_name=baseline_job_name, baseline_dataset=baseline_dataset_uri, # The S3 location of the validation dataset. dataset_format=DatasetFormat.csv(header=True), output_s3_uri = baseline_results_uri, # The S3 location to store the results. problem_type='BinaryClassification', inference_attribute= "prediction", # The column in the dataset that contains predictions. probability_attribute= "probability", # The column in the dataset that contains probabilities. ground_truth_attribute= "label" # The column in the dataset that contains ground truth labels. ) job.wait(logs=False)
  3. 基准作业完成后,您可以看到作业生成的约束。首先,通过调用 ModelQualityMonitor 对象的 latest_baselining_job 方法来获取基准作业的结果。

    baseline_job = model_quality_monitor.latest_baselining_job
  4. 基准作业建议了一些约束,这些约束是模型监控所测量指标的阈值。如果一项指标超过建议的阈值,则模型监控器会报告违规行为。要查看基准作业生成的约束,请调用基准作业的 suggested_constraints 方法。以下代码片段将二进制分类模型的约束加载到 Pandas 数据框中。

    import pandas as pd pd.DataFrame(baseline_job.suggested_constraints().body_dict["binary_classification_constraints"]).T

    我们建议您先查看生成的约束并根据需要对其进行修改,然后再使用它们进行监控。例如,如果某项约束过于激进,则您收到的违规警报可能会比预期的要多。

    如果您的约束包含以科学记数法表示的数字,则需要将其转换为浮点数。以下 Python 预处理脚本示例显示了如何将科学记数法中的数字转换为浮点数。

    import csv def fix_scientific_notation(col): try: return format(float(col), "f") except: return col def preprocess_handler(csv_line): reader = csv.reader([csv_line]) csv_record = next(reader) #skip baseline header, change HEADER_NAME to the first column's name if csv_record[0] == “HEADER_NAME”: return [] return { str(i).zfill(20) : fix_scientific_notation(d) for i, d in enumerate(csv_record)}

    您可以将预处理脚本作为 record_preprocessor_script 添加到基准或监控计划,如模型监控文档中所定义。

  5. 当您对约束感到满意时,请在创建监控计划时将其作为 constraints 参数传递。有关更多信息,请参阅 计划模型质量监控作业

建议的基准约束包含在您使用 output_s3_uri 指定的位置处的 constraints.json 文件中。有关此文件架构的信息,请参阅约束的架构(constraints.json 文件)