强化学习 - 亚马逊 SageMaker AI
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

强化学习

注意

订阅后即提供详细文档

Nova Forge 提供高级强化学习功能,可以选择在自己的环境中使用远程奖励功能。客户可以选择集成自己的端点来执行验证以获得即时现实世界的反馈,甚至可以使用自己的协调器来协调环境中的代理多回合评估。

自带协调器进行代理多回合评估

对于需要多回合对话或奖励功能超过 15 分钟超时的 Forge 用户,Nova Forge 提供自带编排 (BYOO) 功能。这使您可以协调环境中的代理多回合评估(例如,使用化学工具对分子设计进行评分,或者使用机器人模拟来奖励高效完成任务并惩罚碰撞)。

架构概述

BYOO 架构通过客户管理的基础架构提供对部署和生成过程的完全控制。

训练 VPC:

  • 部署:通过将部署生成委托给客户基础架构来协调培训

  • Trainer:根据收到的推出量执行模型权重更新

客户 VPC(例如开启的 ECS EC2):

  • 代理 Lambda:接收部署请求并与客户基础设施进行协调

  • Rollout Response SQS:将已完成的部署返回到培训基础设施的队列

  • 生成请求 SQS:模型生成请求队列

  • 生成响应 SQS:模型生成响应队列

  • 客户容器:实现自定义编排逻辑(可以使用提供的入门套件)

  • D ynamoDB:在整个编排过程中存储和检索状态

工作流程:

  1. 部署委托代理 Lambda 部署生成

  2. 代理 Lambda 向生成请求 SQS 推送推出 API 请求

  3. 客户容器处理请求、管理多回合互动和呼叫奖励功能

  4. 容器根据需要存储和检索来自 DynamoDB 的状态

  5. 容器将推出响应推送到部署响应 SQS

  6. Rollout 将完成的部署发送给 Trainer 以获取体重更新

自带设备设置

先决条件:

部署步骤:

为您并行运行的每个新环境部署它。将使用此代码创建三个 Lambda 函数和四个 SQS 队列。部署对应于架构图的中间部分,以实现训练集群和客户端之间的通信。

sam build sam deploy --guided \ --stack-name <Your Stack Name> \ --capabilities CAPABILITY_IAM \ --parameter-overrides ProjectName=<your-project-name>

为自定义 RL 环境准备数据

重要

带有 BYOO 的自定义 RL 环境是在训练期间使用rollout.delegate: true设置和 BYOO 基础架构参数配置的。某些示例中提到的rl_env字段仅在评估期间用于指定如何评估训练后的模型,而不是在训练本身期间使用。

对于需要自定义 RL 环境或代理的用例,messagestools字段是可选的。使用以下格式来构建您的数据集:

{ "id": "wordle_001", "messages": [], "tools": [], "metadata": { "answer": "crane", "problem": "Guess: crane" } }

元数据将在推出请求中按原样传递。有关更多详细信息,请参阅推出请求文档。

BYOO 配方配置

注意

您自己Amazon环境中的远程奖励功能使用该rl_env字段。相反,他们会rollout.delegate: true将编排移交给您的自定义基础架构。该rl_env字段仅在评估期间用于指定如何评估训练后的模型。请注意,data_s3_path这是强制性的,并且与用于启动 Nova 模型与环境之间对话的初始提示相对应。

max_seq_length包括模型在多回合对话期间预期获得的完整上下文长度。代币的数量在每回合中都会迅速增加,并且在设置时应考虑响应时长。

同样,rollout.timeout这是培训师与环境之间整个对话预计花费的最长时间(以秒为单位)。超时将导致训练失败。

该配方专为培训师与环境之间的高通量通信而设计。根据设计,训练集群将并行创建对环境的许多请求,而 BYOO 环境应设计为处理此类请求。

run: name: <run-name> model_type: amazon.nova-2-lite-v1:0:256k model_name_or_path: nova-lite-2/prod data_s3_path: s3://<bucket-name>/train.jsonl # required output_s3_path: s3://path/to/output/checkpoint replicas: 4 generation_replicas: 2 rollout_worker_replicas: 1 rollout_request_arn: <rollout-proxy-lambda-arn> rollout_response_sqs_url: <rollout-response-queue-url> generate_request_sqs_url: <generate-request-queue-url> generate_response_sqs_url: <generate-response-queue-url> training_config: max_steps: 100 max_seq_length: 9392 global_batch_size: 1024 reasoning_effort: high # Options: low, high, or omit for no reasoning data: # Or multi-turn for multi-turn conversations shuffle: false rollout: delegate: true # Enables BYOO mode timeout: 600 # 10 minutes timeout for rollout completion rollout_strategy: type: off_policy_async age_tolerance: 2 advantage_strategy: number_generation: 16 generator: server_count: ${run.generation_replicas} timeout: 1000 max_model_len: ${training_config.max_seq_length} max_new_tokens: 18000 set_random_seed: true temperature: 1 top_k: 0 train: replicas: ${run.replicas} max_steps: ${training_config.max_steps} global_batch_size: ${training_config.global_batch_size} save_steps: 10 save_top_k: 5 # RL parameters [Advanced] clip_ratio_high: 0.2 ent_coeff: 0.001 loss_scale: 1 # Optimizer settings optim: lr: 1e-7 optimizer: 'adam' weight_decay: 0.01 adam_beta1: 0.9 adam_beta2: 0.95 warmup_steps: 5 min_lr: 1e-5

使用与其他配方类似的hyperpod start-job命令启动 HyperPod 配方。并行运行 BYOO 环境,其详细信息如下所示。

BYOO 请求和响应格式

推出请求:

从训练基础设施发送到您的代理 Lambda:

{ "version": "v0", "timestamp": "2025-10-28T...", "sample_id": "sample-000_0", "max_length": 10240, "rewards": { "range": [0.0, 1.0] }, "data": { "problem": "How many six-digit numbers are there in which all digits are odd? Let's think step by step and output the final answer within \\boxed{}.", "answer": "15625" } }

生成请求:

从您的编排逻辑发送到请求生成模型:

{ "version": "v0", "sample_id": "sample-000_0", "step_id": "sample-000_0_0", "messages": [ { "role": "user", "content": "How many six-digit numbers are there in which all digits are odd? Let's think step by step and output the final answer within \\boxed{}." } ] }

生成响应:

从模型生成服务返回:

{ "version": "v0", "sample_id": "sample-000_0", "step_id": "sample-000_0_0", "data": { "choices": [{ "message": { "content": "To determine how many six-digit numbers..." }, "finish_reason": "stop", "logprobs": { "content": [ {"token": "token_id:123", "logprob": -0.5} ] } }], "serving_model_num": 0 }, "finish_reason": "stop" }

推出响应:

从您的编排逻辑发送回训练基础架构:

{ "version": "v0", "sample_id": "sample-000_0", "stop_reason": "end_of_conversation", "rewards": { "aggregate_score": 0.85 } }

BYOO 环境

为了便于您设置自己的环境,Nova Forge 提供了示例环境示例以及使用适当配置启动环境的代码。

安装:

  • 安装 Nova Forge 附带的verifiers软件包。您也可以安装您感兴趣的测试环境,例如导航wordleverifiers/environments/wordle/并运行 pip install -e .

  • 导航到NovaRFTEnvBundles/nova-rl-async-client并运行 pip install -e .

  • 导航NovaRFTEnvBundles/nova-rl-async-client/src到查看培训和评估客户的示例。下面给出了 wordle 环境的配置示例。

训练客户端配置:

@chz.chz class CLIConfig: # SQS configuration # rollout request queue queue_url: str = "https://sqs.us-east-1.amazonaws.com/<account_id>/<project-name>-SageMaker-RolloutRequestQueue.fifo" region_name: str = "us-east-1" groups_per_batch: int = 4 max_messages_per_poll: int = 10 # Client configuration (for model inference) # proxy lambda client_base_url: str = "https://<proxy-lambda-id>.lambda-url.<region>.on.aws/" client_region: str = "us-east-1" client_service: str = "lambda" client_timeout: float = 600.0 client_poll_interval: float = 0.5 # environment configuration vf_env_id: str = "wordle" vf_env_args: str | None = None # rollout configuration group_size: int = 1 model_name: str = "nova-rl" # processing control max_batches: int | None = None # None = process until queue empty continuous: bool = True # If True, keep polling forever

评估客户端配置:

@chz.chz class CLIConfig: # Model configuration model_name: str = "nova" # Environment configuration vf_env_id: str = "wordle" vf_env_args: str | None = None # '{"max_examples": 1, "max_turns": 5}' # JSON string # Evaluation configuration num_examples: int = 1 rollouts_per_example: int = 1 max_concurrent: int = 32 # Sampling configuration max_tokens: int = 1024 temperature: float = 0.0 # Client configuration # proxy lambda client_base_url: str = "https://<proxy-lambda-id>.lambda-url.us-east-1.on.aws/" client_region: str = "us-east-1" client_service: str = "lambda" client_timeout: float = 3000.0 client_poll_interval: float = 0.5