diff --git a/alr_envs/utils/wrapper/mp_wrapper.py b/alr_envs/utils/wrapper/mp_wrapper.py index e9ab5f1..3d8e86a 100644 --- a/alr_envs/utils/wrapper/mp_wrapper.py +++ b/alr_envs/utils/wrapper/mp_wrapper.py @@ -17,6 +17,7 @@ class MPWrapper(gym.Wrapper, ABC): post_traj_time: float = 0., policy_type: str = None, weights_scale: float = 1., + render_mode: str = None, **mp_kwargs ): super().__init__(env) @@ -36,7 +37,7 @@ class MPWrapper(gym.Wrapper, ABC): self.policy = policy_class(env) # rendering - self.render_mode = None + self.render_mode = render_mode self.render_kwargs = None # TODO: not yet final