实现奖励功能 - 亚马逊 SageMaker AI
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

实现奖励功能

概述

奖励函数(也称为记分器或评分器)是评估模型响应并为训练提供反馈信号的核心组件。它必须作为接受模型响应并返回奖励分数的 Lambda 函数来实现。

接口格式

您的奖励函数必须接受并返回以下格式的数据:

训练的样本输入示例

{ "messages": [ { "role": "user", "content": "Do you have a dedicated security team?" } ], "reference_answer": { "compliant": "No", "explanation": "As an AI developed by Company, I do not have a traditional security team..." } }

奖励 lambda 的有效载荷示例

容器会自动转换您的数据,然后再将其发送到 Lambda 函数,方法是:

  1. 为每个提示生成模型响应

  2. 将助手回合(生成的响应)附加到消息数组

  3. 添加用于跟踪的唯一id字段

您的 Lambda 函数将以这种转换后的格式接收数据:

{ "id": "123", "messages": [ { "role": "user", "content": "Do you have a dedicated security team?" }, { "role": "assistant", "content": "As an AI developed by Amazon, I don not have a dedicated security team..." } ], # Following section will be same as your training dataset sample "reference_answer": { "compliant": "No", "explanation": "As an AI developed by Company, I do not have a traditional security team..." } }

奖励 Lambda 合约

def lambda_handler(event, context): return lambda_grader(event) def lambda_grader(samples: list[dict]) -> list[dict]: """ Args: samples: List of dictionaries in OpenAI format Example input: { "id": "123", "messages": [ { "role": "user", "content": "Do you have a dedicated security team?" }, { "role": "assistant", "content": "As an AI developed by Company, I don nott have a dedicated security team..." } ], # This section will be same as your training dataset "reference_answer": { "compliant": "No", "explanation": "As an AI developed by Company, I do not have a traditional security team..." } } Returns: List of dictionaries with reward scores: { "id": str, # Same id as input sample "aggregate_reward_score": float, # Overall score for the sample "metrics_list": [ # OPTIONAL: Component scores { "name": str, # Name of the component score "value": float, # Value of the component score "type": str # "Reward" or "Metric" } ] } """

输入和输出字段

输入字段

字段 说明 附加说明
id 样本的唯一标识符 在输出中回声。字符串格式
消息 以 OpenAI 格式排序的聊天记录 消息对象数组
消息 [] .role 留言的发言人 常用值:“用户”、“助手”、“系统”
消息 [] .content 消息的文字内容 纯字符串
**元数据 有助于评分的自由格式信息 对象;从训练数据传递的可选字段

输出字段

字段 说明 附加说明
id 与输入样本相同的标识符 必须匹配输入
聚合_奖励_分数 样本的总分数 浮点型(例如,0.0—1.0 或任务定义的范围)
指标列表 构成汇总的分量分数 指标对象数组

技术限制

  • 超时限制-每次 Lambda 调用的最大执行时间 15 分钟

  • 并@@ -必须处理rollout_worker_replicas * 64并发请求

  • 可靠性-必须实施正确的错误处理并始终如一地返回有效分数

  • 性能-针对快速执行(几秒钟,而不是几分钟)进行优化,从而实现高效训练

最佳实践

  • 尽量减少外部 API 调用

  • 使用高效的算法和数据结构

  • 为暂时失败实现重试逻辑

  • 缓存可重复使用的计算

  • 在训练前进行彻底测试,确保执行无错误

使用自定义奖励功能

当你有特定任务的评估标准时,可以实现自定义奖励函数:

  • 定义评估标准-确定哪些因素可以很好地响应您的任务

  • 实现 Lambda 函数-按照接口格式创建 Lambda 函数

  • 本地测试-验证您的函数返回样本输入的正确分数

  • 部署到 Amazon — 部署您的 Lambda 并记下 ARN

  • 配置食谱 — 将 Lambda ARN 添加到您的食谱字段中 reward_lambda_arn

  • 使用小型数据集进行测试 — 使用最少的数据运行 RFT 以验证集成

IAM 权限

所需的权限

您的 SageMaker 执行角色必须具有调用您的 Lambda 函数的权限。将此策略添加到您的 SageMaker 执行角色中:

{ "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Action": [ "lambda:InvokeFunction" ], "Resource": "arn:aws:lambda:region:account-id:function:function-name" } ] }

Lambda 执行角色

您的 Lambda 函数的执行角色需要基本的 Lambda 执行权限:

{ "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Action": [ "logs:CreateLogGroup", "logs:CreateLogStream", "logs:PutLogEvents" ], "Resource": "arn:aws:logs:*:*:*" } ] }

其他权限:如果您的 Lambda 函数访问其他Amazon服务(例如,S3 用于参考数据,DynamoDB 用于日志记录),请将这些权限添加到 Lambda 执行角色。

示例:LLM As a Judge 奖励功能

此示例演示如何使用 Amazon Bedrock 模型作为评判,通过将模型响应与参考答案进行比较来评估模型响应。此 Lambda 模板为客户提供了一个框架,用于实现对 Amazon Bedrock 的调用,以请求推理来处理评判评估。Lambda 函数与其他奖励函数保持相同的 input/output 合约。

实施

此 Lambda 函数实现了两个阶段的评估过程:从传入的样本中lambda_handler提取模型响应和参考答案,然后该函数lambda_graded调用 Amazon Bedrock 来对它们之间的语义相似度进行评分。该实现包括强大的错误处理功能,可自动重试临时故障,并支持灵活的参考答案格式(字符串和结构化字典格式)。

实施细节:

  • 重试逻辑:针对限制异常实现指数退避(1s、2s、4s),以处理 Bedrock API 速率限制

  • 错误处理:如果评估失败,则返回 0.0 的分数,而不是引发异常

  • 确定性评分:使用 temperature=0.0 来确保各个评估的分数一致

  • 灵活的参考格式:自动处理字符串和字典参考答案

  • 分数限制:确保所有分数都在有效的 [0.0, 1.0] 范围内

  • 模型不可知论:更改 JUDGE_MODEL_ID 以使用任何亚马逊 Bedrock 模型(Nova、Llama、Mistral 等)

""" LLM Judge Lambda POC - Working implementation using Amazon Bedrock """ import json import time import boto3 bedrock_runtime = boto3.client('bedrock-runtime', region_name='us-east-1') JUDGE_MODEL_ID = "anthropic.claude-3-5-sonnet-20240620-v1:0" SYSTEM_PROMPT = "You must output ONLY a number between 0.0 and 1.0. No explanations, no text, just the number." JUDGE_PROMPT_TEMPLATE = """Compare the following two responses and rate how similar they are on a scale of 0.0 to 1.0, where: - 1.0 means the responses are semantically equivalent (same meaning, even if worded differently) - 0.5 means the responses are partially similar - 0.0 means the responses are completely different or contradictory Response A: {response_a} Response B: {response_b} Output ONLY a number between 0.0 and 1.0. No explanations.""" def lambda_graded(response_a: str, response_b: str, max_retries: int = 3) -> float: """Call Bedrock to compare responses and return similarity score.""" prompt = JUDGE_PROMPT_TEMPLATE.format(response_a=response_a, response_b=response_b) for attempt in range(max_retries): try: response = bedrock_runtime.converse( modelId=JUDGE_MODEL_ID, messages=[{"role": "user", "content": [{"text": prompt}]}], system=[{"text": SYSTEM_PROMPT}], inferenceConfig={"temperature": 0.0, "maxTokens": 10} ) print(f"Bedrock call successful: {response}") output = response['output']['message']['content'][0]['text'].strip() score = float(output) print(f"Score parsed: {score}") return max(0.0, min(1.0, score)) except Exception as e: if "ThrottlingException" in str(e) and attempt < max_retries - 1: time.sleep(2 ** attempt) else: print(f"Bedrock call failed: {e}") return None return None def lambda_handler(event, context): """AWS Lambda handler - processes samples from RFTEvalInvoker.""" try: samples = event if isinstance(event, list) else [event] results = [] for sample in samples: sample_id = sample.get("id", "unknown") messages = sample.get("messages", []) # Extract assistant response (response A) response_a = "" for msg in messages: if msg.get("role") in ["assistant", "nova_assistant"]: response_a = msg.get("content", "") break # Extract reference answer from root level (no longer in metadata) reference_answer = sample.get("reference_answer", "") # Handle both string and dict reference_answer formats if isinstance(reference_answer, dict): # If reference_answer is a dict, extract the explanation or compliant field response_b = reference_answer.get("explanation", reference_answer.get("compliant", "")) else: response_b = reference_answer if not response_a or not response_b: results.append({ "id": sample_id, "aggregate_reward_score": 0.0, "metrics_list": [{"name": "similarity_score", "value": 0.0, "type": "Metric"}] }) continue # Get similarity score score = lambda_graded(response_a, response_b) results.append({ "id": sample_id, "aggregate_reward_score": score, "metrics_list": [ { "name": "similarity_score", "value": score, "type": "Metric" } ] }) return {"statusCode": 200, "body": json.dumps(results)} except Exception as e: print(f"Error: {e}") return {"statusCode": 500, "body": json.dumps({"error": str(e)})}

输入格式

Lambda 接收与其他奖励函数相同的输入格式:

{ "id": "sample-001", "messages": [ { "role": "user", "content": "Do you have a dedicated security team?" }, { "role": "assistant", "content": "As an AI developed by Amazon, I don't have a dedicated security team..." } ], "reference_answer": { "compliant": "No", "explanation": "As an AI developed by Company, I do not have a traditional security team..." }, "my_custom_field": "custom_value" }

输出格式

{ "id": "sample-001", "aggregate_reward_score": 0.85, "metrics_list": [ { "name": "similarity_score", "value": 0.85, "type": "Metric" } ] }

部署注意事项

您可能还需要根据所选模型的功能和 API 格式调整提示模板和推理参数。

  • IAM 权限:Lambda 执行角色必须拥有您所bedrock:InvokeModel选模型的权限

  • 超时:将 Lambda 超时设置为至少 60 秒,以适应 Bedrock API 延迟和重试次数

  • 区域:在你选择的 Bedrock 模型可用的区域进行部署

  • 成本:监控 Bedrock API 的使用情况,因为每次评估都会对每个样本进行一次 API 调用

  • 吞吐量:对于大规模评估,请请求增加 Bedrock 配额以避免限制

增加基岩吞吐量

如果您在评估期间遇到限制,请增加您的 Bedrock 模型配额:

  • 导航到 S Amazon ervice Quotas 控制台

  • 搜索 “Bedrock” 并选择你所在的地区

  • 找到你所选型号的配额(例如,“Claude 3.5 Sonnet 的每分钟调用次数”)

  • 单击 “请求增加配额” 并指定所需的吞吐量

  • 提供增加的理由(例如,“RFT 评估工作量”)

Lambda 的内置重试逻辑可以处理偶尔的限制,但是持续的大量评估需要适当的配额增加。

所需的 IAM 政策:

{ "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Action": [ "bedrock:InvokeModel" ], "Resource": "arn:aws:bedrock:*::foundation-model/*" } ] }