使用 JumpStartEstimator 类微调公开可用的基础模型 - 亚马逊 SageMaker AI
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

使用 JumpStartEstimator 类微调公开可用的基础模型

注意

有关在私有策划中心微调基础模型的说明,请参阅。微调精选集线器模型

您只需使用几行代码即可对内置算法或预训练模型进行微调 SageMaker Python SDK。

  1. 首先,在内置算法与预训练模型表中找到所选模型的模型 ID。

  2. 使用模型 ID,将您的训练作业定义为 JumpStart估算器。

    from sagemaker.jumpstart.estimator import JumpStartEstimator model_id = "huggingface-textgeneration1-gpt-j-6b" estimator = JumpStartEstimator(model_id=model_id)
  3. 在模型上运行 estimator.fit(),指向用于微调的训练数据。

    estimator.fit( {"train": training_dataset_s3_path, "validation": validation_dataset_s3_path} )
  4. 然后,使用 deploy 方法自动部署模型进行推理。在这个例子中,我们使用来自的 GPT-J 6B 模型 Hugging Face.

    predictor = estimator.deploy()
  5. 然后,您就可以使用 predict 方法对已部署的模型进行推理。

    question = "What is Southern California often abbreviated as?" response = predictor.predict(question) print(response)
注意

此示例使用基础模型 GPT-J 6B,该模型适用于各种文本生成使用场景,包括问题解答、命名实体识别、摘要等。有关模型使用场景的更多信息,请参阅 可用的基础模型

创建 JumpStartEstimator 时,您可以选择指定模型版本或实例类型。有关该JumpStartEstimator 类及其参数的更多信息,请参见JumpStartEstimator

检查默认实例类型

在使用 JumpStartEstimator 类对预训练模型进行微调时,您可以选择包含特定的模型版本或实例类型。所有 JumpStart 模型都有默认的实例类型。使用以下代码读取默认训练实例类型:

from sagemaker import instance_types instance_type = instance_types.retrieve_default( model_id=model_id, model_version=model_version, scope="training") print(instance_type)

您可以使用instance_types.retrieve()方法查看给定 JumpStart 模型的所有支持的实例类型。

检查默认超参数

要检查用于训练的默认超参数,可以使用 hyperparameters 类中的 retrieve_default() 方法。

from sagemaker import hyperparameters my_hyperparameters = hyperparameters.retrieve_default(model_id=model_id, model_version=model_version) print(my_hyperparameters) # Optionally override default hyperparameters for fine-tuning my_hyperparameters["epoch"] = "3" my_hyperparameters["per_device_train_batch_size"] = "4" # Optionally validate hyperparameters for the model hyperparameters.validate(model_id=model_id, model_version=model_version, hyperparameters=my_hyperparameters)

有关可用超参数的更多信息,请参阅 通常支持的微调超参数

检查默认指标定义

您还可以检查默认指标定义:

print(metric_definitions.retrieve_default(model_id=model_id, model_version=model_version))