Merge pull request #83 from ALRhub/78_fix_render_mode
Fix render_mode not working for many fancy envs
This commit is contained in:
commit
216a6f215d
@ -75,7 +75,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
self.observation_space = self._get_observation_space()
|
self.observation_space = self._get_observation_space()
|
||||||
|
|
||||||
# rendering
|
# rendering
|
||||||
self.render_kwargs = {}
|
self.do_render = False
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
# condition value
|
# condition value
|
||||||
@ -164,7 +164,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
dtype=self.env.observation_space.dtype)
|
dtype=self.env.observation_space.dtype)
|
||||||
|
|
||||||
infos = dict()
|
infos = dict()
|
||||||
done = False
|
terminated, truncated = False, False
|
||||||
|
|
||||||
if not traj_is_valid:
|
if not traj_is_valid:
|
||||||
obs, trajectory_return, terminated, truncated, infos = self.env.invalid_traj_callback(action, position, velocity,
|
obs, trajectory_return, terminated, truncated, infos = self.env.invalid_traj_callback(action, position, velocity,
|
||||||
@ -190,8 +190,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
elems[t] = v
|
elems[t] = v
|
||||||
infos[k] = elems
|
infos[k] = elems
|
||||||
|
|
||||||
if self.render_kwargs:
|
if self.do_render:
|
||||||
self.env.render(**self.render_kwargs)
|
self.env.render()
|
||||||
|
|
||||||
|
|
||||||
if terminated or truncated or (self.replanning_schedule(self.env.get_wrapper_attr('current_pos'), self.env.get_wrapper_attr('current_vel'), obs, c_action, t + 1 + self.current_traj_steps) and self.plan_steps < self.max_planning_times):
|
if terminated or truncated or (self.replanning_schedule(self.env.get_wrapper_attr('current_pos'), self.env.get_wrapper_attr('current_vel'), obs, c_action, t + 1 + self.current_traj_steps) and self.plan_steps < self.max_planning_times):
|
||||||
|
|
||||||
@ -215,10 +216,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
trajectory_return = self.reward_aggregation(rewards[:t + 1])
|
trajectory_return = self.reward_aggregation(rewards[:t + 1])
|
||||||
return self.observation(obs), trajectory_return, terminated, truncated, infos
|
return self.observation(obs), trajectory_return, terminated, truncated, infos
|
||||||
|
|
||||||
def render(self, **kwargs):
|
def render(self):
|
||||||
"""Only set render options here, such that they can be used during the rollout.
|
self.do_render = True
|
||||||
This only needs to be called once"""
|
|
||||||
self.render_kwargs = kwargs
|
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||||
-> Tuple[ObsType, Dict[str, Any]]:
|
-> Tuple[ObsType, Dict[str, Any]]:
|
||||||
|
@ -14,12 +14,14 @@ class BaseReacherEnv(gym.Env):
|
|||||||
Base class for all reaching environments.
|
Base class for all reaching environments.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, n_links: int, random_start: bool = True, allow_self_collision: bool = False):
|
def __init__(self, n_links: int, random_start: bool = True, allow_self_collision: bool = False, render_mode: str = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.link_lengths = np.ones(n_links)
|
self.link_lengths = np.ones(n_links)
|
||||||
self.n_links = n_links
|
self.n_links = n_links
|
||||||
self._dt = 0.01
|
self._dt = 0.01
|
||||||
|
|
||||||
|
self.render_mode = render_mode
|
||||||
|
|
||||||
self.random_start = random_start
|
self.random_start = random_start
|
||||||
|
|
||||||
self.allow_self_collision = allow_self_collision
|
self.allow_self_collision = allow_self_collision
|
||||||
|
@ -10,8 +10,8 @@ class BaseReacherDirectEnv(BaseReacherEnv):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, n_links: int, random_start: bool = True,
|
def __init__(self, n_links: int, random_start: bool = True,
|
||||||
allow_self_collision: bool = False):
|
allow_self_collision: bool = False, **kwargs):
|
||||||
super().__init__(n_links, random_start, allow_self_collision)
|
super().__init__(n_links, random_start, allow_self_collision, **kwargs)
|
||||||
|
|
||||||
self.max_vel = 2 * np.pi
|
self.max_vel = 2 * np.pi
|
||||||
action_bound = np.ones((self.n_links,)) * self.max_vel
|
action_bound = np.ones((self.n_links,)) * self.max_vel
|
||||||
|
@ -10,8 +10,8 @@ class BaseReacherTorqueEnv(BaseReacherEnv):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, n_links: int, random_start: bool = True,
|
def __init__(self, n_links: int, random_start: bool = True,
|
||||||
allow_self_collision: bool = False):
|
allow_self_collision: bool = False, **kwargs):
|
||||||
super().__init__(n_links, random_start, allow_self_collision)
|
super().__init__(n_links, random_start, allow_self_collision, **kwargs)
|
||||||
|
|
||||||
self.max_torque = 1000
|
self.max_torque = 1000
|
||||||
action_bound = np.ones((self.n_links,)) * self.max_torque
|
action_bound = np.ones((self.n_links,)) * self.max_torque
|
||||||
|
@ -17,9 +17,9 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
|||||||
|
|
||||||
def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None,
|
def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None,
|
||||||
hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False,
|
hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False,
|
||||||
allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple"):
|
allow_wall_collision: bool = False, collision_penalty: float = 1000, rew_fct: str = "simple", **kwargs):
|
||||||
|
|
||||||
super().__init__(n_links, random_start, allow_self_collision)
|
super().__init__(n_links, random_start, allow_self_collision, **kwargs)
|
||||||
|
|
||||||
# provided initial parameters
|
# provided initial parameters
|
||||||
self.initial_x = hole_x # x-position of center of hole
|
self.initial_x = hole_x # x-position of center of hole
|
||||||
@ -178,7 +178,7 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def render(self, mode='human'):
|
def render(self):
|
||||||
if self.fig is None:
|
if self.fig is None:
|
||||||
# Create base figure once on the beginning. Afterwards only update
|
# Create base figure once on the beginning. Afterwards only update
|
||||||
plt.ion()
|
plt.ion()
|
||||||
@ -197,7 +197,7 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
|||||||
self.fig.gca().set_title(
|
self.fig.gca().set_title(
|
||||||
f"Iteration: {self._steps}, distance: {np.linalg.norm(self.end_effector - self._goal) ** 2}")
|
f"Iteration: {self._steps}, distance: {np.linalg.norm(self.end_effector - self._goal) ** 2}")
|
||||||
|
|
||||||
if mode == "human":
|
if self.render_mode == "human":
|
||||||
|
|
||||||
# arm
|
# arm
|
||||||
self.line.set_data(self._joints[:, 0], self._joints[:, 1])
|
self.line.set_data(self._joints[:, 0], self._joints[:, 1])
|
||||||
@ -205,7 +205,7 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
|||||||
self.fig.canvas.draw()
|
self.fig.canvas.draw()
|
||||||
self.fig.canvas.flush_events()
|
self.fig.canvas.flush_events()
|
||||||
|
|
||||||
elif mode == "partial":
|
elif self.render_mode == "partial":
|
||||||
if self._steps % 20 == 0 or self._steps in [1, 199] or self._is_collided:
|
if self._steps % 20 == 0 or self._steps in [1, 199] or self._is_collided:
|
||||||
# Arm
|
# Arm
|
||||||
plt.plot(self._joints[:, 0], self._joints[:, 1], 'ro-', markerfacecolor='k',
|
plt.plot(self._joints[:, 0], self._joints[:, 1], 'ro-', markerfacecolor='k',
|
||||||
|
@ -17,8 +17,8 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, n_links: int, target: Union[None, Iterable] = None, random_start: bool = True,
|
def __init__(self, n_links: int, target: Union[None, Iterable] = None, random_start: bool = True,
|
||||||
allow_self_collision: bool = False, ):
|
allow_self_collision: bool = False, **kwargs):
|
||||||
super().__init__(n_links, random_start, allow_self_collision)
|
super().__init__(n_links, random_start, allow_self_collision, **kwargs)
|
||||||
|
|
||||||
# provided initial parameters
|
# provided initial parameters
|
||||||
self.inital_target = target
|
self.inital_target = target
|
||||||
@ -98,7 +98,7 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
|
|||||||
def _check_collisions(self) -> bool:
|
def _check_collisions(self) -> bool:
|
||||||
return self._check_self_collision()
|
return self._check_self_collision()
|
||||||
|
|
||||||
def render(self, mode='human'): # pragma: no cover
|
def render(self): # pragma: no cover
|
||||||
if self.fig is None:
|
if self.fig is None:
|
||||||
# Create base figure once on the beginning. Afterwards only update
|
# Create base figure once on the beginning. Afterwards only update
|
||||||
plt.ion()
|
plt.ion()
|
||||||
|
@ -13,9 +13,9 @@ from . import MPWrapper
|
|||||||
class ViaPointReacherEnv(BaseReacherDirectEnv):
|
class ViaPointReacherEnv(BaseReacherDirectEnv):
|
||||||
|
|
||||||
def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None,
|
def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None,
|
||||||
target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000):
|
target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000, **kwargs):
|
||||||
|
|
||||||
super().__init__(n_links, random_start, allow_self_collision)
|
super().__init__(n_links, random_start, allow_self_collision, **kwargs)
|
||||||
|
|
||||||
# provided initial parameters
|
# provided initial parameters
|
||||||
self.intitial_target = target # provided target value
|
self.intitial_target = target # provided target value
|
||||||
@ -123,7 +123,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
|
|||||||
def _check_collisions(self) -> bool:
|
def _check_collisions(self) -> bool:
|
||||||
return self._check_self_collision()
|
return self._check_self_collision()
|
||||||
|
|
||||||
def render(self, mode='human'):
|
def render(self):
|
||||||
goal_pos = self._goal.T
|
goal_pos = self._goal.T
|
||||||
via_pos = self._via_point.T
|
via_pos = self._via_point.T
|
||||||
|
|
||||||
@ -146,7 +146,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
|
|||||||
|
|
||||||
self.fig.gca().set_title(f"Iteration: {self._steps}, distance: {self.end_effector - self._goal}")
|
self.fig.gca().set_title(f"Iteration: {self._steps}, distance: {self.end_effector - self._goal}")
|
||||||
|
|
||||||
if mode == "human":
|
if self.render_mode == "human":
|
||||||
# goal
|
# goal
|
||||||
if self._steps == 1:
|
if self._steps == 1:
|
||||||
self.goal_point_plot.set_data(goal_pos[0], goal_pos[1])
|
self.goal_point_plot.set_data(goal_pos[0], goal_pos[1])
|
||||||
@ -158,7 +158,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
|
|||||||
self.fig.canvas.draw()
|
self.fig.canvas.draw()
|
||||||
self.fig.canvas.flush_events()
|
self.fig.canvas.flush_events()
|
||||||
|
|
||||||
elif mode == "partial":
|
elif self.render_mode == "partial":
|
||||||
if self._steps == 1:
|
if self._steps == 1:
|
||||||
# fig, ax = plt.subplots()
|
# fig, ax = plt.subplots()
|
||||||
# Add the patch to the Axes
|
# Add the patch to the Axes
|
||||||
@ -178,7 +178,7 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
|
|||||||
plt.ylim([-1.1, lim])
|
plt.ylim([-1.1, lim])
|
||||||
plt.pause(0.01)
|
plt.pause(0.01)
|
||||||
|
|
||||||
elif mode == "final":
|
elif self.render_mode == "final":
|
||||||
if self._steps == 199 or self._is_collided:
|
if self._steps == 199 or self._is_collided:
|
||||||
# fig, ax = plt.subplots()
|
# fig, ax = plt.subplots()
|
||||||
|
|
||||||
|
@ -101,6 +101,7 @@ class AntJumpEnv(AntEnvCustomXML):
|
|||||||
contact_force_range=(-1.0, 1.0),
|
contact_force_range=(-1.0, 1.0),
|
||||||
reset_noise_scale=0.1,
|
reset_noise_scale=0.1,
|
||||||
exclude_current_positions_from_observation=True,
|
exclude_current_positions_from_observation=True,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
self.max_height = 0
|
self.max_height = 0
|
||||||
@ -113,7 +114,7 @@ class AntJumpEnv(AntEnvCustomXML):
|
|||||||
healthy_z_range=healthy_z_range,
|
healthy_z_range=healthy_z_range,
|
||||||
contact_force_range=contact_force_range,
|
contact_force_range=contact_force_range,
|
||||||
reset_noise_scale=reset_noise_scale,
|
reset_noise_scale=reset_noise_scale,
|
||||||
exclude_current_positions_from_observation=exclude_current_positions_from_observation)
|
exclude_current_positions_from_observation=exclude_current_positions_from_observation, **kwargs)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
self.current_step += 1
|
self.current_step += 1
|
||||||
|
@ -36,7 +36,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
|||||||
"render_fps": 50
|
"render_fps": 50
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, frame_skip: int = 10, random_init: bool = False):
|
def __init__(self, frame_skip: int = 10, random_init: bool = False, **kwargs):
|
||||||
utils.EzPickle.__init__(**locals())
|
utils.EzPickle.__init__(**locals())
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
self.init_qpos_box_pushing = np.array([0., 0., 0., -1.5, 0., 1.5, 0., 0., 0., 0.6, 0.45, 0.0, 1., 0., 0., 0.])
|
self.init_qpos_box_pushing = np.array([0., 0., 0., -1.5, 0., 1.5, 0., 0., 0., 0.6, 0.45, 0.0, 1., 0., 0., 0.])
|
||||||
@ -58,7 +58,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
|||||||
MujocoEnv.__init__(self,
|
MujocoEnv.__init__(self,
|
||||||
model_path=os.path.join(os.path.dirname(__file__), "assets", "box_pushing.xml"),
|
model_path=os.path.join(os.path.dirname(__file__), "assets", "box_pushing.xml"),
|
||||||
frame_skip=self.frame_skip,
|
frame_skip=self.frame_skip,
|
||||||
observation_space=self.observation_space)
|
observation_space=self.observation_space, **kwargs)
|
||||||
self.action_space = spaces.Box(low=-1, high=1, shape=(7,))
|
self.action_space = spaces.Box(low=-1, high=1, shape=(7,))
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
@ -305,8 +305,8 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
|
|
||||||
class BoxPushingDense(BoxPushingEnvBase):
|
class BoxPushingDense(BoxPushingEnvBase):
|
||||||
def __init__(self, frame_skip: int = 10, random_init: bool = False):
|
def __init__(self, **kwargs):
|
||||||
super(BoxPushingDense, self).__init__(frame_skip=frame_skip, random_init=random_init)
|
super(BoxPushingDense, self).__init__(**kwargs)
|
||||||
def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat,
|
def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat,
|
||||||
rod_tip_pos, rod_quat, qpos, qvel, action):
|
rod_tip_pos, rod_quat, qpos, qvel, action):
|
||||||
joint_penalty = self._joint_limit_violate_penalty(qpos,
|
joint_penalty = self._joint_limit_violate_penalty(qpos,
|
||||||
@ -329,8 +329,8 @@ class BoxPushingDense(BoxPushingEnvBase):
|
|||||||
|
|
||||||
|
|
||||||
class BoxPushingTemporalSparse(BoxPushingEnvBase):
|
class BoxPushingTemporalSparse(BoxPushingEnvBase):
|
||||||
def __init__(self, frame_skip: int = 10, random_init: bool = False):
|
def __init__(self, **kwargs):
|
||||||
super(BoxPushingTemporalSparse, self).__init__(frame_skip=frame_skip, random_init=random_init)
|
super(BoxPushingTemporalSparse, self).__init__(**kwargs)
|
||||||
|
|
||||||
def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat,
|
def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat,
|
||||||
rod_tip_pos, rod_quat, qpos, qvel, action):
|
rod_tip_pos, rod_quat, qpos, qvel, action):
|
||||||
@ -361,8 +361,8 @@ class BoxPushingTemporalSparse(BoxPushingEnvBase):
|
|||||||
|
|
||||||
class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase):
|
class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase):
|
||||||
|
|
||||||
def __init__(self, frame_skip: int = 10, random_init: bool = False):
|
def __init__(self, **kwargs):
|
||||||
super(BoxPushingTemporalSpatialSparse, self).__init__(frame_skip=frame_skip, random_init=random_init)
|
super(BoxPushingTemporalSpatialSparse, self).__init__(**kwargs)
|
||||||
|
|
||||||
def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat,
|
def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat,
|
||||||
rod_tip_pos, rod_quat, qpos, qvel, action):
|
rod_tip_pos, rod_quat, qpos, qvel, action):
|
||||||
@ -392,8 +392,8 @@ class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase):
|
|||||||
|
|
||||||
class BoxPushingTemporalSpatialSparse2(BoxPushingEnvBase):
|
class BoxPushingTemporalSpatialSparse2(BoxPushingEnvBase):
|
||||||
|
|
||||||
def __init__(self, frame_skip: int = 10, random_init: bool = False):
|
def __init__(self, **kwargs):
|
||||||
super(BoxPushingTemporalSpatialSparse2, self).__init__(frame_skip=frame_skip, random_init=random_init)
|
super(BoxPushingTemporalSpatialSparse2, self).__init__(**kwargs)
|
||||||
|
|
||||||
def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat,
|
def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat,
|
||||||
rod_tip_pos, rod_quat, qpos, qvel, action):
|
rod_tip_pos, rod_quat, qpos, qvel, action):
|
||||||
@ -428,8 +428,8 @@ class BoxPushingTemporalSpatialSparse2(BoxPushingEnvBase):
|
|||||||
|
|
||||||
|
|
||||||
class BoxPushingNoConstraintSparse(BoxPushingEnvBase):
|
class BoxPushingNoConstraintSparse(BoxPushingEnvBase):
|
||||||
def __init__(self, frame_skip: int = 10, random_init: bool = False):
|
def __init__(self, **kwargs):
|
||||||
super(BoxPushingNoConstraintSparse, self).__init__(frame_skip=frame_skip, random_init=random_init)
|
super(BoxPushingNoConstraintSparse, self).__init__(**kwargs)
|
||||||
|
|
||||||
def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat,
|
def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat,
|
||||||
rod_tip_pos, rod_quat, qpos, qvel, action):
|
rod_tip_pos, rod_quat, qpos, qvel, action):
|
||||||
|
@ -74,7 +74,8 @@ class HalfCheetahJumpEnv(HalfCheetahEnvCustomXML):
|
|||||||
reset_noise_scale=0.1,
|
reset_noise_scale=0.1,
|
||||||
context=True,
|
context=True,
|
||||||
exclude_current_positions_from_observation=True,
|
exclude_current_positions_from_observation=True,
|
||||||
max_episode_steps=100):
|
max_episode_steps=100,
|
||||||
|
**kwargs):
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
self.max_height = 0
|
self.max_height = 0
|
||||||
# self.max_episode_steps = max_episode_steps
|
# self.max_episode_steps = max_episode_steps
|
||||||
@ -85,7 +86,8 @@ class HalfCheetahJumpEnv(HalfCheetahEnvCustomXML):
|
|||||||
forward_reward_weight=forward_reward_weight,
|
forward_reward_weight=forward_reward_weight,
|
||||||
ctrl_cost_weight=ctrl_cost_weight,
|
ctrl_cost_weight=ctrl_cost_weight,
|
||||||
reset_noise_scale=reset_noise_scale,
|
reset_noise_scale=reset_noise_scale,
|
||||||
exclude_current_positions_from_observation=exclude_current_positions_from_observation)
|
exclude_current_positions_from_observation=exclude_current_positions_from_observation,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
|
||||||
|
@ -115,6 +115,7 @@ class HopperJumpEnv(HopperEnvCustomXML):
|
|||||||
reset_noise_scale=5e-3,
|
reset_noise_scale=5e-3,
|
||||||
exclude_current_positions_from_observation=False,
|
exclude_current_positions_from_observation=False,
|
||||||
sparse=False,
|
sparse=False,
|
||||||
|
**kwargs
|
||||||
):
|
):
|
||||||
|
|
||||||
self.sparse = sparse
|
self.sparse = sparse
|
||||||
@ -141,7 +142,8 @@ class HopperJumpEnv(HopperEnvCustomXML):
|
|||||||
healthy_z_range=healthy_z_range,
|
healthy_z_range=healthy_z_range,
|
||||||
healthy_angle_range=healthy_angle_range,
|
healthy_angle_range=healthy_angle_range,
|
||||||
reset_noise_scale=reset_noise_scale,
|
reset_noise_scale=reset_noise_scale,
|
||||||
exclude_current_positions_from_observation=exclude_current_positions_from_observation)
|
exclude_current_positions_from_observation=exclude_current_positions_from_observation,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
# increase initial height
|
# increase initial height
|
||||||
self.init_qpos[1] = 1.5
|
self.init_qpos[1] = 1.5
|
||||||
|
@ -29,7 +29,8 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML):
|
|||||||
reset_noise_scale=5e-3,
|
reset_noise_scale=5e-3,
|
||||||
context=True,
|
context=True,
|
||||||
exclude_current_positions_from_observation=True,
|
exclude_current_positions_from_observation=True,
|
||||||
max_episode_steps=250):
|
max_episode_steps=250,
|
||||||
|
**kwargs):
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
self.max_height = 0
|
self.max_height = 0
|
||||||
self.max_episode_steps = max_episode_steps
|
self.max_episode_steps = max_episode_steps
|
||||||
@ -50,7 +51,8 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML):
|
|||||||
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
|
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
|
||||||
super().__init__(xml_file, forward_reward_weight, ctrl_cost_weight, healthy_reward, terminate_when_unhealthy,
|
super().__init__(xml_file, forward_reward_weight, ctrl_cost_weight, healthy_reward, terminate_when_unhealthy,
|
||||||
healthy_state_range, healthy_z_range, healthy_angle_range, reset_noise_scale,
|
healthy_state_range, healthy_z_range, healthy_angle_range, reset_noise_scale,
|
||||||
exclude_current_positions_from_observation)
|
exclude_current_positions_from_observation,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
|
||||||
|
@ -32,7 +32,8 @@ class HopperThrowEnv(HopperEnvCustomXML):
|
|||||||
reset_noise_scale=5e-3,
|
reset_noise_scale=5e-3,
|
||||||
context=True,
|
context=True,
|
||||||
exclude_current_positions_from_observation=True,
|
exclude_current_positions_from_observation=True,
|
||||||
max_episode_steps=250):
|
max_episode_steps=250,
|
||||||
|
**kwargs):
|
||||||
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
|
xml_file = os.path.join(os.path.dirname(__file__), "assets", xml_file)
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
self.max_episode_steps = max_episode_steps
|
self.max_episode_steps = max_episode_steps
|
||||||
@ -57,7 +58,8 @@ class HopperThrowEnv(HopperEnvCustomXML):
|
|||||||
healthy_z_range=healthy_z_range,
|
healthy_z_range=healthy_z_range,
|
||||||
healthy_state_range=healthy_angle_range,
|
healthy_state_range=healthy_angle_range,
|
||||||
reset_noise_scale=reset_noise_scale,
|
reset_noise_scale=reset_noise_scale,
|
||||||
exclude_current_positions_from_observation=exclude_current_positions_from_observation)
|
exclude_current_positions_from_observation=exclude_current_positions_from_observation,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
self.current_step += 1
|
self.current_step += 1
|
||||||
|
@ -36,7 +36,8 @@ class HopperThrowInBasketEnv(HopperEnvCustomXML):
|
|||||||
context=True,
|
context=True,
|
||||||
penalty=0.0,
|
penalty=0.0,
|
||||||
exclude_current_positions_from_observation=True,
|
exclude_current_positions_from_observation=True,
|
||||||
max_episode_steps=250):
|
max_episode_steps=250,
|
||||||
|
**kwargs):
|
||||||
self.hit_basket_reward = hit_basket_reward
|
self.hit_basket_reward = hit_basket_reward
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
self.max_episode_steps = max_episode_steps
|
self.max_episode_steps = max_episode_steps
|
||||||
@ -65,7 +66,8 @@ class HopperThrowInBasketEnv(HopperEnvCustomXML):
|
|||||||
healthy_z_range=healthy_z_range,
|
healthy_z_range=healthy_z_range,
|
||||||
healthy_angle_range=healthy_angle_range,
|
healthy_angle_range=healthy_angle_range,
|
||||||
reset_noise_scale=reset_noise_scale,
|
reset_noise_scale=reset_noise_scale,
|
||||||
exclude_current_positions_from_observation=exclude_current_positions_from_observation)
|
exclude_current_positions_from_observation=exclude_current_positions_from_observation,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
|
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4,
|
||||||
goal_switching_step: int = None,
|
goal_switching_step: int = None,
|
||||||
enable_artificial_wind: bool = False):
|
enable_artificial_wind: bool = False, **kwargs):
|
||||||
utils.EzPickle.__init__(**locals())
|
utils.EzPickle.__init__(**locals())
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
|
|
||||||
@ -68,7 +68,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
MujocoEnv.__init__(self,
|
MujocoEnv.__init__(self,
|
||||||
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
model_path=os.path.join(os.path.dirname(__file__), "assets", "xml", "table_tennis_env.xml"),
|
||||||
frame_skip=frame_skip,
|
frame_skip=frame_skip,
|
||||||
observation_space=self.observation_space)
|
observation_space=self.observation_space,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
if ctxt_dim == 2:
|
if ctxt_dim == 2:
|
||||||
self.context_bounds = CONTEXT_BOUNDS_2DIMS
|
self.context_bounds = CONTEXT_BOUNDS_2DIMS
|
||||||
@ -275,11 +276,11 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
|
|
||||||
class TableTennisWind(TableTennisEnv):
|
class TableTennisWind(TableTennisEnv):
|
||||||
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4):
|
def __init__(self, ctxt_dim: int = 4, frame_skip: int = 4, **kwargs):
|
||||||
self.observation_space = spaces.Box(
|
self.observation_space = spaces.Box(
|
||||||
low=-np.inf, high=np.inf, shape=(22,), dtype=np.float64
|
low=-np.inf, high=np.inf, shape=(22,), dtype=np.float64
|
||||||
)
|
)
|
||||||
super().__init__(ctxt_dim=ctxt_dim, frame_skip=frame_skip, enable_artificial_wind=True)
|
super().__init__(ctxt_dim=ctxt_dim, frame_skip=frame_skip, enable_artificial_wind=True, **kwargs)
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
obs = np.concatenate([
|
obs = np.concatenate([
|
||||||
@ -297,5 +298,5 @@ class TableTennisWind(TableTennisEnv):
|
|||||||
|
|
||||||
|
|
||||||
class TableTennisGoalSwitching(TableTennisEnv):
|
class TableTennisGoalSwitching(TableTennisEnv):
|
||||||
def __init__(self, frame_skip: int = 4, goal_switching_step: int = 99):
|
def __init__(self, frame_skip: int = 4, goal_switching_step: int = 99, **kwargs):
|
||||||
super().__init__(frame_skip=frame_skip, goal_switching_step=goal_switching_step)
|
super().__init__(frame_skip=frame_skip, goal_switching_step=goal_switching_step, **kwargs)
|
||||||
|
@ -97,7 +97,8 @@ class Walker2dJumpEnv(Walker2dEnvCustomXML):
|
|||||||
reset_noise_scale=5e-3,
|
reset_noise_scale=5e-3,
|
||||||
penalty=0,
|
penalty=0,
|
||||||
exclude_current_positions_from_observation=True,
|
exclude_current_positions_from_observation=True,
|
||||||
max_episode_steps=300):
|
max_episode_steps=300,
|
||||||
|
**kwargs):
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
self.max_episode_steps = max_episode_steps
|
self.max_episode_steps = max_episode_steps
|
||||||
self.max_height = 0
|
self.max_height = 0
|
||||||
@ -112,7 +113,8 @@ class Walker2dJumpEnv(Walker2dEnvCustomXML):
|
|||||||
healthy_z_range=healthy_z_range,
|
healthy_z_range=healthy_z_range,
|
||||||
healthy_angle_range=healthy_angle_range,
|
healthy_angle_range=healthy_angle_range,
|
||||||
reset_noise_scale=reset_noise_scale,
|
reset_noise_scale=reset_noise_scale,
|
||||||
exclude_current_positions_from_observation=exclude_current_positions_from_observation)
|
exclude_current_positions_from_observation=exclude_current_positions_from_observation,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
self.current_step += 1
|
self.current_step += 1
|
||||||
|
@ -41,7 +41,7 @@ class ToyEnv(gym.Env):
|
|||||||
obs, reward, terminated, truncated, info = np.array([-1]), 1, False, False, {}
|
obs, reward, terminated, truncated, info = np.array([-1]), 1, False, False, {}
|
||||||
return obs, reward, terminated, truncated, info
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ class ToyEnv(gym.Env):
|
|||||||
obs, reward, terminated, truncated, info = np.array([-1]), 1, False, False, {}
|
obs, reward, terminated, truncated, info = np.array([-1]), 1, False, False, {}
|
||||||
return obs, reward, terminated, truncated, info
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ class ToyEnv(gym.Env):
|
|||||||
obs, reward, terminated, truncated, info = np.array([-1]), 1, False, False, {}
|
obs, reward, terminated, truncated, info = np.array([-1]), 1, False, False, {}
|
||||||
return obs, reward, terminated, truncated, info
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
def render(self, mode="human"):
|
def render(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user