diff --git a/alr_envs/utils/wrapper/dmp_wrapper.py b/alr_envs/utils/wrapper/dmp_wrapper.py index d6c8f92..8f94227 100644 --- a/alr_envs/utils/wrapper/dmp_wrapper.py +++ b/alr_envs/utils/wrapper/dmp_wrapper.py @@ -13,7 +13,7 @@ class DmpWrapper(MPWrapper): final_pos: np.ndarray = None, duration: int = 1, alpha_phase: float = 2., dt: float = None, learn_goal: bool = False, return_to_start: bool = False, post_traj_time: float = 0., weights_scale: float = 1., goal_scale: float = 1., bandwidth_factor: float = 3., - policy_type: str = None): + policy_type: str = None, render_mode: str = None): """ This Wrapper generates a trajectory based on a DMP and will only return episodic performances. @@ -45,7 +45,7 @@ class DmpWrapper(MPWrapper): self.t = np.linspace(0, duration, int(duration / dt)) self.goal_scale = goal_scale - super().__init__(env, num_dof, duration, dt, post_traj_time, policy_type, weights_scale, + super().__init__(env, num_dof, duration, dt, post_traj_time, policy_type, weights_scale, render_mode, num_basis=num_basis, start_pos=start_pos, final_pos=final_pos, alpha_phase=alpha_phase, bandwidth_factor=bandwidth_factor) diff --git a/alr_envs/utils/wrapper/mp_wrapper.py b/alr_envs/utils/wrapper/mp_wrapper.py index 3d8e86a..34f0440 100644 --- a/alr_envs/utils/wrapper/mp_wrapper.py +++ b/alr_envs/utils/wrapper/mp_wrapper.py @@ -38,7 +38,7 @@ class MPWrapper(gym.Wrapper, ABC): # rendering self.render_mode = render_mode - self.render_kwargs = None + self.render_kwargs = {} # TODO: not yet final def __call__(self, params, contexts=None):