diff --git a/sb3_trl/trl_pg/trl_pg.py b/sb3_trl/trl_pg/trl_pg.py index 34b2c83..d217f89 100644 --- a/sb3_trl/trl_pg/trl_pg.py +++ b/sb3_trl/trl_pg/trl_pg.py @@ -439,7 +439,7 @@ class TRL_PG(OnPolicyAlgorithm): # Rescale and perform action clipped_actions = actions # Clip the actions to avoid out of bound error - if isinstance(self.action_space, gym.spaces.Box): + if isinstance(self.action_space, spaces.Box): clipped_actions = np.clip( actions, self.action_space.low, self.action_space.high) @@ -455,7 +455,7 @@ class TRL_PG(OnPolicyAlgorithm): self._update_info_buffer(infos) n_steps += 1 - if isinstance(self.action_space, gym.spaces.Discrete): + if isinstance(self.action_space, spaces.Discrete): # Reshape in case of discrete action actions = actions.reshape(-1, 1)