fix rendering
This commit is contained in:
parent
2c8335f632
commit
c307383873
@ -13,7 +13,7 @@ class DmpWrapper(MPWrapper):
|
||||
final_pos: np.ndarray = None, duration: int = 1, alpha_phase: float = 2., dt: float = None,
|
||||
learn_goal: bool = False, return_to_start: bool = False, post_traj_time: float = 0.,
|
||||
weights_scale: float = 1., goal_scale: float = 1., bandwidth_factor: float = 3.,
|
||||
policy_type: str = None):
|
||||
policy_type: str = None, render_mode: str = None):
|
||||
|
||||
"""
|
||||
This Wrapper generates a trajectory based on a DMP and will only return episodic performances.
|
||||
@ -45,7 +45,7 @@ class DmpWrapper(MPWrapper):
|
||||
self.t = np.linspace(0, duration, int(duration / dt))
|
||||
self.goal_scale = goal_scale
|
||||
|
||||
super().__init__(env, num_dof, duration, dt, post_traj_time, policy_type, weights_scale,
|
||||
super().__init__(env, num_dof, duration, dt, post_traj_time, policy_type, weights_scale, render_mode,
|
||||
num_basis=num_basis, start_pos=start_pos, final_pos=final_pos, alpha_phase=alpha_phase,
|
||||
bandwidth_factor=bandwidth_factor)
|
||||
|
||||
|
@ -38,7 +38,7 @@ class MPWrapper(gym.Wrapper, ABC):
|
||||
|
||||
# rendering
|
||||
self.render_mode = render_mode
|
||||
self.render_kwargs = None
|
||||
self.render_kwargs = {}
|
||||
|
||||
# TODO: not yet final
|
||||
def __call__(self, params, contexts=None):
|
||||
|
Loading…
Reference in New Issue
Block a user