深度学习 AMI
开发人员指南
AWS 文档中描述的 AWS 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅中国的 AWS 服务入门

有关将 PyTorch 转换为 ONNX,然后加载到 MXNet 的教程

ONNX 概述

开放神经网络交换 (ONNX) 是一种用于表示深度学习模型的开放格式。ONNX 受到 Amazon Web Services、Microsoft、Facebook 和其他多个合作伙伴的支持。您可以使用任何选定的框架来设计、训练和部署深度学习模型。ONNX 模型的好处是,它们可以在框架之间轻松移动。

本教程介绍如何将 采用 Conda 的 Deep Learning AMI 与 ONNX 结合使用。通过执行以下步骤,您可以训练模型或从一个框架中加载预先训练的模型,将此模型导出为 ONNX,然后将此模型导入到另一个框架中。

ONNX 先决条件

要使用此 ONNX 教程,您必须有权访问 采用 Conda 的 Deep Learning AMI 版本 12 或更高版本。有关如何开始使用 采用 Conda 的 Deep Learning AMI 的更多信息,请参阅 采用 Conda 的 Deep Learning AMI

重要

这些示例使用可能需要多达 8 GB 内存(或更多)的函数。请务必选择具有足量内存的实例类型。

使用 采用 Conda 的 Deep Learning AMI 启动终端会话以开始以下教程。

将 PyTorch 模型转换为 ONNX,然后将模型加载到 MXNet 中

首先,激活 PyTorch 环境:

$ source activate pytorch_p36

使用文本编辑器创建一个新文件,并在脚本中使用以下程序来训练 PyTorch 中的模拟模型,然后将它导出为 ONNX 格式。

# Build a Mock Model in PyTorch with a convolution and a reduceMean layer import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms from torch.autograd import Variable import torch.onnx as torch_onnx class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3,3), stride=1, padding=0, bias=False) def forward(self, inputs): x = self.conv(inputs) #x = x.view(x.size()[0], x.size()[1], -1) return torch.mean(x, dim=2) # Use this an input trace to serialize the model input_shape = (3, 100, 100) model_onnx_path = "torch_model.onnx" model = Model() model.train(False) # Export the model to an ONNX file dummy_input = Variable(torch.randn(1, *input_shape)) output = torch_onnx.export(model, dummy_input, model_onnx_path, verbose=False) print("Export of torch_model.onnx complete!")

在您运行此脚本后,您将在同一目录中看到新创建的 .onnx 文件。现在,切换到 MXNet Conda 环境以使用 MXNet 加载模型。

接下来,激活 MXNet 环境:

$ source deactivate $ source activate mxnet_p36

使用文本编辑器创建一个新文件,并在脚本中使用以下程序以在 MXNet 中打开 ONNX 格式文件。

import mxnet as mx from mxnet.contrib import onnx as onnx_mxnet import numpy as np # Import the ONNX model into MXNet's symbolic interface sym, arg, aux = onnx_mxnet.import_model("torch_model.onnx") print("Loaded torch_model.onnx!") print(sym.get_internals())

运行此脚本后,MXNet 将拥有加载的模型,并打印一些基本模型信息。