fix rendering
This commit is contained in:
parent
2c8335f632
commit
c307383873
@ -13,7 +13,7 @@ class DmpWrapper(MPWrapper):
|
|||||||
final_pos: np.ndarray = None, duration: int = 1, alpha_phase: float = 2., dt: float = None,
|
final_pos: np.ndarray = None, duration: int = 1, alpha_phase: float = 2., dt: float = None,
|
||||||
learn_goal: bool = False, return_to_start: bool = False, post_traj_time: float = 0.,
|
learn_goal: bool = False, return_to_start: bool = False, post_traj_time: float = 0.,
|
||||||
weights_scale: float = 1., goal_scale: float = 1., bandwidth_factor: float = 3.,
|
weights_scale: float = 1., goal_scale: float = 1., bandwidth_factor: float = 3.,
|
||||||
policy_type: str = None):
|
policy_type: str = None, render_mode: str = None):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This Wrapper generates a trajectory based on a DMP and will only return episodic performances.
|
This Wrapper generates a trajectory based on a DMP and will only return episodic performances.
|
||||||
@ -45,7 +45,7 @@ class DmpWrapper(MPWrapper):
|
|||||||
self.t = np.linspace(0, duration, int(duration / dt))
|
self.t = np.linspace(0, duration, int(duration / dt))
|
||||||
self.goal_scale = goal_scale
|
self.goal_scale = goal_scale
|
||||||
|
|
||||||
super().__init__(env, num_dof, duration, dt, post_traj_time, policy_type, weights_scale,
|
super().__init__(env, num_dof, duration, dt, post_traj_time, policy_type, weights_scale, render_mode,
|
||||||
num_basis=num_basis, start_pos=start_pos, final_pos=final_pos, alpha_phase=alpha_phase,
|
num_basis=num_basis, start_pos=start_pos, final_pos=final_pos, alpha_phase=alpha_phase,
|
||||||
bandwidth_factor=bandwidth_factor)
|
bandwidth_factor=bandwidth_factor)
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ class MPWrapper(gym.Wrapper, ABC):
|
|||||||
|
|
||||||
# rendering
|
# rendering
|
||||||
self.render_mode = render_mode
|
self.render_mode = render_mode
|
||||||
self.render_kwargs = None
|
self.render_kwargs = {}
|
||||||
|
|
||||||
# TODO: not yet final
|
# TODO: not yet final
|
||||||
def __call__(self, params, contexts=None):
|
def __call__(self, params, contexts=None):
|
||||||
|
Loading…
Reference in New Issue
Block a user