'How can we convert a .pth model into .pb file?

I have already got the complete model by using pytorch, however I wanna convert the .pth file into .pb, which could be used in Tensorflow. Does anyone have some ideas?



Solution 1:[1]

You can use ONNX: Open Neural Network Exchange Format

To convert .pth file to .pb First, you need to export a model defined in PyTorch to ONNX and then import the ONNX model into Tensorflow (PyTorch => ONNX => Tensorflow)

This is an example of MNISTModel to Convert a PyTorch model to Tensorflow using ONNX from onnx/tutorials

Save the trained model to a file

torch.save(model.state_dict(), 'output/mnist.pth')

Load the trained model from file

trained_model = Net()
trained_model.load_state_dict(torch.load('output/mnist.pth'))

# Export the trained model to ONNX
dummy_input = Variable(torch.randn(1, 1, 28, 28)) # one black and white 28 x 28 picture will be the input to the model
torch.onnx.export(trained_model, dummy_input, "output/mnist.onnx")

Load the ONNX file

model = onnx.load('output/mnist.onnx')

# Import the ONNX model to Tensorflow
tf_rep = prepare(model)

Save the Tensorflow model into a file

tf_rep.export_graph('output/mnist.pb')

AS noted by @tsveti_iko in the comment

NOTE: The prepare() is build-in in the onnx-tf, so you first need to install it through the console like this pip install onnx-tf, then import it in the code like this: import onnx from onnx_tf.backend import prepare and after that you can finally use it as described in the answer.

Solution 2:[2]

If you are using TF 1.15 or below you might not find above code helpful because you would end-up solving miss-match version error
So here is all version matched code working for TF 1.X

Keras                2.3.0
Keras-Applications   1.0.8
Keras-Preprocessing  1.1.2
numpy                1.21.5
onnx                 1.8.0
onnx-tf              1.3.0
protobuf             3.19.4
tensorboard          1.15.0
tensorflow           1.15.0
tensorflow-estimator 1.15.1
torch                1.6.0+cpu
torchvision          0.7.0+cpu

After having all these packages use the answer by Dishin

Note: Variable is depreciated in newer version of torch

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1
Solution 2