Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅中国的 Amazon Web Services 服务入门。本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
配置并启动超参数优化作业
要配置和启动超级参数优化作业,请完成以下步骤。
指定超参数优化作业设置
要指定超参数优化作业的设置,您需要定义 JSON 对象。创建优化作业时,将对象作为 HyperParameterTuningJobConfig
参数的值传递给 CreateHyperParameterTuningJob
。
在此 JSON 对象中,您将指定:
超级参数优化作业定义 XGBoost 算法 内置算法的 eta
、alpha
、min_child_weight
和 max_depth
超级参数的范围。超参数优化作业的目标指标最大化这validation:auc
算法发送到的指标 CloudWatch 日志。
tuning_job_config = {
"ParameterRanges": {
"CategoricalParameterRanges": [],
"ContinuousParameterRanges": [
{
"MaxValue": "1",
"MinValue": "0",
"Name": "eta"
},
{
"MaxValue": "2",
"MinValue": "0",
"Name": "alpha"
},
{
"MaxValue": "10",
"MinValue": "1",
"Name": "min_child_weight"
}
],
"IntegerParameterRanges": [
{
"MaxValue": "10",
"MinValue": "1",
"Name": "max_depth"
}
]
},
"ResourceLimits": {
"MaxNumberOfTrainingJobs": 20,
"MaxParallelTrainingJobs": 3
},
"Strategy": "Bayesian",
"HyperParameterTuningJobObjective": {
"MetricName": "validation:auc",
"Type": "Maximize"
}
}
配置训练作业
要配置优化作业启动的训练作业,请定义作为 CreateHyperParameterTuningJob
调用的 TrainingJobDefinition
参数值传递的 JSON 对象。
在此 JSON 对象中,您将指定:
在本示例中,我们将设置 XGBoost 算法 内置算法的 eval_metric
、num_round
、objective
、rate_drop
和 tweedie_variance_power
参数的静态值。
- SageMaker Python SDK v1
-
from sagemaker.amazon.amazon_estimator import get_image_uri
training_image = get_image_uri(region, 'xgboost', repo_version='1.0-1')
s3_input_train = 's3://{}/{}/train'.format(bucket, prefix)
s3_input_validation ='s3://{}/{}/validation/'.format(bucket, prefix)
training_job_definition = {
"AlgorithmSpecification": {
"TrainingImage": training_image,
"TrainingInputMode": "File"
},
"InputDataConfig": [
{
"ChannelName": "train",
"CompressionType": "None",
"ContentType": "csv",
"DataSource": {
"S3DataSource": {
"S3DataDistributionType": "FullyReplicated",
"S3DataType": "S3Prefix",
"S3Uri": s3_input_train
}
}
},
{
"ChannelName": "validation",
"CompressionType": "None",
"ContentType": "csv",
"DataSource": {
"S3DataSource": {
"S3DataDistributionType": "FullyReplicated",
"S3DataType": "S3Prefix",
"S3Uri": s3_input_validation
}
}
}
],
"OutputDataConfig": {
"S3OutputPath": "s3://{}/{}/output".format(bucket,prefix)
},
"ResourceConfig": {
"InstanceCount": 2,
"InstanceType": "ml.c4.2xlarge",
"VolumeSizeInGB": 10
},
"RoleArn": role,
"StaticHyperParameters": {
"eval_metric": "auc",
"num_round": "100",
"objective": "binary:logistic",
"rate_drop": "0.3",
"tweedie_variance_power": "1.4"
},
"StoppingCondition": {
"MaxRuntimeInSeconds": 43200
}
}
- SageMaker Python SDK v2
-
training_image = sagemaker.image_uris.retrieve('xgboost', region, '1.0-1')
s3_input_train = 's3://{}/{}/train'.format(bucket, prefix)
s3_input_validation ='s3://{}/{}/validation/'.format(bucket, prefix)
training_job_definition = {
"AlgorithmSpecification": {
"TrainingImage": training_image,
"TrainingInputMode": "File"
},
"InputDataConfig": [
{
"ChannelName": "train",
"CompressionType": "None",
"ContentType": "csv",
"DataSource": {
"S3DataSource": {
"S3DataDistributionType": "FullyReplicated",
"S3DataType": "S3Prefix",
"S3Uri": s3_input_train
}
}
},
{
"ChannelName": "validation",
"CompressionType": "None",
"ContentType": "csv",
"DataSource": {
"S3DataSource": {
"S3DataDistributionType": "FullyReplicated",
"S3DataType": "S3Prefix",
"S3Uri": s3_input_validation
}
}
}
],
"OutputDataConfig": {
"S3OutputPath": "s3://{}/{}/output".format(bucket,prefix)
},
"ResourceConfig": {
"InstanceCount": 2,
"InstanceType": "ml.c4.2xlarge",
"VolumeSizeInGB": 10
},
"RoleArn": role,
"StaticHyperParameters": {
"eval_metric": "auc",
"num_round": "100",
"objective": "binary:logistic",
"rate_drop": "0.3",
"tweedie_variance_power": "1.4"
},
"StoppingCondition": {
"MaxRuntimeInSeconds": 43200
}
}
命名并启动超参数优化作业
现在,您可为超参数优化作业提供名称,然后通过调用 CreateHyperParameterTuningJob
API 来启动它。传递您在之前步骤中创建的 tuning_job_config
和 training_job_definition
作为参数的值。
tuning_job_name = "MyTuningJob"
smclient.create_hyper_parameter_tuning_job(HyperParameterTuningJobName = tuning_job_name,
HyperParameterTuningJobConfig = tuning_job_config,
TrainingJobDefinition = training_job_definition)
下一个步骤
监控超参数优化作业的进度