Adapt Your PyTorch Training Script - Amazon SageMaker
Services or capabilities described in Amazon Web Services documentation might vary by Region. To see the differences applicable to the China Regions, see Getting Started with Amazon Web Services in China (PDF).

Adapt Your PyTorch Training Script

To start collecting model output tensors and debug training issues, make the following modifications to your PyTorch training script.

For PyTorch 1.12.0

If you bring a PyTorch training script, you can run the training job and extract model output tensors with a few additional code lines in your training script. You need to use the hook APIs in the sagemaker-debugger client library. Walk through the following instructions that break down the steps with code examples.

  1. Create a hook.

    (Recommended) For training jobs within SageMaker

    import smdebug.pytorch as smd hook=smd.get_hook(create_if_not_exists=True)

    When you launch a training job in Step 2: Launch and Debug Training Jobs Using SageMaker Python SDK with any of the DebuggerHookConfig, TensorBoardConfig, or Rules in your estimator, SageMaker adds a JSON configuration file to your training instance that is picked up by the get_hook function. Note that if you do not include any of the configuration APIs in your estimator, there will be no configuration file for the hook to find, and the function returns None.

    (Optional) For training jobs outside SageMaker

    If you run training jobs in local mode, directly on SageMaker Notebook instances, Amazon EC2 instances, or your own local devices, use smd.Hook class to create a hook. However, this approach can only store the tensor collections and usable for TensorBoard visualization. SageMaker Debugger’s built-in Rules don’t work with the local mode because the Rules require SageMaker ML training instances and S3 to store outputs from the remote instances in real time. The smd.get_hook API returns None in this case.

    If you want to create a manual hook to save tensors in local mode, use the following code snippet with the logic to check if the smd.get_hook API returns None and create a manual hook using the smd.Hook class. Note that you can specify any output directory in your local machine.

    import smdebug.pytorch as smd hook=smd.get_hook(create_if_not_exists=True) if hook is None: hook=smd.Hook( out_dir='/path/to/your/local/output/', export_tensorboard=True )
  2. Wrap your model with the hook’s class methods.

    The hook.register_module() method takes your model and iterates through each layer, looking for any tensors that match with regular expressions that you’ll provide through the configuration in Step 2: Launch and Debug Training Jobs Using SageMaker Python SDK. The collectable tensors through this hook method are weights, biases, activations, gradients, inputs, and outputs.

    hook.register_module(model)
    Tip

    If you collect the entire output tensors from a large deep learning model, the total size of those collections can exponentially grow and might cause bottlenecks. If you want to save specific tensors, you can also use the hook.save_tensor() method. This method helps you pick the variable for the specific tensor and save to a custom collection named as you want. For more information, see step 7 of this instruction.

  3. Warp the loss function with the hook’s class methods.

    The hook.register_loss method is to wrap the loss function. It extracts loss values every save_interval that you’ll set during configuration in Step 2: Launch and Debug Training Jobs Using SageMaker Python SDK, and saves them to the "losses" collection.

    hook.register_loss(loss_function)
  4. Add hook.set_mode(ModeKeys.TRAIN) in the train block. This indicates the tensor collection is extracted during the training phase.

    def train(): ... hook.set_mode(ModeKeys.TRAIN)
  5. Add hook.set_mode(ModeKeys.EVAL) in the validation block. This indicates the tensor collection is extracted during the validation phase.

    def validation(): ... hook.set_mode(ModeKeys.EVAL)
  6. Use hook.save_scalar() to save custom scalars. You can save scalar values that aren’t in your model. For example, if you want to record the accuracy values computed during evaluation, add the following line of code below the line where you calculate accuracy.

    hook.save_scalar("accuracy", accuracy)

    Note that you need to provide a string as the first argument to name the custom scalar collection. This is the name that'll be used for visualizing the scalar values in TensorBoard, and can be any string you want.

  7. Use hook.save_tensor() to save custom tensors. Similarly to hook.save_scalar(), you can save additional tensors, defining your own tensor collection. For example, you can extract input image data that are passed into the model and save as a custom tensor by adding the following code line, where "images" is an example name of the custom tensor, image_inputs is an example variable for the input image data.

    hook.save_tensor("images", image_inputs)

    Note that you must provide a string to the first argument to name the custom tensor. hook.save_tensor() has the third argument collections_to_write to specify the tensor collection to save the custom tensor. The default is collections_to_write="default". If you don't explicitely specify the third argument, the custom tensor is saved to the "default" tensor collection.

After you have completed adapting your training script, proceed to Step 2: Launch and Debug Training Jobs Using SageMaker Python SDK.