'Type error while tracing for exporting pytorch RL agents to ONNX

While trying to export the trained agent network in onnx format using inbuilt tracing function torch.export, facing the following error.

--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) /tmp/ipykernel_19027/3297995674.py in 38 file_path = os.path.join('tmp/onnx', agent+'_exported.onnx') 39 print(torch.flatten(torch.tensor(observations[agent]))) ---> 40 torch.onnx.export(agents[agent], tuple(observations[agent]), file_path, verbose=True)

~/.local/lib/python3.8/site-packages/torch/onnx/init.py in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format) 314 315 from torch.onnx import utils --> 316 return utils.export(model, args, f, export_params, verbose, training, 317 input_names, output_names, operator_export_type, opset_version, 318 _retain_param_name, do_constant_folding, example_outputs,

~/.local/lib/python3.8/site-packages/torch/onnx/utils.py in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, _retain_param_name, do_constant_folding, example_outputs, strip_doc_string, dynamic_axes, keep_initializers_as_inputs, custom_opsets, enable_onnx_checker, use_external_data_format) 105 "Otherwise set to False because of size limits imposed by Protocol Buffers.") 106 --> 107 _export(model, args, f, export_params, verbose, training, input_names, output_names, 108 operator_export_type=operator_export_type, opset_version=opset_version, 109 do_constant_folding=do_constant_folding, example_outputs=example_outputs,

~/.local/lib/python3.8/site-packages/torch/onnx/utils.py in _export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, fixed_batch_size, custom_opsets, add_node_names, use_external_data_format, onnx_shape_inference) 722 723 graph, params_dict, torch_out =
--> 724 _model_to_graph(model, args, verbose, input_names, 725 output_names, operator_export_type, 726 example_outputs, val_do_constant_folding,

~/.local/lib/python3.8/site-packages/torch/onnx/utils.py in _model_to_graph(model, args, verbose, input_names, output_names, operator_export_type, example_outputs, do_constant_folding, _disable_torch_constant_prop, fixed_batch_size, training, dynamic_axes) 491 args = (args, ) 492 --> 493 graph, params, torch_out, module = _create_jit_graph(model, args) 494 495 params_dict = _get_named_param_dict(graph, params)

~/.local/lib/python3.8/site-packages/torch/onnx/utils.py in _create_jit_graph(model, args) 435 return graph, params, torch_out, None 436 else: --> 437 graph, torch_out = _trace_and_get_graph_from_model(model, args) 438 state_dict = _unique_state_dict(model) 439 params = list(state_dict.values())

~/.local/lib/python3.8/site-packages/torch/onnx/utils.py in _trace_and_get_graph_from_model(model, args) 386 387 trace_graph, torch_out, inputs_states =
--> 388 torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True) 389 warn_on_static_input_change(inputs_states) 390

~/.local/lib/python3.8/site-packages/torch/jit/_trace.py in _get_trace_graph(f, args, kwargs, strict, _force_outplace, return_inputs, _return_inputs_states) 1164 if not isinstance(args, tuple): 1165 args = (args,) -> 1166 outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs) 1167
return outs

~/.local/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1101
or _global_forward_hooks or _global_forward_pre_hooks): -> 1102 return forward_call(*input, **kwargs) 1103 # Do not call functions when jit is used 1104 full_backward_hooks, non_full_backward_hooks = [], []

~/.local/lib/python3.8/site-packages/torch/jit/_trace.py in forward(self, *args) 93 94 def forward(self, *args: torch.Tensor): ---> 95 in_vars, in_desc = _flatten(args) 96 # NOTE: use full state, because we need it for BatchNorm export 97 # This differs from the compiler path, which doesn't support it at the moment.

RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: numpy.float32

code:

import torch.onnx

from ddpg_agent import Agent

from ACVEnv import *

import os
attacker_agent = Agent(agent_name = 'attacker',alpha=0.0001, beta=0.001, 
            input_dims=a_env.observation_space('attacker').shape, tau=0.001,
            batch_size=64, fc1_dims=400, fc2_dims=300, 
            n_actions=a_env.action_space('attacker').shape[0])

detector_agent = Agent(agent_name='detector' ,
                alpha=0.0001, beta=0.001, 
                input_dims=a_env.observation_space('detector').shape, tau=0.001,
                batch_size=64, fc1_dims=400, fc2_dims=300, 
                n_actions=a_env.action_space('detector').shape[0])

# attacker_agent(a_env).load_models()
attacker_agent.load_models()
detector_agent.load_models()

# attacker_agent(a_env).eval()
attacker_agent.eval()
detector_agent.eval()

a_env.reset()

observations = {'attacker':a_env._observation_spaces["attacker"].sample(), 'detector':a_env._observation_spaces["detector"].sample()}

agent_names = ["attacker", "detector"]

agents = {'attacker':attacker_agent, 'detector':detector_agent}

for agent in agent_names:
    
    file_path = os.path.join('tmp/onnx', agent+'_exported.onnx')
    print(torch.flatten(torch.tensor(observations[agent])))
    torch.onnx.export(agents[agent],  tuple(observations[agent]),  file_path,  verbose=True)

My PyTorch version is 1.10.0+cpu. The corresponding issue in pytorch forum is closed but found no solid solution! https://github.com/pytorch/pytorch/issues/29551



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source