'How to get input size for a operator in pytorch script model?

I use this code to transfer the model to script model:

scripted_model = torch.jit.trace(detector.model, images).eval()

Then I print the scripted_model. A part of the output is as follows:

 (base): DLA(
    original_name=DLA
    (base_layer): Sequential(
      original_name=Sequential
      (0): Conv2d(original_name=Conv2d)
      (1): BatchNorm2d(original_name=BatchNorm2d)
      (2): ReLU(original_name=ReLU)
    )
    (level0): Sequential(
      original_name=Sequential
      (0): Conv2d(original_name=Conv2d)
      (1): BatchNorm2d(original_name=BatchNorm2d)
      (2): ReLU(original_name=ReLU)
    )
    (level1): Sequential(
      original_name=Sequential
      (0): Conv2d(original_name=Conv2d)
      (1): BatchNorm2d(original_name=BatchNorm2d)
      (2): ReLU(original_name=ReLU)
    )
    (level2): Tree(
      original_name=Tree
      (tree1): BasicBlock(
        original_name=BasicBlock
        (conv1): Conv2d(original_name=Conv2d)
        (bn1): BatchNorm2d(original_name=BatchNorm2d)
        (relu): ReLU(original_name=ReLU)
        (conv2): Conv2d(original_name=Conv2d)
        (bn2): BatchNorm2d(original_name=BatchNorm2d)
      )
      (tree2): BasicBlock(
        original_name=BasicBlock
        (conv1): Conv2d(original_name=Conv2d)
        (bn1): BatchNorm2d(original_name=BatchNorm2d)
        (relu): ReLU(original_name=ReLU)
        (conv2): Conv2d(original_name=Conv2d)
        (bn2): BatchNorm2d(original_name=BatchNorm2d
      )
      (root): Root(
        original_name=Root
        (conv): Conv2d(original_name=Conv2d)
        (bn): BatchNorm2d(original_name=BatchNorm2d)
        (relu): ReLU(original_name=ReLU)
      )
      (downsample): MaxPool2d(original_name=MaxPool2d)
      (project): Sequential(
        original_name=Sequential
        (0): Conv2d(original_name=Conv2d)
        (1): BatchNorm2d(original_name=BatchNorm2d)
      )
    ) 
...

I just want to get the input size for the operator, such as how many inputs for the operator (0): Conv2d(original_name=Conv2d). I print the graph of this script model, the output is as follows:

  %4770 : __torch__.torch.nn.modules.module.___torch_mangle_11.Module = prim::GetAttr[name="wh"](%self.1)
  %4762 : __torch__.torch.nn.modules.module.___torch_mangle_15.Module = prim::GetAttr[name="tracking"](%self.1)
  %4754 : __torch__.torch.nn.modules.module.___torch_mangle_23.Module = prim::GetAttr[name="rot"](%self.1)
  %4746 : __torch__.torch.nn.modules.module.___torch_mangle_7.Module = prim::GetAttr[name="reg"](%self.1)
  %4738 : __torch__.torch.nn.modules.module.___torch_mangle_3.Module = prim::GetAttr[name="hm"](%self.1)
  %4730 : __torch__.torch.nn.modules.module.___torch_mangle_27.Module = prim::GetAttr[name="dim"](%self.1)
  %4722 : __torch__.torch.nn.modules.module.___torch_mangle_19.Module = prim::GetAttr[name="dep"](%self.1)
  %4714 : __torch__.torch.nn.modules.module.___torch_mangle_31.Module = prim::GetAttr[name="amodel_offset"](%self.1)
  %4706 : __torch__.torch.nn.modules.module.___torch_mangle_289.Module = prim::GetAttr[name="ida_up"](%self.1)
  %4645 : __torch__.torch.nn.modules.module.___torch_mangle_262.Module = prim::GetAttr[name="dla_up"](%self.1)
  %4461 : __torch__.torch.nn.modules.module.___torch_mangle_180.Module = prim::GetAttr[name="base"](%self.1)
  %5100 : (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) = prim::CallMethod[name="forward"](%4461, %input.1)
  %5082 : Tensor, %5083 : Tensor, %5084 : Tensor, %5085 : Tensor, %5086 : Tensor, %5087 : Tensor, %5088 : Tensor, %5089 : Tensor = prim::TupleUnpack(%5100)
  %5101 : (Tensor, Tensor, Tensor) = prim::CallMethod[name="forward"](%4645, %5082, %5083, %5084, %5085, %5086, %5087, %5088, %5089)
  %5097 : Tensor, %5098 : Tensor, %5099 : Tensor = prim::TupleUnpack(%5101)
  %3158 : None = prim::Constant()

I even can find the operator name. How can I get input size for a specific operator in the script model?



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source