使用 Apache MXNet(孵化中)在 ONNX 模型中进行推理 - 深度学习 AMI
Amazon Web Services 文档中描述的 Amazon Web Services 服务或功能可能因区域而异。要查看适用于中国区域的差异,请参阅中国的 Amazon Web Services 服务入门

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

使用 Apache MXNet(孵化中)在 ONNX 模型中进行推理

如何使用 ONNX 模型通过 Apache MXNet 进行图像推理(孵化中)

    • (Python 3 的选项)-激活 Python 3 Apache MXNet(孵化)环境:

      $ source activate mxnet_p36
    • (Python 2 的选项)-激活 Python 2 Apache MXNet(孵化)环境:

      $ source activate mxnet_p27
  1. 其余步骤假定您使用的是 mxnet_p36 环境。

  2. 下载一张哈士奇的照片。

    $ curl -O https://upload.wikimedia.org/wikipedia/commons/b/b5/Siberian_Husky_bi-eyed_Flickr.jpg
  3. 下载使用此模型的类的列表。

    $ curl -O https://gist.githubusercontent.com/yrevar/6135f1bd8dcf2e0cc683/raw/d133d61a09d7e5a3b36b8c111a8dd5c4b5d560ee/imagenet1000_clsid_to_human.pkl
  4. 下载 ONNX 格式的预训练的 VGG 16 模型。

    $ wget -O vgg16.onnx https://github.com/onnx/models/raw/master/vision/classification/vgg/model/vgg16-7.onnx
  5. 使用您的首选文本编辑器来创建具有以下内容的脚本。此脚本将使用哈士奇的图片,从预训练模型中获得预测结果,然后在类文件中查找,并返回一个图片分类结果。

    import mxnet as mx import mxnet.contrib.onnx as onnx_mxnet import numpy as np from collections import namedtuple from PIL import Image import pickle # Preprocess the image img = Image.open("Siberian_Husky_bi-eyed_Flickr.jpg") img = img.resize((224,224)) rgb_img = np.asarray(img, dtype=np.float32) - 128 bgr_img = rgb_img[..., [2,1,0]] img_data = np.ascontiguousarray(np.rollaxis(bgr_img,2)) img_data = img_data[np.newaxis, :, :, :].astype(np.float32) # Define the model's input data_names = ['data'] Batch = namedtuple('Batch', data_names) # Set the context to cpu or gpu ctx = mx.cpu() # Load the model sym, arg, aux = onnx_mxnet.import_model("vgg16.onnx") mod = mx.mod.Module(symbol=sym, data_names=data_names, context=ctx, label_names=None) mod.bind(for_training=False, data_shapes=[(data_names[0],img_data.shape)], label_shapes=None) mod.set_params(arg_params=arg, aux_params=aux, allow_missing=True, allow_extra=True) # Run inference on the image mod.forward(Batch([mx.nd.array(img_data)])) predictions = mod.get_outputs()[0].asnumpy() top_class = np.argmax(predictions) print(top_class) labels_dict = pickle.load(open("imagenet1000_clsid_to_human.pkl", "rb")) print(labels_dict[top_class])
  6. 然后运行脚本,您应看到一个如下所示的结果:

    248 Eskimo dog, husky