Pytorch Model Saving and Loading

Once the model is trained and achieves a satisfactory level of performance, it is important to save the model. Also of equal importance is to be able to load the model again so that it can be used to make inferences on new data. In this chapter of the Pytorch Tutorial, you will learn how to save a pytorch model and how to load the saved model back.

Saving and Loading the Model

To save a model, you have to call the torch.save() function and pass the model and the path where you want to save the model.

Note– Pytorch models are saved with the .pt or .pth extension.

Example

In this example, we are saving the mynet model that we have trained in the mynet_v1.pt file at the given path.

# Save the model as mynet_v1.pt at the given path
torch.save(mynet, './saved_models/mynet_v1.pt')

To load a model, you have to call the torch.load() function and pass it the path of the file from which you want to load the model.

Example

In this example, we are loading the model from the file we just created as mynet_saved.

# Load the model from mynet_v1.py file as mynet_saved
mynet_saved = torch.load('./saved_models/mynet_v1.pt')

This is one approach to save and load models in Pytorch. However, this approach is not recommended for the reason that it can lead to problems when loaded in another directory or project. Hence, it reducing the flexibility with which you can use your model.

You can overcome this hurdle by a simple workaround- storing only the model parameters as a dictionary in the file and loading them later to initialize the model. Learn More about this.


Saving and Loading the Model Parameters

To save only the parameters of the model, you have to call the torch.save() and pass the model parameters and the path where you want to save the model. The model parameters can be saved by using the state_dict() method which returns the state of the model(the parameters) as a dictionary.

Example

In this example, we are saving the parameters of mynet model in the file mynet_v1_params.pt file at the given path.

# Save the model parameters at as mynet_v1_params.pt at the given path
torch.save(mynet.state_dict(), './saved_models/mynet_v1_params.pt')

To load the model, you need to create an instance of the model and then load the model parameters from the file. The model parameters can be set by using the load_state_dict() method and passing it the parameters.

Example

In this example, first, we are instantiating MyNeuralNetwork class to create a model, mynet. Then, we are loading the parameters from the file to the model we just instantiated by making use of the load_state_dict() method.

# create an instance of the model
mynet = MyNeuralNetwork()

# Load the parameters from the file to the model
mynet.load_state_dict(torch.load('./saved_models/mynet_v1_params.pt'))