有关将 PyTorch 转换为 ONNX,然后加载到 MXNet 的教程 - 深度学习 AMI
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅中国的 Amazon Web Services 服务入门

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

有关将 PyTorch 转换为 ONNX,然后加载到 MXNet 的教程

ONNX 概述

开放神经网络交换 (ONNX) 是一种用于表示深度学习模型的开放格式。ONNX 受到 Amazon Web Services、Microsoft、Facebook 和其他多个合作伙伴的支持。您可以使用任何选定的框架来设计、训练和部署深度学习模型。ONNX 模型的好处是,它们可以在框架之间轻松移动。

本教程介绍如何将深度学习 AMI 与 ONNX 结合使用。通过执行以下步骤,您可以训练模型或从一个框架中加载预先训练的模型,将此模型导出为 ONNX,然后将此模型导入到另一个框架中。

ONNX 先决条件

要使用此 ONNX 教程,您必须有权访问具有 Conda 版本 12 或更高版本的深度学习 AMI。有关如何在 Conda 上开始使用深度学习 AMI 的更多信息,请参阅。使用 Conda 的深度学习 AMI.

重要

这些示例使用可能需要多达 8 GB 内存(或更多)的函数。请务必选择具有足量内存的实例类型。

使用 Conda 启动终端会话以开始以下教程。

将 PyTorch 模型转换为 ONNX,然后将模型加载到 MXNet 中

首先,激活 PyTorch 环境:

$ source activate pytorch_p36

使用文本编辑器创建一个新文件,并在脚本中使用以下程序来训练 PyTorch 中的模拟模型,然后将它导出为 ONNX 格式。

# Build a Mock Model in PyTorch with a convolution and a reduceMean layer import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms from torch.autograd import Variable import torch.onnx as torch_onnx class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3,3), stride=1, padding=0, bias=False) def forward(self, inputs): x = self.conv(inputs) #x = x.view(x.size()[0], x.size()[1], -1) return torch.mean(x, dim=2) # Use this an input trace to serialize the model input_shape = (3, 100, 100) model_onnx_path = "torch_model.onnx" model = Model() model.train(False) # Export the model to an ONNX file dummy_input = Variable(torch.randn(1, *input_shape)) output = torch_onnx.export(model, dummy_input, model_onnx_path, verbose=False) print("Export of torch_model.onnx complete!")

在您运行此脚本后,您将在同一目录中看到新创建的 .onnx 文件。现在,切换到 MXNet Conda 环境以使用 MXNet 加载模型。

接下来,激活 MXNet 环境:

$ source deactivate $ source activate mxnet_p36

使用文本编辑器创建一个新文件,并在脚本中使用以下程序以在 MXNet 中打开 ONNX 格式文件。

import mxnet as mx from mxnet.contrib import onnx as onnx_mxnet import numpy as np # Import the ONNX model into MXNet's symbolic interface sym, arg, aux = onnx_mxnet.import_model("torch_model.onnx") print("Loaded torch_model.onnx!") print(sym.get_internals())

运行此脚本后,MXNet 将拥有加载的模型,并打印一些基本模型信息。