'Exception when converting Unet from pytorch to onnx

I'm trying to convert a Unet model from PyTorch to ONNX.

Running the following code:

import torch
from unets import Unet, thin_setup

net = Unet(in_features=3, down=[16, 32, 64, 64, 64], up=[64, 64, 64, 128 + 1],
           setup={**thin_setup, 'bias': True, 'padding': True})
net.eval()

inputs = torch.randn((1, 3, 768, 768))
outputs = net(inputs)
torch.onnx.export(net, inputs, "unet.onnx", opset_version=12)

a RuntimeError: Unsupported: ONNX export of instance_norm for unknown channel size. exception is raised.

How am I solving it?

remark: I suspect that this is due to a node of a upsample layer that has no output shape: %196 : Float(*, *, *, *, strides=[589824, 9216, 96, 1], requires_grad=1, device=cpu) = onnx::Resize[coordinate_transformation_mode="pytorch_half_pixel", cubic_coeff_a=-0.75, mode="linear", nearest_mode="floor"](%169, %194, %195, %193) # ~/miniconda/envs/my_env/lib/python3.7/site-packages/torch/nn/functional.py:3709:0

environment: python 3.7 / torch 1.9.1+cu102 / onnx 1.10.2



Solution 1:[1]

The problem is due to ONNX not having an implementation of the PyTorch 2D Instane Normalization layer. The solution was to copy the relevant UNet code and implement the layer myself:

class InstanceNormAlternative(nn.InstanceNorm2d):

    def forward(self, inp: Tensor) -> Tensor:
        self._check_input_dim(inp)

        desc = 1 / (input.var(axis=[2, 3], keepdim=True, unbiased=False) + self.eps) ** 0.5
        retval = (input - input.mean(axis=[2, 3], keepdim=True)) * desc
        return retval

Make sure to use unbiased variance if you wish to be as similar as possible to PyTorch.

NOTE: CoreML tools cannot convert the variance operator from PyTorch to CoreML. Make sure to use PyTorch's nn.InstanceNorm2d layer (and not the above alternative) when converting to CoreML.

FREE TIP: If converting PyTorch UNets to TF, you are also going to encounter the following error: RuntimeError: Resize coordinate_transformation_mode=pytorch_half_pixel is not supported in Tensorflow The remedy is to change the interpolation parameters in TrivialUpsample.forward to align_corners=True. In my experience, the effect of the change on the network output was minor.

This answer was written with the help of Micha? Tyszkiewicz.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Yohai Devir