Using PyTorch-Neuron and the Amazon Neuron Compiler - Deep Learning AMI
Services or capabilities described in Amazon Web Services documentation might vary by Region. To see the differences applicable to the China Regions, see Getting Started with Amazon Web Services in China (PDF).

Using PyTorch-Neuron and the Amazon Neuron Compiler

The PyTorch-Neuron compilation API provides a method to compile a model graph that you can run on an Amazon Inferentia device.

A trained model must be compiled to an Inferentia target before it can be deployed on Inf1 instances. The following tutorial compiles the torchvision ResNet50 model and exports it as a saved TorchScript module. This model is then used to run inference.

For convenience, this tutorial uses an Inf1 instance for both compilation and inference. In practice, you may compile your model using another instance type, such as the c5 instance family. You must then deploy your compiled model to the Inf1 inference server. For more information, see the Amazon Neuron PyTorch SDK Documentation.


Before using this tutorial, you should have completed the set up steps in Launching a DLAMI Instance with Amazon Neuron. You should also have a familiarity with deep learning and using the DLAMI.

Activate the Conda Environment

Activate the PyTorch-Neuron conda environment using the following command:

source activate aws_neuron_pytorch_p36

To exit the current conda environment, run:

source deactivate

Resnet50 Compilation

Create a Python script called with the following content. This script uses the PyTorch-Neuron compilation Python API to compile a ResNet-50 model.


There is a dependency between versions of torchvision and the torch package that you should be aware of when compiling torchvision models. These dependency rules can be managed through pip. Torchvision==0.6.1 matches the torch==1.5.1 release, while torchvision==0.8.2 matches the torch==1.7.1 release.

import torch import numpy as np import os import torch_neuron from torchvision import models image = torch.zeros([1, 3, 224, 224], dtype=torch.float32) ## Load a pretrained ResNet50 model model = models.resnet50(pretrained=True) ## Tell the model we are using it for evaluation (not training) model.eval() model_neuron = torch.neuron.trace(model, example_inputs=[image]) ## Export to saved model"")

Run the compilation script.


Compilation will take a few minutes. When compilation has finished, the compiled model is saved as in the local directory.

ResNet50 Inference

Create a Python script called with the following content. This script downloads a sample image and uses it to run inference with the compiled model.

import os import time import torch import torch_neuron import json import numpy as np from urllib import request from torchvision import models, transforms, datasets ## Create an image directory containing a small kitten os.makedirs("./torch_neuron_test/images", exist_ok=True) request.urlretrieve("", "./torch_neuron_test/images/kitten_small.jpg") ## Fetch labels to output the top classifications request.urlretrieve("","imagenet_class_index.json") idx2label = [] with open("imagenet_class_index.json", "r") as read_file: class_idx = json.load(read_file) idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))] ## Import a sample image and normalize it into a tensor normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) eval_dataset = datasets.ImageFolder( os.path.dirname("./torch_neuron_test/"), transforms.Compose([ transforms.Resize([224, 224]), transforms.ToTensor(), normalize, ]) ) image, _ = eval_dataset[0] image = torch.tensor(image.numpy()[np.newaxis, ...]) ## Load model model_neuron = torch.jit.load( '' ) ## Predict results = model_neuron( image ) # Get the top 5 results top5_idx = results[0].sort()[1][-5:] # Lookup and print the top 5 labels top5_labels = [idx2label[idx] for idx in top5_idx] print("Top 5 labels:\n {}".format(top5_labels) )

Run inference with the compiled model using the following command:


Your output should look like the following:

Top 5 labels: ['tiger', 'lynx', 'tiger_cat', 'Egyptian_cat', 'tabby']