'Stable_Baselines: How to transfer trained reinforcement learning models between devices?

I've been training a reinforcement learning model on a kaggle notebook, using the gpu accelerator, since I only have a cpu laptop to work with. However when I try to import the model onto a local jupyter notebook, in order to render it, I get the following error.

AssertionError                            Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_13752/2941311152.py in <module>
      2 # device = torch.device('cpu')
      3 # model = PPO('CnnPolicy', env, verbose = 1, tensorboard_log = Log_Dir, learning_rate = 0.000001, n_steps = 512)
----> 4 PPO.load('./Training/Saved Models/Mario Models/best_model_100000', env)

c:\users\test\appdata\local\programs\python\python39\lib\site-packages\stable_baselines3\common\base_class.py in load(cls, path, env, device, custom_objects, print_system_info, force_reset, **kwargs)
    728         model.__dict__.update(data)
    729         model.__dict__.update(kwargs)
--> 730         model._setup_model()
    731 
    732         # put state_dicts back in place

c:\users\test\appdata\local\programs\python\python39\lib\site-packages\stable_baselines3\ppo\ppo.py in _setup_model(self)
    156 
    157         # Initialize schedules for policy/value clipping
--> 158         self.clip_range = get_schedule_fn(self.clip_range)
    159         if self.clip_range_vf is not None:
    160             if isinstance(self.clip_range_vf, (float, int)):

c:\users\test\appdata\local\programs\python\python39\lib\site-packages\stable_baselines3\common\utils.py in get_schedule_fn(value_schedule)
     89         value_schedule = constant_fn(float(value_schedule))
     90     else:
---> 91         assert callable(value_schedule)
     92     return value_schedule
     93 

AssertionError: 

I tried looking through documentation to correct the error, but the guidelines for transferring a standard pytorch model do not work on a stable baselines model.



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source