fix rendering

This commit is contained in:
Maximilian Huettenrauch 2021-04-30 16:22:33 +02:00
parent 2c8335f632
commit c307383873
2 changed files with 3 additions and 3 deletions

View File

@ -13,7 +13,7 @@ class DmpWrapper(MPWrapper):
final_pos: np.ndarray = None, duration: int = 1, alpha_phase: float = 2., dt: float = None,
learn_goal: bool = False, return_to_start: bool = False, post_traj_time: float = 0.,
weights_scale: float = 1., goal_scale: float = 1., bandwidth_factor: float = 3.,
policy_type: str = None):
policy_type: str = None, render_mode: str = None):
"""
This Wrapper generates a trajectory based on a DMP and will only return episodic performances.
@ -45,7 +45,7 @@ class DmpWrapper(MPWrapper):
self.t = np.linspace(0, duration, int(duration / dt))
self.goal_scale = goal_scale
super().__init__(env, num_dof, duration, dt, post_traj_time, policy_type, weights_scale,
super().__init__(env, num_dof, duration, dt, post_traj_time, policy_type, weights_scale, render_mode,
num_basis=num_basis, start_pos=start_pos, final_pos=final_pos, alpha_phase=alpha_phase,
bandwidth_factor=bandwidth_factor)

View File

@ -38,7 +38,7 @@ class MPWrapper(gym.Wrapper, ABC):
# rendering
self.render_mode = render_mode
self.render_kwargs = None
self.render_kwargs = {}
# TODO: not yet final
def __call__(self, params, contexts=None):