本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。
PyTorch 到 ONNX 到 CNTK 教程
注意
从 v28 版本开始, Amazon Deep Learning AMI 中将不再包含 CNTK、Caffe、Caffe2 和 Theano Conda 环境。包含这些环境 Amazon Deep Learning AMI 的先前版本将继续可用。但是,只有在开源社区针对这些框架发布安全修补程序时,我们才会为这些环境提供更新。
ONNX 概述
开放神经网络交换
本教程介绍如何将带 Conda 的深度学习 AMI 与 ONNX 结合使用。通过执行以下步骤,您可以训练模型或从一个框架中加载预先训练的模型,将此模型导出为 ONNX,然后将此模型导入到另一个框架中。
ONNX 先决条件
要使用此 ONNX 教程,您必须有权访问带 Conda 的深度学习 AMI 版本 12 或更高版本。有关如何开始使用带 Conda 的深度学习 AMI 的更多信息,请参阅 带 Conda 的深度学习 AMI。
重要
这些示例使用可能需要多达 8 GB 内存(或更多)的函数。请务必选择具有足量内存的实例类型。
使用带 Conda 的深度学习 AMI 来启动终端会话以开始以下教程。
将 PyTorch 模型转换为 ONNX,然后将模型加载到 CNTK
首先,激活 PyTorch 环境:
$
source activate pytorch_p36
使用文本编辑器创建新文件,在脚本中使用以下程序训练模拟模型 PyTorch,然后将其导出为 ONNX 格式。
# Build a Mock Model in Pytorch with a convolution and a reduceMean layer\ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms from torch.autograd import Variable import torch.onnx as torch_onnx class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3,3), stride=1, padding=0, bias=False) def forward(self, inputs): x = self.conv(inputs) #x = x.view(x.size()[0], x.size()[1], -1) return torch.mean(x, dim=2) # Use this an input trace to serialize the model input_shape = (3, 100, 100) model_onnx_path = "torch_model.onnx" model = Model() model.train(False) # Export the model to an ONNX file dummy_input = Variable(torch.randn(1, *input_shape)) output = torch_onnx.export(model, dummy_input, model_onnx_path, verbose=False)
在您运行此脚本后,您将在同一目录中看到新创建的 .onnx 文件。现在,切换到 CNTK Conda 环境以使用 CNTK 加载模型。
接下来,激活 CNTK 环境:
$
source deactivate$
source activate cntk_p36
使用文本编辑器创建一个新文件,并在脚本中使用以下程序以在 CNTK 中打开 ONNX 格式文件。
import cntk as C # Import the PyTorch model into CNTK via the CNTK import API z = C.Function.load("torch_model.onnx", device=C.device.cpu(), format=C.ModelFormat.ONNX)
在运行此脚本后,CNTK 将加载模型。
您也可以通过以下方式使用 CNTK 导出为 ONNX:将以下内容追加到上一脚本,然后运行它。
# Export the model to ONNX via the CNTK export API z.save("cntk_model.onnx", format=C.ModelFormat.ONNX)