Merge pull request #99 from D-o-d-o-x/fix_automatic_render
Fix: Some fancy envs not following gym spec regarding rendering behavior
This commit is contained in:
commit
fa72c3791c
@ -115,6 +115,7 @@ class AntJumpEnv(AntEnvCustomXML):
|
|||||||
contact_force_range=contact_force_range,
|
contact_force_range=contact_force_range,
|
||||||
reset_noise_scale=reset_noise_scale,
|
reset_noise_scale=reset_noise_scale,
|
||||||
exclude_current_positions_from_observation=exclude_current_positions_from_observation, **kwargs)
|
exclude_current_positions_from_observation=exclude_current_positions_from_observation, **kwargs)
|
||||||
|
self.render_active = False
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
self.current_step += 1
|
self.current_step += 1
|
||||||
@ -153,8 +154,15 @@ class AntJumpEnv(AntEnvCustomXML):
|
|||||||
}
|
}
|
||||||
truncated = False
|
truncated = False
|
||||||
|
|
||||||
|
if self.render_active and self.render_mode=='human':
|
||||||
|
self.render()
|
||||||
|
|
||||||
return obs, reward, terminated, truncated, info
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
self.render_active = True
|
||||||
|
return super().render()
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
return np.append(super()._get_obs(), self.goal)
|
return np.append(super()._get_obs(), self.goal)
|
||||||
|
|
||||||
|
@ -44,6 +44,7 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
utils.EzPickle.__init__(self)
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
# Small Context -> Easier. Todo: Should we do different versions?
|
# Small Context -> Easier. Todo: Should we do different versions?
|
||||||
# self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "beerpong_wo_cup.xml")
|
# self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "beerpong_wo_cup.xml")
|
||||||
@ -89,7 +90,7 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
observation_space=self.observation_space,
|
observation_space=self.observation_space,
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
utils.EzPickle.__init__(self)
|
self.render_active = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def start_pos(self):
|
def start_pos(self):
|
||||||
@ -169,8 +170,15 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
truncated = False
|
truncated = False
|
||||||
|
|
||||||
|
if self.render_active and self.render_mode=='human':
|
||||||
|
self.render()
|
||||||
|
|
||||||
return ob, reward, terminated, truncated, infos
|
return ob, reward, terminated, truncated, infos
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
self.render_active = True
|
||||||
|
return super().render()
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
theta = self.data.qpos.flat[:7].copy()
|
theta = self.data.qpos.flat[:7].copy()
|
||||||
theta_dot = self.data.qvel.flat[:7].copy()
|
theta_dot = self.data.qvel.flat[:7].copy()
|
||||||
|
@ -4,6 +4,7 @@ import numpy as np
|
|||||||
from gymnasium import utils, spaces
|
from gymnasium import utils, spaces
|
||||||
from gymnasium.envs.mujoco import MujocoEnv
|
from gymnasium.envs.mujoco import MujocoEnv
|
||||||
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import rot_to_quat, get_quaternion_error, rotation_distance
|
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import rot_to_quat, get_quaternion_error, rotation_distance
|
||||||
|
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import rot_to_quat, get_quaternion_error, rotation_distance
|
||||||
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import q_max, q_min, q_dot_max, q_torque_max
|
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import q_max, q_min, q_dot_max, q_torque_max
|
||||||
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import desired_rod_quat
|
from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import desired_rod_quat
|
||||||
|
|
||||||
@ -60,6 +61,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
|||||||
frame_skip=self.frame_skip,
|
frame_skip=self.frame_skip,
|
||||||
observation_space=self.observation_space, **kwargs)
|
observation_space=self.observation_space, **kwargs)
|
||||||
self.action_space = spaces.Box(low=-1, high=1, shape=(7,))
|
self.action_space = spaces.Box(low=-1, high=1, shape=(7,))
|
||||||
|
self.render_active = False
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
action = 10 * np.clip(action, self.action_space.low, self.action_space.high)
|
action = 10 * np.clip(action, self.action_space.low, self.action_space.high)
|
||||||
@ -108,8 +110,15 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
|||||||
terminated = episode_end and infos['is_success']
|
terminated = episode_end and infos['is_success']
|
||||||
truncated = episode_end and not infos['is_success']
|
truncated = episode_end and not infos['is_success']
|
||||||
|
|
||||||
|
if self.render_active and self.render_mode=='human':
|
||||||
|
self.render()
|
||||||
|
|
||||||
return obs, reward, terminated, truncated, infos
|
return obs, reward, terminated, truncated, infos
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
self.render_active = True
|
||||||
|
return super().render()
|
||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
# rest box to initial position
|
# rest box to initial position
|
||||||
self.set_state(self.init_qpos_box_pushing, self.init_qvel_box_pushing)
|
self.set_state(self.init_qpos_box_pushing, self.init_qvel_box_pushing)
|
||||||
|
@ -60,7 +60,11 @@ class HalfCheetahEnvCustomXML(HalfCheetahEnv):
|
|||||||
default_camera_config=DEFAULT_CAMERA_CONFIG,
|
default_camera_config=DEFAULT_CAMERA_CONFIG,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
self.render_active = False
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
self.render_active = True
|
||||||
|
return super().render()
|
||||||
|
|
||||||
class HalfCheetahJumpEnv(HalfCheetahEnvCustomXML):
|
class HalfCheetahJumpEnv(HalfCheetahEnvCustomXML):
|
||||||
"""
|
"""
|
||||||
@ -120,6 +124,9 @@ class HalfCheetahJumpEnv(HalfCheetahEnvCustomXML):
|
|||||||
'max_height': self.max_height
|
'max_height': self.max_height
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.render_active and self.render_mode=='human':
|
||||||
|
self.render()
|
||||||
|
|
||||||
return observation, reward, terminated, truncated, info
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
|
@ -88,6 +88,12 @@ class HopperEnvCustomXML(HopperEnv):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.render_active = False
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
self.render_active = True
|
||||||
|
return super().render()
|
||||||
|
|
||||||
|
|
||||||
class HopperJumpEnv(HopperEnvCustomXML):
|
class HopperJumpEnv(HopperEnvCustomXML):
|
||||||
"""
|
"""
|
||||||
@ -201,6 +207,10 @@ class HopperJumpEnv(HopperEnvCustomXML):
|
|||||||
healthy=self.is_healthy,
|
healthy=self.is_healthy,
|
||||||
contact_dist=self.contact_dist or 0
|
contact_dist=self.contact_dist or 0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.render_active and self.render_mode=='human':
|
||||||
|
self.render()
|
||||||
|
|
||||||
return observation, reward, terminated, truncated, info
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
|
@ -140,6 +140,9 @@ class HopperJumpOnBoxEnv(HopperEnvCustomXML):
|
|||||||
|
|
||||||
truncated = self.current_step >= self.max_episode_steps and not terminated
|
truncated = self.current_step >= self.max_episode_steps and not terminated
|
||||||
|
|
||||||
|
if self.render_active and self.render_mode=='human':
|
||||||
|
self.render()
|
||||||
|
|
||||||
return observation, reward, terminated, truncated, info
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
|
@ -61,6 +61,8 @@ class HopperThrowEnv(HopperEnvCustomXML):
|
|||||||
exclude_current_positions_from_observation=exclude_current_positions_from_observation,
|
exclude_current_positions_from_observation=exclude_current_positions_from_observation,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
|
self.render_active = False
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
self.current_step += 1
|
self.current_step += 1
|
||||||
self.do_simulation(action, self.frame_skip)
|
self.do_simulation(action, self.frame_skip)
|
||||||
@ -94,8 +96,15 @@ class HopperThrowEnv(HopperEnvCustomXML):
|
|||||||
}
|
}
|
||||||
truncated = False
|
truncated = False
|
||||||
|
|
||||||
|
if self.render_active and self.render_mode=='human':
|
||||||
|
self.render()
|
||||||
|
|
||||||
return observation, reward, terminated, truncated, info
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
self.render_active = True
|
||||||
|
return super().render()
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
return np.append(super()._get_obs(), self.goal)
|
return np.append(super()._get_obs(), self.goal)
|
||||||
|
|
||||||
|
@ -68,6 +68,7 @@ class HopperThrowInBasketEnv(HopperEnvCustomXML):
|
|||||||
reset_noise_scale=reset_noise_scale,
|
reset_noise_scale=reset_noise_scale,
|
||||||
exclude_current_positions_from_observation=exclude_current_positions_from_observation,
|
exclude_current_positions_from_observation=exclude_current_positions_from_observation,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
self.render_active = False
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
|
||||||
@ -118,8 +119,15 @@ class HopperThrowInBasketEnv(HopperEnvCustomXML):
|
|||||||
}
|
}
|
||||||
truncated = False
|
truncated = False
|
||||||
|
|
||||||
|
if self.render_active and self.render_mode=='human':
|
||||||
|
self.render()
|
||||||
|
|
||||||
return observation, reward, terminated, truncated, info
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
self.render_active = True
|
||||||
|
return super().render()
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
return np.append(super()._get_obs(), self.basket_x)
|
return np.append(super()._get_obs(), self.basket_x)
|
||||||
|
|
||||||
|
@ -47,6 +47,8 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
|||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.render_active = False
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
self._steps += 1
|
self._steps += 1
|
||||||
|
|
||||||
@ -77,8 +79,15 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
|||||||
goal=self.goal if hasattr(self, "goal") else None
|
goal=self.goal if hasattr(self, "goal") else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.render_active and self.render_mode=='human':
|
||||||
|
self.render()
|
||||||
|
|
||||||
return ob, reward, terminated, truncated, info
|
return ob, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
self.render_active = True
|
||||||
|
return super().render()
|
||||||
|
|
||||||
def distance_reward(self):
|
def distance_reward(self):
|
||||||
vec = self.get_body_com("fingertip") - self.get_body_com("target")
|
vec = self.get_body_com("fingertip") - self.get_body_com("target")
|
||||||
return -self._reward_weight * np.linalg.norm(vec)
|
return -self._reward_weight * np.linalg.norm(vec)
|
||||||
|
@ -71,6 +71,8 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
observation_space=self.observation_space,
|
observation_space=self.observation_space,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
|
self.render_active = False
|
||||||
|
|
||||||
if ctxt_dim == 2:
|
if ctxt_dim == 2:
|
||||||
self.context_bounds = CONTEXT_BOUNDS_2DIMS
|
self.context_bounds = CONTEXT_BOUNDS_2DIMS
|
||||||
elif ctxt_dim == 4:
|
elif ctxt_dim == 4:
|
||||||
@ -158,8 +160,15 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
terminated, truncated = self._terminated, False
|
terminated, truncated = self._terminated, False
|
||||||
|
|
||||||
|
if self.render_active and self.render_mode=='human':
|
||||||
|
self.render()
|
||||||
|
|
||||||
return self._get_obs(), reward, terminated, truncated, info
|
return self._get_obs(), reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
self.render_active = True
|
||||||
|
return super().render()
|
||||||
|
|
||||||
def _contact_checker(self, id_1, id_2):
|
def _contact_checker(self, id_1, id_2):
|
||||||
for coni in range(0, self.data.ncon):
|
for coni in range(0, self.data.ncon):
|
||||||
con = self.data.contact[coni]
|
con = self.data.contact[coni]
|
||||||
|
@ -79,6 +79,8 @@ class Walker2dEnvCustomXML(Walker2dEnv):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.render_active = False
|
||||||
|
|
||||||
|
|
||||||
class Walker2dJumpEnv(Walker2dEnvCustomXML):
|
class Walker2dJumpEnv(Walker2dEnvCustomXML):
|
||||||
"""
|
"""
|
||||||
@ -145,8 +147,15 @@ class Walker2dJumpEnv(Walker2dEnvCustomXML):
|
|||||||
}
|
}
|
||||||
truncated = False
|
truncated = False
|
||||||
|
|
||||||
|
if self.render_active and self.render_mode=='human':
|
||||||
|
self.render()
|
||||||
|
|
||||||
return observation, reward, terminated, truncated, info
|
return observation, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
self.render_active = True
|
||||||
|
return super().render()
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
return np.append(super()._get_obs(), self.goal)
|
return np.append(super()._get_obs(), self.goal)
|
||||||
|
|
||||||
|
@ -3,14 +3,14 @@ import fancy_gym
|
|||||||
|
|
||||||
|
|
||||||
def example_run_replanning_env(env_name="fancy_ProDMP/BoxPushingDenseReplan-v0", seed=1, iterations=1, render=False):
|
def example_run_replanning_env(env_name="fancy_ProDMP/BoxPushingDenseReplan-v0", seed=1, iterations=1, render=False):
|
||||||
env = gym.make(env_name)
|
env = gym.make(env_name, render_mode='human' if render else None)
|
||||||
env.reset(seed=seed)
|
env.reset(seed=seed)
|
||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
while True:
|
while True:
|
||||||
ac = env.action_space.sample()
|
ac = env.action_space.sample()
|
||||||
obs, reward, terminated, truncated, info = env.step(ac)
|
obs, reward, terminated, truncated, info = env.step(ac)
|
||||||
if render:
|
if render:
|
||||||
env.render(mode="human")
|
env.render()
|
||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
env.reset()
|
env.reset()
|
||||||
break
|
break
|
||||||
@ -38,13 +38,13 @@ def example_custom_replanning_envs(seed=0, iteration=100, render=True):
|
|||||||
'replanning_schedule': lambda pos, vel, obs, action, t: t % 25 == 0,
|
'replanning_schedule': lambda pos, vel, obs, action, t: t % 25 == 0,
|
||||||
'condition_on_desired': True}
|
'condition_on_desired': True}
|
||||||
|
|
||||||
base_env = gym.make(base_env_id)
|
base_env = gym.make(base_env_id, render_mode='human' if render else None)
|
||||||
env = fancy_gym.make_bb(env=base_env, wrappers=wrappers, black_box_kwargs=black_box_kwargs,
|
env = fancy_gym.make_bb(env=base_env, wrappers=wrappers, black_box_kwargs=black_box_kwargs,
|
||||||
traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs,
|
traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs,
|
||||||
phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs,
|
phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs,
|
||||||
seed=seed)
|
seed=seed)
|
||||||
if render:
|
if render:
|
||||||
env.render(mode="human")
|
env.render()
|
||||||
|
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ def example_dmc(env_id="dm_control/fish-swim", seed=1, iterations=1000, render=T
|
|||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
env = gym.make(env_id)
|
env = gym.make(env_id, render_mode='human' if render else None)
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset(seed=seed)
|
obs = env.reset(seed=seed)
|
||||||
print("observation shape:", env.observation_space.shape)
|
print("observation shape:", env.observation_space.shape)
|
||||||
@ -26,7 +26,7 @@ def example_dmc(env_id="dm_control/fish-swim", seed=1, iterations=1000, render=T
|
|||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
ac = env.action_space.sample()
|
ac = env.action_space.sample()
|
||||||
if render:
|
if render:
|
||||||
env.render(mode="human")
|
env.render()
|
||||||
obs, reward, terminated, truncated, info = env.step(ac)
|
obs, reward, terminated, truncated, info = env.step(ac)
|
||||||
rewards += reward
|
rewards += reward
|
||||||
|
|
||||||
@ -84,7 +84,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
|||||||
# basis_generator_kwargs = {'basis_generator_type': 'rbf',
|
# basis_generator_kwargs = {'basis_generator_type': 'rbf',
|
||||||
# 'num_basis': 5
|
# 'num_basis': 5
|
||||||
# }
|
# }
|
||||||
base_env = gym.make(base_env_id)
|
base_env = gym.make(base_env_id, render_mode='human' if render else None)
|
||||||
env = fancy_gym.make_bb(env=base_env, wrappers=wrappers, black_box_kwargs={},
|
env = fancy_gym.make_bb(env=base_env, wrappers=wrappers, black_box_kwargs={},
|
||||||
traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs,
|
traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs,
|
||||||
phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs,
|
phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs,
|
||||||
@ -96,7 +96,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
|||||||
# It is also possible to change them mode multiple times when
|
# It is also possible to change them mode multiple times when
|
||||||
# e.g. only every nth trajectory should be displayed.
|
# e.g. only every nth trajectory should be displayed.
|
||||||
if render:
|
if render:
|
||||||
env.render(mode="human")
|
env.render()
|
||||||
|
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
@ -115,7 +115,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
|||||||
env.close()
|
env.close()
|
||||||
del env
|
del env
|
||||||
|
|
||||||
def main(render = True):
|
def main(render = False):
|
||||||
# # Standard DMC Suite tasks
|
# # Standard DMC Suite tasks
|
||||||
example_dmc("dm_control/fish-swim", seed=10, iterations=1000, render=render)
|
example_dmc("dm_control/fish-swim", seed=10, iterations=1000, render=render)
|
||||||
#
|
#
|
||||||
|
@ -21,7 +21,7 @@ def example_general(env_id="Pendulum-v1", seed=1, iterations=1000, render=True):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
env = gym.make(env_id)
|
env = gym.make(env_id, render_mode='human' if render else None)
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset(seed=seed)
|
obs = env.reset(seed=seed)
|
||||||
print("Observation shape: ", env.observation_space.shape)
|
print("Observation shape: ", env.observation_space.shape)
|
||||||
@ -85,7 +85,7 @@ def example_async(env_id="fancy/HoleReacher-v0", n_cpu=4, seed=int('533D', 16),
|
|||||||
# do not return values above threshold
|
# do not return values above threshold
|
||||||
return *map(lambda v: np.stack(v)[:n_samples], buffer.values()),
|
return *map(lambda v: np.stack(v)[:n_samples], buffer.values()),
|
||||||
|
|
||||||
def main(render = True):
|
def main(render = False):
|
||||||
# Basic gym task
|
# Basic gym task
|
||||||
example_general("Pendulum-v1", seed=10, iterations=200, render=render)
|
example_general("Pendulum-v1", seed=10, iterations=200, render=render)
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import gymnasium as gym
|
|||||||
import fancy_gym
|
import fancy_gym
|
||||||
|
|
||||||
|
|
||||||
def example_meta(env_id="fish-swim", seed=1, iterations=1000, render=True):
|
def example_meta(env_id="metaworld/button-press-v2", seed=1, iterations=1000, render=True):
|
||||||
"""
|
"""
|
||||||
Example for running a MetaWorld based env in the step based setting.
|
Example for running a MetaWorld based env in the step based setting.
|
||||||
The env_id has to be specified as `task_name-v2`. V1 versions are not supported and we always
|
The env_id has to be specified as `task_name-v2`. V1 versions are not supported and we always
|
||||||
@ -18,7 +18,7 @@ def example_meta(env_id="fish-swim", seed=1, iterations=1000, render=True):
|
|||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
env = gym.make(env_id)
|
env = gym.make(env_id, render_mode='human' if render else None)
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset(seed=seed)
|
obs = env.reset(seed=seed)
|
||||||
print("observation shape:", env.observation_space.shape)
|
print("observation shape:", env.observation_space.shape)
|
||||||
@ -27,9 +27,7 @@ def example_meta(env_id="fish-swim", seed=1, iterations=1000, render=True):
|
|||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
ac = env.action_space.sample()
|
ac = env.action_space.sample()
|
||||||
if render:
|
if render:
|
||||||
# THIS NEEDS TO BE SET TO FALSE FOR NOW, BECAUSE THE INTERFACE FOR RENDERING IS DIFFERENT TO BASIC GYM
|
env.render()
|
||||||
# TODO: Remove this, when Metaworld fixes its interface.
|
|
||||||
env.render(False)
|
|
||||||
obs, reward, terminated, truncated, info = env.step(ac)
|
obs, reward, terminated, truncated, info = env.step(ac)
|
||||||
rewards += reward
|
rewards += reward
|
||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
@ -81,7 +79,7 @@ def example_custom_meta_and_mp(seed=1, iterations=1, render=True):
|
|||||||
basis_generator_kwargs = {'basis_generator_type': 'rbf',
|
basis_generator_kwargs = {'basis_generator_type': 'rbf',
|
||||||
'num_basis': 5
|
'num_basis': 5
|
||||||
}
|
}
|
||||||
base_env = gym.make(base_env_id)
|
base_env = gym.make(base_env_id, render_mode='human' if render else None)
|
||||||
env = fancy_gym.make_bb(env=base_env, wrappers=wrappers, black_box_kwargs={},
|
env = fancy_gym.make_bb(env=base_env, wrappers=wrappers, black_box_kwargs={},
|
||||||
traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs,
|
traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs,
|
||||||
phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs,
|
phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs,
|
||||||
@ -93,7 +91,7 @@ def example_custom_meta_and_mp(seed=1, iterations=1, render=True):
|
|||||||
# It is also possible to change them mode multiple times when
|
# It is also possible to change them mode multiple times when
|
||||||
# e.g. only every nth trajectory should be displayed.
|
# e.g. only every nth trajectory should be displayed.
|
||||||
if render:
|
if render:
|
||||||
env.render(mode="human")
|
env.render()
|
||||||
|
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset(seed=seed)
|
obs = env.reset(seed=seed)
|
||||||
|
@ -13,15 +13,13 @@ def example_mp(env_name, seed=1, render=True):
|
|||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
env = gym.make(env_name)
|
env = gym.make(env_name, render_mode='human' if render else None)
|
||||||
|
|
||||||
returns = 0
|
returns = 0
|
||||||
obs = env.reset(seed=seed)
|
obs = env.reset(seed=seed)
|
||||||
# number of samples/full trajectories (multiple environment steps)
|
# number of samples/full trajectories (multiple environment steps)
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
if render and i % 2 == 0:
|
if render and i % 2 == 0:
|
||||||
env.render(mode="human")
|
|
||||||
else:
|
|
||||||
env.render()
|
env.render()
|
||||||
ac = env.action_space.sample()
|
ac = env.action_space.sample()
|
||||||
obs, reward, terminated, truncated, info = env.step(ac)
|
obs, reward, terminated, truncated, info = env.step(ac)
|
||||||
|
Loading…
Reference in New Issue
Block a user