PyTorch到 ONNX 到 MXNet 教程 - 深度学习 AMI
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅 中国的 Amazon Web Services 服务入门 (PDF)

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

PyTorch到 ONNX 到 MXNet 教程

ONNX 概述

开放神经网络交换 (ONNX) 是一种用于表示深度学习模型的开放格式。ONNX 得到亚马逊网络服务、微软、Facebook 和其他几个合作伙伴的支持。您可以使用任何选定的框架来设计、训练和部署深度学习模型。ONNX 模型的好处是,它们可以在框架之间轻松移动。

本教程向你展示了如何在 ONNX 上使用 Conda 的深度学习 AMI。通过执行以下步骤,您可以训练模型或从一个框架中加载预先训练的模型,将此模型导出为 ONNX,然后将此模型导入到另一个框架中。

ONNX 先决条件

要使用此 ONNX 教程,你必须有权访问装有 Conda 版本 12 或更高版本的深度学习 AMI。有关如何使用 Conda 开始使用深度学习 AMI 的更多信息,请参阅使用 Conda 进行深度学习 AMI

重要

这些示例使用可能需要多达 8 GB 内存(或更多)的函数。请务必选择具有足量内存的实例类型。

使用 Conda 启动深度学习 AMI 的终端会话,开始以下教程。

将PyTorch模型转换为 ONNX,然后将模型加载到 MXNet

首先,激活PyTorch环境:

$ source activate pytorch_p36

使用文本编辑器创建一个新文件,在脚本中使用以下程序训练模拟模型PyTorch,然后将其导出为 ONNX 格式。

# Build a Mock Model in PyTorch with a convolution and a reduceMean layer import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms from torch.autograd import Variable import torch.onnx as torch_onnx class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3,3), stride=1, padding=0, bias=False) def forward(self, inputs): x = self.conv(inputs) #x = x.view(x.size()[0], x.size()[1], -1) return torch.mean(x, dim=2) # Use this an input trace to serialize the model input_shape = (3, 100, 100) model_onnx_path = "torch_model.onnx" model = Model() model.train(False) # Export the model to an ONNX file dummy_input = Variable(torch.randn(1, *input_shape)) output = torch_onnx.export(model, dummy_input, model_onnx_path, verbose=False) print("Export of torch_model.onnx complete!")

在您运行此脚本后,您将在同一目录中看到新创建的 .onnx 文件。现在,切换到 MXNet Conda 环境以使用 MXNet 加载模型。

接下来,激活 MXNet 环境:

$ source deactivate $ source activate mxnet_p36

使用文本编辑器创建一个新文件,并在脚本中使用以下程序以在 MXNet 中打开 ONNX 格式文件。

import mxnet as mx from mxnet.contrib import onnx as onnx_mxnet import numpy as np # Import the ONNX model into MXNet's symbolic interface sym, arg, aux = onnx_mxnet.import_model("torch_model.onnx") print("Loaded torch_model.onnx!") print(sym.get_internals())

运行此脚本后,MXNet 将拥有加载的模型,并打印一些基本模型信息。