'Can't trace the model using torch.jit.trace

Can't trace the model using torch.jit.trace. This is a resnet 101 based segmentation model. I am using python 3.7, torch 1.8, rtx 3070 8gb. My code:

Net=FCN.Net(CatDic.CatNum) 
Net.load_state_dict(torch.load('./model.torch', map_location=torch.device('cuda')), strict=False)  
Net.eval()

c = torch.jit.trace(Net, torch.randn(1, 640, 640, 3).cuda())

My neural network structure:

class Net(nn.Module):
    def __init__(self, CatDict):
            super(Net, self).__init__()
            self.Encoder = models.resnet101(pretrained=True)
            self.PSPScales = [1, 1 / 2, 1 / 4, 1 / 8]

            self.PSPLayers = nn.ModuleList()
            for Ps in self.PSPScales:
                self.PSPLayers.append(nn.Sequential(
                    nn.Conv2d(2048, 1024, stride=1, kernel_size=3, padding=1, bias=True)))
            self.PSPSqueeze = nn.Sequential(
                nn.Conv2d(4096, 512, stride=1, kernel_size=1, padding=0, bias=False),
                nn.BatchNorm2d(512),
                nn.ReLU(),
                nn.Conv2d(512, 512, stride=1, kernel_size=3, padding=0, bias=False),
                nn.BatchNorm2d(512),
                nn.ReLU()
            )
            
            self.SkipConnections = nn.ModuleList()
            self.SkipConnections.append(nn.Sequential(
                nn.Conv2d(1024, 512, stride=1, kernel_size=1, padding=0, bias=False),
                nn.BatchNorm2d(512),
                nn.ReLU()))
            self.SkipConnections.append(nn.Sequential(
                nn.Conv2d(512, 256, stride=1, kernel_size=1, padding=0, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU()))
            self.SkipConnections.append(nn.Sequential(
                nn.Conv2d(256, 256, stride=1, kernel_size=1, padding=0, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU()))
            # ------------------Skip squeeze applied to the (concat of upsample+skip conncection layers)-----------------------------------------------------------------------------
            self.SqueezeUpsample = nn.ModuleList()
            self.SqueezeUpsample.append(nn.Sequential(
                nn.Conv2d(1024, 512, stride=1, kernel_size=1, padding=0, bias=False),
                nn.BatchNorm2d(512),
                nn.ReLU()))
            self.SqueezeUpsample.append(nn.Sequential(
                nn.Conv2d(256 + 512, 256, stride=1, kernel_size=1, padding=0, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU()))
            self.SqueezeUpsample.append(nn.Sequential(
                nn.Conv2d(256 + 256, 256, stride=1, kernel_size=1, padding=0, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU()))


            self.OutLayersList =nn.ModuleList()
            self.OutLayersDict={}
            for f,nm in enumerate(CatDict):
                    self.OutLayersDict[nm]= nn.Conv2d(256, 2, stride=1, kernel_size=3, padding=1, bias=False)
                    self.OutLayersList.append(self.OutLayersDict[nm])

    def forward(self,Images, UseGPU = True, TrainMode=False, FreezeBatchNormStatistics=False):
                RGBMean = [123.68,116.779,103.939]
                RGBStd = [65,65,65]
                if TrainMode:
                        tp=torch.FloatTensor
                else:
                    self.half()
                    tp=torch.HalfTensor
                    self.eval()
                #InpImages = torch.autograd.Variable(torch.from_numpy(Images), requires_grad=False).transpose(2,3).transpose(1, 2).type(torch.FloatTensor)

                InpImages = torch.autograd.Variable(Images, requires_grad=False).transpose(2,3).transpose(1, 2).type(tp)
                if FreezeBatchNormStatistics==True: self.eval()
                if UseGPU:
                    InpImages=InpImages.cuda()
                    self.cuda()
                else:
                    self=self.cpu()
                    self.float()
                    InpImages=InpImages.type(torch.float).cpu()
                for i in range(len(RGBMean)): InpImages[:, i, :, :]=(InpImages[:, i, :, :]-RGBMean[i])/RGBStd[i] # normalize image values
                x=InpImages
                SkipConFeatures=[] # Store features map of layers used for skip connection
                x = self.Encoder.conv1(x)
                x = self.Encoder.bn1(x)
                x = self.Encoder.relu(x)
                x = self.Encoder.maxpool(x)
                x = self.Encoder.layer1(x)
                SkipConFeatures.append(x)
                x = self.Encoder.layer2(x)
                SkipConFeatures.append(x)
                x = self.Encoder.layer3(x)
                SkipConFeatures.append(x)
                x = self.Encoder.layer4(x)
                PSPSize=(x.shape[2],x.shape[3]) # Size of the original features map

                PSPFeatures=[] # Results of various of scaled procceessing
                for i,PSPLayer in enumerate(self.PSPLayers): # run PSP layers scale features map to various of sizes apply convolution and concat the results
                      NewSize=(np.array(PSPSize)*self.PSPScales[i]).astype(np.int)
                      if NewSize[0] < 1: NewSize[0] = 1
                      if NewSize[1] < 1: NewSize[1] = 1

                      y = nn.functional.interpolate(x, tuple(NewSize), mode='bilinear')
                      y = PSPLayer(y)
                      y = nn.functional.interpolate(y, PSPSize, mode='bilinear')
                      PSPFeatures.append(y)
                x=torch.cat(PSPFeatures,dim=1)
                x=self.PSPSqueeze(x)
                for i in range(len(self.SkipConnections)):
                  sp=(SkipConFeatures[-1-i].shape[2],SkipConFeatures[-1-i].shape[3])
                  x=nn.functional.interpolate(x,size=sp,mode='bilinear') #Resize
                  x = torch.cat((self.SkipConnections[i](SkipConFeatures[-1-i]),x), dim=1)
                  x = self.SqueezeUpsample[i](x)

                self.OutLbDict = {}
                
                ret_arr = np.eye(640, 640)
                for nm in self.OutLayersDict:
                  l=self.OutLayersDict[nm](x)
                  l = nn.functional.interpolate(l,size=InpImages.shape[2:4],mode='bilinear') # Resize to original image size
                  tt, Labels = l.max(1)  # Find label per pixel
                  self.OutLbDict[nm] = Labels
                  array = np.asarray(self.OutLbDict[nm].cpu())
                  resx = np.reshape(array, ((array.shape)[2], (array.shape)[1]))
                  ret_arr = list(ret_arr + resx*10)

                return ret_arr

I get an error:

RuntimeError: Tracer cannot infer type of [array([..])]
:Could not infer type of list element: Only tensors and (possibly nested) tuples of tensors, lists, or dictsare supported as inputs or outputs of traced functions, but instead got value of type ndarray.

If I remove all numpy arrays from the code, then I get a different error:

C:\anaconda3\lib\site-packages\torch\jit_trace.py in _check_trace(check_inputs, func, traced_func, check_tolerance, strict, force_outplace, is_trace_module, _module_class) 517 diag_info = graph_diagnostic_info() 518 if any(info is not None for info in diag_info): --> 519 raise TracingCheckError(*diag_info) 520 521

TracingCheckError: Tracing failed sanity checks!
ERROR: Graphs differed across invocations!
    Graph diff:
          graph(%self.1 : __torch__.FCN_NetModel.Net,
                %Images : Tensor):
            %2 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="SqueezeUpsample"](%self.1)
            %3 : __torch__.torch.nn.modules.container.Sequential = prim::GetAttr[name="2"](%2)
            %4 : __torch__.torch.nn.modules.container.ModuleList = prim::GetAttr[name="SkipConnections"](%self.1)
            %5 : __torch__.torch.nn.modules.container.Sequential = prim::GetAttr[name="2"](%4)

........

The interesting thing is that if I run c = torch.jit.trace (Net, torch.randn (1, 640, 640, 3) .cuda ()) again, the last error does not occur and the tracing is successful. But this traced model doesn't work. I would be grateful for your help.



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source