Merge branch '47-update-to-new-gym-api' into gym_upgrade
This commit is contained in:
commit
228e343a1b
@ -1,8 +1,9 @@
|
||||
from typing import Tuple, Optional, Callable
|
||||
from typing import Tuple, Optional, Callable, Dict, Any
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
from mp_pytorch.mp.mp_interfaces import MPInterface
|
||||
|
||||
from fancy_gym.black_box.controller.base_controller import BaseController
|
||||
@ -59,7 +60,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
self.reward_aggregation = reward_aggregation
|
||||
|
||||
# spaces
|
||||
self.return_context_observation = not (learn_sub_trajectories or self.do_replanning)
|
||||
self.return_context_observation = not (
|
||||
learn_sub_trajectories or self.do_replanning)
|
||||
self.traj_gen_action_space = self._get_traj_gen_action_space()
|
||||
self.action_space = self._get_action_space()
|
||||
self.observation_space = self._get_observation_space()
|
||||
@ -91,14 +93,17 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
# If we do not do this, the traj_gen assumes we are continuing the trajectory.
|
||||
self.traj_gen.reset()
|
||||
|
||||
clipped_params = np.clip(action, self.traj_gen_action_space.low, self.traj_gen_action_space.high)
|
||||
clipped_params = np.clip(
|
||||
action, self.traj_gen_action_space.low, self.traj_gen_action_space.high)
|
||||
self.traj_gen.set_params(clipped_params)
|
||||
init_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
|
||||
init_time = np.array(
|
||||
0 if not self.do_replanning else self.current_traj_steps * self.dt)
|
||||
|
||||
condition_pos = self.condition_pos if self.condition_pos is not None else self.current_pos
|
||||
condition_vel = self.condition_vel if self.condition_vel is not None else self.current_vel
|
||||
|
||||
self.traj_gen.set_initial_conditions(init_time, condition_pos, condition_vel)
|
||||
self.traj_gen.set_initial_conditions(
|
||||
init_time, condition_pos, condition_vel)
|
||||
self.traj_gen.set_duration(duration, self.dt)
|
||||
|
||||
position = get_numpy(self.traj_gen.get_traj_pos())
|
||||
@ -138,13 +143,15 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
|
||||
|
||||
# TODO remove this part, right now only needed for beer pong
|
||||
mp_params, env_spec_params = self.env.episode_callback(action, self.traj_gen)
|
||||
mp_params, env_spec_params = self.env.episode_callback(
|
||||
action, self.traj_gen)
|
||||
position, velocity = self.get_trajectory(mp_params)
|
||||
|
||||
trajectory_length = len(position)
|
||||
rewards = np.zeros(shape=(trajectory_length,))
|
||||
if self.verbose >= 2:
|
||||
actions = np.zeros(shape=(trajectory_length,) + self.env.action_space.shape)
|
||||
actions = np.zeros(shape=(trajectory_length,) +
|
||||
self.env.action_space.shape)
|
||||
observations = np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape,
|
||||
dtype=self.env.observation_space.dtype)
|
||||
|
||||
@ -153,9 +160,12 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
|
||||
self.plan_steps += 1
|
||||
for t, (pos, vel) in enumerate(zip(position, velocity)):
|
||||
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
|
||||
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
|
||||
obs, c_reward, done, info = self.env.step(c_action)
|
||||
step_action = self.tracking_controller.get_action(
|
||||
pos, vel, self.current_pos, self.current_vel)
|
||||
c_action = np.clip(
|
||||
step_action, self.env.action_space.low, self.env.action_space.high)
|
||||
obs, c_reward, terminated, truncated, info = self.env.step(
|
||||
c_action)
|
||||
rewards[t] = c_reward
|
||||
|
||||
if self.verbose >= 2:
|
||||
@ -170,9 +180,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
if self.render_kwargs:
|
||||
self.env.render(**self.render_kwargs)
|
||||
|
||||
if done or (self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
|
||||
t + 1 + self.current_traj_steps)
|
||||
and self.plan_steps < self.max_planning_times):
|
||||
if terminated or truncated or (self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
|
||||
t + 1 + self.current_traj_steps)
|
||||
and self.plan_steps < self.max_planning_times):
|
||||
|
||||
if self.condition_on_desired:
|
||||
self.condition_pos = pos
|
||||
@ -192,17 +202,18 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
|
||||
infos['trajectory_length'] = t + 1
|
||||
trajectory_return = self.reward_aggregation(rewards[:t + 1])
|
||||
return self.observation(obs), trajectory_return, done, infos
|
||||
return self.observation(obs), trajectory_return, terminated, truncated, infos
|
||||
|
||||
def render(self, **kwargs):
|
||||
"""Only set render options here, such that they can be used during the rollout.
|
||||
This only needs to be called once"""
|
||||
self.render_kwargs = kwargs
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||
-> Tuple[ObsType, Dict[str, Any]]:
|
||||
self.current_traj_steps = 0
|
||||
self.plan_steps = 0
|
||||
self.traj_gen.reset()
|
||||
self.condition_pos = None
|
||||
self.condition_vel = None
|
||||
return super(BlackBoxWrapper, self).reset()
|
||||
return super(BlackBoxWrapper, self).reset(seed=seed, options=options)
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import Union, Tuple
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from mp_pytorch.mp.mp_interfaces import MPInterface
|
||||
|
||||
|
@ -1,14 +1,16 @@
|
||||
from copy import deepcopy
|
||||
|
||||
from gymnasium.wrappers import FlattenObservation
|
||||
|
||||
from . import manipulation, suite
|
||||
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []}
|
||||
|
||||
from gym.envs.registration import register
|
||||
from gymnasium.envs.registration import register
|
||||
|
||||
DEFAULT_BB_DICT_ProMP = {
|
||||
"name": 'EnvName',
|
||||
"wrappers": [],
|
||||
"wrappers": [FlattenObservation],
|
||||
"trajectory_generator_kwargs": {
|
||||
'trajectory_generator_type': 'promp'
|
||||
},
|
||||
@ -29,7 +31,7 @@ DEFAULT_BB_DICT_ProMP = {
|
||||
|
||||
DEFAULT_BB_DICT_DMP = {
|
||||
"name": 'EnvName',
|
||||
"wrappers": [],
|
||||
"wrappers": [FlattenObservation],
|
||||
"trajectory_generator_kwargs": {
|
||||
'trajectory_generator_type': 'dmp'
|
||||
},
|
||||
@ -49,7 +51,7 @@ DEFAULT_BB_DICT_DMP = {
|
||||
|
||||
# DeepMind Control Suite (DMC)
|
||||
kwargs_dict_bic_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_bic_dmp['name'] = f"dmc:ball_in_cup-catch"
|
||||
kwargs_dict_bic_dmp['name'] = f"dm_control/ball_in_cup-catch-v0"
|
||||
kwargs_dict_bic_dmp['wrappers'].append(suite.ball_in_cup.MPWrapper)
|
||||
# bandwidth_factor=2
|
||||
kwargs_dict_bic_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
@ -62,7 +64,7 @@ register(
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0")
|
||||
|
||||
kwargs_dict_bic_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_bic_promp['name'] = f"dmc:ball_in_cup-catch"
|
||||
kwargs_dict_bic_promp['name'] = f"dm_control/ball_in_cup-catch-v0"
|
||||
kwargs_dict_bic_promp['wrappers'].append(suite.ball_in_cup.MPWrapper)
|
||||
register(
|
||||
id=f'dmc_ball_in_cup-catch_promp-v0',
|
||||
@ -72,7 +74,7 @@ register(
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_ball_in_cup-catch_promp-v0")
|
||||
|
||||
kwargs_dict_reacher_easy_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_reacher_easy_dmp['name'] = f"dmc:reacher-easy"
|
||||
kwargs_dict_reacher_easy_dmp['name'] = f"dm_control/reacher-easy-v0"
|
||||
kwargs_dict_reacher_easy_dmp['wrappers'].append(suite.reacher.MPWrapper)
|
||||
# bandwidth_factor=2
|
||||
kwargs_dict_reacher_easy_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
@ -86,7 +88,7 @@ register(
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0")
|
||||
|
||||
kwargs_dict_reacher_easy_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_reacher_easy_promp['name'] = f"dmc:reacher-easy"
|
||||
kwargs_dict_reacher_easy_promp['name'] = f"dm_control/reacher-easy-v0"
|
||||
kwargs_dict_reacher_easy_promp['wrappers'].append(suite.reacher.MPWrapper)
|
||||
kwargs_dict_reacher_easy_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
||||
register(
|
||||
@ -97,7 +99,7 @@ register(
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_reacher-easy_promp-v0")
|
||||
|
||||
kwargs_dict_reacher_hard_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_reacher_hard_dmp['name'] = f"dmc:reacher-hard"
|
||||
kwargs_dict_reacher_hard_dmp['name'] = f"dm_control/reacher-hard-v0"
|
||||
kwargs_dict_reacher_hard_dmp['wrappers'].append(suite.reacher.MPWrapper)
|
||||
# bandwidth_factor = 2
|
||||
kwargs_dict_reacher_hard_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
@ -111,7 +113,7 @@ register(
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0")
|
||||
|
||||
kwargs_dict_reacher_hard_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_reacher_hard_promp['name'] = f"dmc:reacher-hard"
|
||||
kwargs_dict_reacher_hard_promp['name'] = f"dm_control/reacher-hard-v0"
|
||||
kwargs_dict_reacher_hard_promp['wrappers'].append(suite.reacher.MPWrapper)
|
||||
kwargs_dict_reacher_hard_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
||||
register(
|
||||
@ -126,7 +128,7 @@ _dmc_cartpole_tasks = ["balance", "balance_sparse", "swingup", "swingup_sparse"]
|
||||
for _task in _dmc_cartpole_tasks:
|
||||
_env_id = f'dmc_cartpole-{_task}_dmp-v0'
|
||||
kwargs_dict_cartpole_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_cartpole_dmp['name'] = f"dmc:cartpole-{_task}"
|
||||
kwargs_dict_cartpole_dmp['name'] = f"dm_control/cartpole-{_task}-v0"
|
||||
kwargs_dict_cartpole_dmp['wrappers'].append(suite.cartpole.MPWrapper)
|
||||
# bandwidth_factor = 2
|
||||
kwargs_dict_cartpole_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
@ -143,7 +145,7 @@ for _task in _dmc_cartpole_tasks:
|
||||
|
||||
_env_id = f'dmc_cartpole-{_task}_promp-v0'
|
||||
kwargs_dict_cartpole_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_cartpole_promp['name'] = f"dmc:cartpole-{_task}"
|
||||
kwargs_dict_cartpole_promp['name'] = f"dm_control/cartpole-{_task}-v0"
|
||||
kwargs_dict_cartpole_promp['wrappers'].append(suite.cartpole.MPWrapper)
|
||||
kwargs_dict_cartpole_promp['controller_kwargs']['p_gains'] = 10
|
||||
kwargs_dict_cartpole_promp['controller_kwargs']['d_gains'] = 10
|
||||
@ -156,7 +158,7 @@ for _task in _dmc_cartpole_tasks:
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
kwargs_dict_cartpole2poles_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_cartpole2poles_dmp['name'] = f"dmc:cartpole-two_poles"
|
||||
kwargs_dict_cartpole2poles_dmp['name'] = f"dm_control/cartpole-two_poles-v0"
|
||||
kwargs_dict_cartpole2poles_dmp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper)
|
||||
# bandwidth_factor = 2
|
||||
kwargs_dict_cartpole2poles_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
@ -173,7 +175,7 @@ register(
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||
|
||||
kwargs_dict_cartpole2poles_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_cartpole2poles_promp['name'] = f"dmc:cartpole-two_poles"
|
||||
kwargs_dict_cartpole2poles_promp['name'] = f"dm_control/cartpole-two_poles-v0"
|
||||
kwargs_dict_cartpole2poles_promp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper)
|
||||
kwargs_dict_cartpole2poles_promp['controller_kwargs']['p_gains'] = 10
|
||||
kwargs_dict_cartpole2poles_promp['controller_kwargs']['d_gains'] = 10
|
||||
@ -187,7 +189,7 @@ register(
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
kwargs_dict_cartpole3poles_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_cartpole3poles_dmp['name'] = f"dmc:cartpole-three_poles"
|
||||
kwargs_dict_cartpole3poles_dmp['name'] = f"dm_control/cartpole-three_poles-v0"
|
||||
kwargs_dict_cartpole3poles_dmp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper)
|
||||
# bandwidth_factor = 2
|
||||
kwargs_dict_cartpole3poles_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
@ -204,7 +206,7 @@ register(
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||
|
||||
kwargs_dict_cartpole3poles_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_cartpole3poles_promp['name'] = f"dmc:cartpole-three_poles"
|
||||
kwargs_dict_cartpole3poles_promp['name'] = f"dm_control/cartpole-three_poles-v0"
|
||||
kwargs_dict_cartpole3poles_promp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper)
|
||||
kwargs_dict_cartpole3poles_promp['controller_kwargs']['p_gains'] = 10
|
||||
kwargs_dict_cartpole3poles_promp['controller_kwargs']['d_gains'] = 10
|
||||
@ -219,7 +221,7 @@ ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
# DeepMind Manipulation
|
||||
kwargs_dict_mani_reach_site_features_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_mani_reach_site_features_dmp['name'] = f"dmc:manipulation-reach_site_features"
|
||||
kwargs_dict_mani_reach_site_features_dmp['name'] = f"dm_control/reach_site_features-v0"
|
||||
kwargs_dict_mani_reach_site_features_dmp['wrappers'].append(manipulation.reach_site.MPWrapper)
|
||||
kwargs_dict_mani_reach_site_features_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
# TODO: weight scale 50, but goal scale 0.1
|
||||
@ -233,7 +235,7 @@ register(
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0")
|
||||
|
||||
kwargs_dict_mani_reach_site_features_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_mani_reach_site_features_promp['name'] = f"dmc:manipulation-reach_site_features"
|
||||
kwargs_dict_mani_reach_site_features_promp['name'] = f"dm_control/reach_site_features-v0"
|
||||
kwargs_dict_mani_reach_site_features_promp['wrappers'].append(manipulation.reach_site.MPWrapper)
|
||||
kwargs_dict_mani_reach_site_features_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
||||
kwargs_dict_mani_reach_site_features_promp['controller_kwargs']['controller_type'] = 'velocity'
|
||||
|
@ -3,15 +3,15 @@
|
||||
# Copyright (c) 2020 Denis Yarats
|
||||
import collections
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Any, Dict, Tuple, Optional, Union, Callable
|
||||
from typing import Any, Dict, Tuple, Optional, Union, Callable, SupportsFloat
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from dm_control import composer
|
||||
from dm_control.rl import control
|
||||
from dm_env import specs
|
||||
from gym import spaces
|
||||
from gym.core import ObsType
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType, ActType
|
||||
|
||||
|
||||
def _spec_to_box(spec):
|
||||
@ -100,23 +100,23 @@ class DMCWrapper(gym.Env):
|
||||
self._action_space.seed(seed)
|
||||
self._observation_space.seed(seed)
|
||||
|
||||
def step(self, action) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]:
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||
assert self._action_space.contains(action)
|
||||
extra = {'internal_state': self._env.physics.get_state().copy()}
|
||||
|
||||
time_step = self._env.step(action)
|
||||
reward = time_step.reward or 0.
|
||||
done = time_step.last()
|
||||
terminated = False
|
||||
truncated = time_step.last() and time_step.discount > 0
|
||||
obs = self._get_obs(time_step)
|
||||
extra['discount'] = time_step.discount
|
||||
|
||||
return obs, reward, done, extra
|
||||
return obs, reward, terminated, truncated, extra
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, return_info: bool = False,
|
||||
options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||
-> Tuple[ObsType, Dict[str, Any]]:
|
||||
time_step = self._env.reset()
|
||||
obs = self._get_obs(time_step)
|
||||
return obs
|
||||
return obs, {}
|
||||
|
||||
def render(self, mode='rgb_array', height=240, width=320, camera_id=-1, overlays=(), depth=False,
|
||||
segmentation=False, scene_option=None, render_flag_overrides=None):
|
||||
|
@ -35,4 +35,4 @@ class MPWrapper(RawInterfaceWrapper):
|
||||
|
||||
@property
|
||||
def dt(self) -> Union[float, int]:
|
||||
return self.env.dt
|
||||
return self.env.control_timestep()
|
||||
|
@ -31,4 +31,4 @@ class MPWrapper(RawInterfaceWrapper):
|
||||
|
||||
@property
|
||||
def dt(self) -> Union[float, int]:
|
||||
return self.env.dt
|
||||
return self.env.control_timestep()
|
||||
|
@ -35,7 +35,7 @@ class MPWrapper(RawInterfaceWrapper):
|
||||
|
||||
@property
|
||||
def dt(self) -> Union[float, int]:
|
||||
return self.env.dt
|
||||
return self.env.control_timestep()
|
||||
|
||||
|
||||
class TwoPolesMPWrapper(MPWrapper):
|
||||
|
@ -30,4 +30,4 @@ class MPWrapper(RawInterfaceWrapper):
|
||||
|
||||
@property
|
||||
def dt(self) -> Union[float, int]:
|
||||
return self.env.dt
|
||||
return self.env.control_timestep()
|
||||
|
@ -1,7 +1,7 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
from gym import register
|
||||
from gymnasium import register
|
||||
|
||||
from . import classic_control, mujoco
|
||||
from .classic_control.hole_reacher.hole_reacher import HoleReacherEnv
|
||||
|
@ -1,10 +1,10 @@
|
||||
from typing import Union, Tuple, Optional
|
||||
from typing import Union, Tuple, Optional, Any, Dict
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
from gym.core import ObsType
|
||||
from gym.utils import seeding
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
from gymnasium.utils import seeding
|
||||
|
||||
from fancy_gym.envs.classic_control.utils import intersect
|
||||
|
||||
@ -55,7 +55,6 @@ class BaseReacherEnv(gym.Env):
|
||||
self.fig = None
|
||||
|
||||
self._steps = 0
|
||||
self.seed()
|
||||
|
||||
@property
|
||||
def dt(self) -> Union[float, int]:
|
||||
@ -69,10 +68,15 @@ class BaseReacherEnv(gym.Env):
|
||||
def current_vel(self):
|
||||
return self._angle_velocity.copy()
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, return_info: bool = False,
|
||||
options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||
-> Tuple[ObsType, Dict[str, Any]]:
|
||||
# Sample only orientation of first link, i.e. the arm is always straight.
|
||||
if self.random_start:
|
||||
super(BaseReacherEnv, self).reset(seed=seed, options=options)
|
||||
try:
|
||||
random_start = options.get('random_start', self.random_start)
|
||||
except AttributeError:
|
||||
random_start = self.random_start
|
||||
if random_start:
|
||||
first_joint = self.np_random.uniform(np.pi / 4, 3 * np.pi / 4)
|
||||
self._joint_angles = np.hstack([[first_joint], np.zeros(self.n_links - 1)])
|
||||
self._start_pos = self._joint_angles.copy()
|
||||
@ -84,7 +88,7 @@ class BaseReacherEnv(gym.Env):
|
||||
self._update_joints()
|
||||
self._steps = 0
|
||||
|
||||
return self._get_obs().copy()
|
||||
return self._get_obs().copy(), {}
|
||||
|
||||
def _update_joints(self):
|
||||
"""
|
||||
@ -124,10 +128,6 @@ class BaseReacherEnv(gym.Env):
|
||||
def _terminate(self, info) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def close(self):
|
||||
super(BaseReacherEnv, self).close()
|
||||
del self.fig
|
||||
|
@ -1,5 +1,5 @@
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
|
||||
from fancy_gym.envs.classic_control.base_reacher.base_reacher import BaseReacherEnv
|
||||
|
||||
@ -32,6 +32,7 @@ class BaseReacherDirectEnv(BaseReacherEnv):
|
||||
reward, info = self._get_reward(action)
|
||||
|
||||
self._steps += 1
|
||||
done = self._terminate(info)
|
||||
terminated = self._terminate(info)
|
||||
truncated = False
|
||||
|
||||
return self._get_obs().copy(), reward, done, info
|
||||
return self._get_obs().copy(), reward, terminated, truncated, info
|
||||
|
@ -1,5 +1,5 @@
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
from gymnasium import spaces
|
||||
|
||||
from fancy_gym.envs.classic_control.base_reacher.base_reacher import BaseReacherEnv
|
||||
|
||||
@ -31,6 +31,7 @@ class BaseReacherTorqueEnv(BaseReacherEnv):
|
||||
reward, info = self._get_reward(action)
|
||||
|
||||
self._steps += 1
|
||||
done = False
|
||||
terminated = False
|
||||
truncated = False
|
||||
|
||||
return self._get_obs().copy(), reward, done, info
|
||||
return self._get_obs().copy(), reward, terminated, truncated, info
|
||||
|
@ -1,9 +1,10 @@
|
||||
from typing import Union, Optional, Tuple
|
||||
from typing import Union, Optional, Tuple, Any, Dict
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from gym.core import ObsType
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
from matplotlib import patches
|
||||
|
||||
from fancy_gym.envs.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv
|
||||
@ -40,7 +41,7 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
||||
[np.inf] # env steps, because reward start after n steps TODO: Maybe
|
||||
])
|
||||
# self.action_space = gym.spaces.Box(low=-action_bound, high=action_bound, shape=action_bound.shape)
|
||||
self.observation_space = gym.spaces.Box(low=-state_bound, high=state_bound, shape=state_bound.shape)
|
||||
self.observation_space = spaces.Box(low=-state_bound, high=state_bound, shape=state_bound.shape)
|
||||
|
||||
if rew_fct == "simple":
|
||||
from fancy_gym.envs.classic_control.hole_reacher.hr_simple_reward import HolereacherReward
|
||||
@ -54,13 +55,18 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
||||
else:
|
||||
raise ValueError("Unknown reward function {}".format(rew_fct))
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, return_info: bool = False,
|
||||
options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||
-> Tuple[ObsType, Dict[str, Any]]:
|
||||
|
||||
# initialize seed here as the random goal needs to be generated before the super reset()
|
||||
gym.Env.reset(self, seed=seed, options=options)
|
||||
|
||||
self._generate_hole()
|
||||
self._set_patches()
|
||||
self.reward_function.reset()
|
||||
|
||||
return super().reset()
|
||||
# do not provide seed to avoid setting it twice
|
||||
return super(HoleReacherEnv, self).reset(options=options)
|
||||
|
||||
def _get_reward(self, action: np.ndarray) -> (float, dict):
|
||||
return self.reward_function.get_reward(self)
|
||||
@ -223,16 +229,3 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
||||
self.fig.gca().add_patch(left_block)
|
||||
self.fig.gca().add_patch(right_block)
|
||||
self.fig.gca().add_patch(hole_floor)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
env = HoleReacherEnv(5)
|
||||
env.reset()
|
||||
|
||||
for i in range(10000):
|
||||
ac = env.action_space.sample()
|
||||
obs, rew, done, info = env.step(ac)
|
||||
env.render()
|
||||
if done:
|
||||
env.reset()
|
||||
|
@ -1,9 +1,9 @@
|
||||
from typing import Iterable, Union, Optional, Tuple
|
||||
from typing import Iterable, Union, Optional, Tuple, Any, Dict
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
from gym.core import ObsType
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from fancy_gym.envs.classic_control.base_reacher.base_reacher_torque import BaseReacherTorqueEnv
|
||||
|
||||
@ -42,11 +42,10 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
|
||||
# def start_pos(self):
|
||||
# return self._start_pos
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, return_info: bool = False,
|
||||
options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||
-> Tuple[ObsType, Dict[str, Any]]:
|
||||
self._generate_goal()
|
||||
|
||||
return super().reset()
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
def _get_reward(self, action: np.ndarray):
|
||||
diff = self.end_effector - self._goal
|
||||
@ -128,14 +127,3 @@ class SimpleReacherEnv(BaseReacherTorqueEnv):
|
||||
self.fig.canvas.draw()
|
||||
self.fig.canvas.flush_events()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
env = SimpleReacherEnv(5)
|
||||
env.reset()
|
||||
for i in range(200):
|
||||
ac = env.action_space.sample()
|
||||
obs, rew, done, info = env.step(ac)
|
||||
|
||||
env.render()
|
||||
if done:
|
||||
break
|
||||
|
@ -1,9 +1,10 @@
|
||||
from typing import Iterable, Union, Tuple, Optional
|
||||
from typing import Iterable, Union, Tuple, Optional, Any, Dict
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from gym.core import ObsType
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from fancy_gym.envs.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv
|
||||
|
||||
@ -34,16 +35,16 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
|
||||
[np.inf] * 2, # x-y coordinates of target distance
|
||||
[np.inf] # env steps, because reward start after n steps
|
||||
])
|
||||
self.observation_space = gym.spaces.Box(low=-state_bound, high=state_bound, shape=state_bound.shape)
|
||||
self.observation_space = spaces.Box(low=-state_bound, high=state_bound, shape=state_bound.shape)
|
||||
|
||||
# @property
|
||||
# def start_pos(self):
|
||||
# return self._start_pos
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, return_info: bool = False,
|
||||
options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||
-> Tuple[ObsType, Dict[str, Any]]:
|
||||
self._generate_goal()
|
||||
return super().reset()
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
def _generate_goal(self):
|
||||
# TODO: Maybe improve this later, this can yield quite a lot of invalid settings
|
||||
@ -185,14 +186,3 @@ class ViaPointReacherEnv(BaseReacherDirectEnv):
|
||||
plt.pause(0.01)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
env = ViaPointReacherEnv(5)
|
||||
env.reset()
|
||||
|
||||
for i in range(10000):
|
||||
ac = env.action_space.sample()
|
||||
obs, rew, done, info = env.step(ac)
|
||||
env.render()
|
||||
if done:
|
||||
env.reset()
|
||||
|
@ -1,8 +1,8 @@
|
||||
from typing import Tuple, Union, Optional
|
||||
from typing import Tuple, Union, Optional, Any, Dict
|
||||
|
||||
import numpy as np
|
||||
from gym.core import ObsType
|
||||
from gym.envs.mujoco.ant_v4 import AntEnv
|
||||
from gymnasium.core import ObsType
|
||||
from gymnasium.envs.mujoco.ant_v4 import AntEnv
|
||||
|
||||
MAX_EPISODE_STEPS_ANTJUMP = 200
|
||||
|
||||
@ -61,9 +61,10 @@ class AntJumpEnv(AntEnv):
|
||||
|
||||
costs = ctrl_cost + contact_cost
|
||||
|
||||
done = bool(height < 0.3) # fall over -> is the 0.3 value from healthy_z_range? TODO change 0.3 to the value of healthy z angle
|
||||
terminated = bool(
|
||||
height < 0.3) # fall over -> is the 0.3 value from healthy_z_range? TODO change 0.3 to the value of healthy z angle
|
||||
|
||||
if self.current_step == MAX_EPISODE_STEPS_ANTJUMP or done:
|
||||
if self.current_step == MAX_EPISODE_STEPS_ANTJUMP or terminated:
|
||||
# -10 for scaling the value of the distance between the max_height and the goal height; only used when context is enabled
|
||||
# height_reward = -10 * (np.linalg.norm(self.max_height - self.goal))
|
||||
height_reward = -10 * np.linalg.norm(self.max_height - self.goal)
|
||||
@ -80,19 +81,20 @@ class AntJumpEnv(AntEnv):
|
||||
'max_height': self.max_height,
|
||||
'goal': self.goal
|
||||
}
|
||||
truncated = False
|
||||
|
||||
return obs, reward, done, info
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def _get_obs(self):
|
||||
return np.append(super()._get_obs(), self.goal)
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, return_info: bool = False,
|
||||
options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||
-> Tuple[ObsType, Dict[str, Any]]:
|
||||
self.current_step = 0
|
||||
self.max_height = 0
|
||||
# goal heights from 1.0 to 2.5; can be increased, but didnt work well with CMORE
|
||||
self.goal = self.np_random.uniform(1.0, 2.5, 1)
|
||||
return super().reset()
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
# reset_model had to be implemented in every env to make it deterministic
|
||||
def reset_model(self):
|
||||
|
@ -1,9 +1,10 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Optional, Any, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
from gym import utils
|
||||
from gym.envs.mujoco import MujocoEnv
|
||||
from gymnasium import utils
|
||||
from gymnasium.core import ObsType
|
||||
from gymnasium.envs.mujoco import MujocoEnv
|
||||
|
||||
MAX_EPISODE_STEPS_BEERPONG = 300
|
||||
FIXED_RELEASE_STEP = 62 # empirically evaluated for frame_skip=2!
|
||||
@ -30,7 +31,7 @@ CUP_COLLISION_OBJ = ["cup_geom_table3", "cup_geom_table4", "cup_geom_table5", "c
|
||||
|
||||
|
||||
class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
||||
def __init__(self):
|
||||
def __init__(self, **kwargs):
|
||||
self._steps = 0
|
||||
# Small Context -> Easier. Todo: Should we do different versions?
|
||||
# self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "beerpong_wo_cup.xml")
|
||||
@ -65,7 +66,13 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
||||
self.ball_in_cup = False
|
||||
self.dist_ground_cup = -1 # distance floor to cup if first floor contact
|
||||
|
||||
MujocoEnv.__init__(self, model_path=self.xml_path, frame_skip=1, mujoco_bindings="mujoco")
|
||||
MujocoEnv.__init__(
|
||||
self,
|
||||
self.xml_path,
|
||||
frame_skip=1,
|
||||
observation_space=self.observation_space,
|
||||
**kwargs
|
||||
)
|
||||
utils.EzPickle.__init__(self)
|
||||
|
||||
@property
|
||||
@ -76,7 +83,8 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
||||
def start_vel(self):
|
||||
return self._start_vel
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||
-> Tuple[ObsType, Dict[str, Any]]:
|
||||
self.dists = []
|
||||
self.dists_final = []
|
||||
self.action_costs = []
|
||||
@ -86,7 +94,7 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
||||
self.ball_cup_contact = False
|
||||
self.ball_in_cup = False
|
||||
self.dist_ground_cup = -1 # distance floor to cup if first floor contact
|
||||
return super().reset()
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
def reset_model(self):
|
||||
init_pos_all = self.init_qpos.copy()
|
||||
@ -128,11 +136,11 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
||||
if not crash:
|
||||
reward, reward_infos = self._get_reward(applied_action)
|
||||
is_collided = reward_infos['is_collided'] # TODO: Remove if self collision does not make a difference
|
||||
done = is_collided
|
||||
terminated = is_collided
|
||||
self._steps += 1
|
||||
else:
|
||||
reward = -30
|
||||
done = True
|
||||
terminated = True
|
||||
reward_infos = {"success": False, "ball_pos": np.zeros(3), "ball_vel": np.zeros(3), "is_collided": False}
|
||||
|
||||
infos = dict(
|
||||
@ -142,7 +150,10 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
||||
q_vel=self.data.qvel[0:7].ravel().copy(), sim_crash=crash,
|
||||
)
|
||||
infos.update(reward_infos)
|
||||
return ob, reward, done, infos
|
||||
|
||||
truncated = False
|
||||
|
||||
return ob, reward, terminated, truncated, infos
|
||||
|
||||
def _get_obs(self):
|
||||
theta = self.data.qpos.flat[:7].copy()
|
||||
@ -258,9 +269,9 @@ class BeerPongEnvStepBasedEpisodicReward(BeerPongEnv):
|
||||
return super(BeerPongEnvStepBasedEpisodicReward, self).step(a)
|
||||
else:
|
||||
reward = 0
|
||||
done = True
|
||||
terminated, truncated = True, False
|
||||
while self._steps < MAX_EPISODE_STEPS_BEERPONG:
|
||||
obs, sub_reward, done, infos = super(BeerPongEnvStepBasedEpisodicReward, self).step(
|
||||
obs, sub_reward, terminated, truncated, infos = super(BeerPongEnvStepBasedEpisodicReward, self).step(
|
||||
np.zeros(a.shape))
|
||||
reward += sub_reward
|
||||
return obs, reward, done, infos
|
||||
return obs, reward, terminated, truncated, infos
|
||||
|
@ -2,8 +2,8 @@ import os
|
||||
|
||||
import mujoco_py.builder
|
||||
import numpy as np
|
||||
from gym import utils
|
||||
from gym.envs.mujoco import MujocoEnv
|
||||
from gymnasium import utils
|
||||
from gymnasium.envs.mujoco import MujocoEnv
|
||||
|
||||
from fancy_gym.envs.mujoco.beerpong.deprecated.beerpong_reward_staged import BeerPongReward
|
||||
|
||||
@ -90,11 +90,11 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
||||
if not crash:
|
||||
reward, reward_infos = self.reward_function.compute_reward(self, applied_action)
|
||||
is_collided = reward_infos['is_collided']
|
||||
done = is_collided or self._steps == self.ep_length - 1
|
||||
terminated = is_collided or self._steps == self.ep_length - 1
|
||||
self._steps += 1
|
||||
else:
|
||||
reward = -30
|
||||
done = True
|
||||
terminated = True
|
||||
reward_infos = {"success": False, "ball_pos": np.zeros(3), "ball_vel": np.zeros(3), "is_collided": False}
|
||||
|
||||
infos = dict(
|
||||
@ -104,7 +104,7 @@ class BeerPongEnv(MujocoEnv, utils.EzPickle):
|
||||
q_vel=self.sim.data.qvel[0:7].ravel().copy(), sim_crash=crash,
|
||||
)
|
||||
infos.update(reward_infos)
|
||||
return ob, reward, done, infos
|
||||
return ob, reward, terminated, infos
|
||||
|
||||
def _get_obs(self):
|
||||
theta = self.sim.data.qpos.flat[:7]
|
||||
@ -143,16 +143,16 @@ class BeerPongEnvStepBasedEpisodicReward(BeerPongEnv):
|
||||
return super(BeerPongEnvStepBasedEpisodicReward, self).step(a)
|
||||
else:
|
||||
reward = 0
|
||||
done = False
|
||||
while not done:
|
||||
sub_ob, sub_reward, done, sub_infos = super(BeerPongEnvStepBasedEpisodicReward, self).step(
|
||||
np.zeros(a.shape))
|
||||
terminated, truncated = False, False
|
||||
while not (terminated or truncated):
|
||||
sub_ob, sub_reward, terminated, truncated, sub_infos = super(BeerPongEnvStepBasedEpisodicReward,
|
||||
self).step(np.zeros(a.shape))
|
||||
reward += sub_reward
|
||||
infos = sub_infos
|
||||
ob = sub_ob
|
||||
ob[-1] = self.release_step + 1 # Since we simulate until the end of the episode, PPO does not see the
|
||||
# internal steps and thus, the observation also needs to be set correctly
|
||||
return ob, reward, done, infos
|
||||
return ob, reward, terminated, truncated, infos
|
||||
|
||||
|
||||
# class BeerBongEnvStepBased(BeerBongEnv):
|
||||
@ -186,27 +186,3 @@ class BeerPongEnvStepBasedEpisodicReward(BeerPongEnv):
|
||||
# ob[-1] = self.release_step + 1 # Since we simulate until the end of the episode, PPO does not see the
|
||||
# # internal steps and thus, the observation also needs to be set correctly
|
||||
# return ob, reward, done, infos
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
env = BeerPongEnv(frame_skip=2)
|
||||
env.seed(0)
|
||||
# env = BeerBongEnvStepBased(frame_skip=2)
|
||||
# env = BeerBongEnvStepBasedEpisodicReward(frame_skip=2)
|
||||
# env = BeerBongEnvFixedReleaseStep(frame_skip=2)
|
||||
import time
|
||||
|
||||
env.reset()
|
||||
env.render("human")
|
||||
for i in range(600):
|
||||
# ac = 10 * env.action_space.sample()
|
||||
ac = 0.05 * np.ones(7)
|
||||
obs, rew, d, info = env.step(ac)
|
||||
env.render("human")
|
||||
|
||||
if d:
|
||||
print('reward:', rew)
|
||||
print('RESETTING')
|
||||
env.reset()
|
||||
time.sleep(1)
|
||||
env.close()
|
||||
|
@ -1,9 +1,9 @@
|
||||
import os
|
||||
from typing import Tuple, Union, Optional
|
||||
from typing import Tuple, Union, Optional, Any, Dict
|
||||
|
||||
import numpy as np
|
||||
from gym.core import ObsType
|
||||
from gym.envs.mujoco.half_cheetah_v4 import HalfCheetahEnv
|
||||
from gymnasium.core import ObsType
|
||||
from gymnasium.envs.mujoco.half_cheetah_v4 import HalfCheetahEnv
|
||||
|
||||
MAX_EPISODE_STEPS_HALFCHEETAHJUMP = 100
|
||||
|
||||
@ -44,7 +44,8 @@ class HalfCheetahJumpEnv(HalfCheetahEnv):
|
||||
## Didnt use fell_over, because base env also has no done condition - Paul and Marc
|
||||
# fell_over = abs(self.sim.data.qpos[2]) > 2.5 # how to figure out if the cheetah fell over? -> 2.5 oke?
|
||||
# TODO: Should a fall over be checked here?
|
||||
done = False
|
||||
terminated = False
|
||||
truncated = False
|
||||
|
||||
ctrl_cost = self.control_cost(action)
|
||||
costs = ctrl_cost
|
||||
@ -63,17 +64,17 @@ class HalfCheetahJumpEnv(HalfCheetahEnv):
|
||||
'max_height': self.max_height
|
||||
}
|
||||
|
||||
return observation, reward, done, info
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def _get_obs(self):
|
||||
return np.append(super()._get_obs(), self.goal)
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, return_info: bool = False,
|
||||
options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||
-> Tuple[ObsType, Dict[str, Any]]:
|
||||
self.max_height = 0
|
||||
self.current_step = 0
|
||||
self.goal = self.np_random.uniform(1.1, 1.6, 1) # 1.1 1.6
|
||||
return super().reset()
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
# overwrite reset_model to make it deterministic
|
||||
def reset_model(self):
|
||||
|
@ -1,7 +1,7 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from gym.envs.mujoco.hopper_v4 import HopperEnv
|
||||
from gymnasium.envs.mujoco.hopper_v4 import HopperEnv
|
||||
|
||||
MAX_EPISODE_STEPS_HOPPERJUMP = 250
|
||||
|
||||
@ -73,7 +73,7 @@ class HopperJumpEnv(HopperEnv):
|
||||
self.do_simulation(action, self.frame_skip)
|
||||
|
||||
height_after = self.get_body_com("torso")[2]
|
||||
#site_pos_after = self.data.get_site_xpos('foot_site')
|
||||
# site_pos_after = self.data.get_site_xpos('foot_site')
|
||||
site_pos_after = self.data.site('foot_site').xpos
|
||||
self.max_height = max(height_after, self.max_height)
|
||||
|
||||
@ -88,7 +88,8 @@ class HopperJumpEnv(HopperEnv):
|
||||
|
||||
ctrl_cost = self.control_cost(action)
|
||||
costs = ctrl_cost
|
||||
done = False
|
||||
terminated = False
|
||||
truncated = False
|
||||
|
||||
goal_dist = np.linalg.norm(site_pos_after - self.goal)
|
||||
if self.contact_dist is None and self.contact_with_floor:
|
||||
@ -115,7 +116,7 @@ class HopperJumpEnv(HopperEnv):
|
||||
healthy=self.is_healthy,
|
||||
contact_dist=self.contact_dist or 0
|
||||
)
|
||||
return observation, reward, done, info
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def _get_obs(self):
|
||||
# goal_dist = self.data.get_site_xpos('foot_site') - self.goal
|
||||
|
@ -1,7 +1,9 @@
|
||||
import os
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
|
||||
import numpy as np
|
||||
from gym.envs.mujoco.hopper_v4 import HopperEnv
|
||||
from gymnasium.core import ObsType
|
||||
from gymnasium.envs.mujoco.hopper_v4 import HopperEnv
|
||||
|
||||
MAX_EPISODE_STEPS_HOPPERJUMPONBOX = 250
|
||||
|
||||
@ -74,10 +76,10 @@ class HopperJumpOnBoxEnv(HopperEnv):
|
||||
|
||||
costs = ctrl_cost
|
||||
|
||||
done = fell_over or self.hopper_on_box
|
||||
terminated = fell_over or self.hopper_on_box
|
||||
|
||||
if self.current_step >= self.max_episode_steps or done:
|
||||
done = False
|
||||
if self.current_step >= self.max_episode_steps or terminated:
|
||||
done = False # TODO why are we doing this???
|
||||
|
||||
max_height = self.max_height.copy()
|
||||
min_distance = self.min_distance.copy()
|
||||
@ -122,12 +124,13 @@ class HopperJumpOnBoxEnv(HopperEnv):
|
||||
'goal': self.box_x,
|
||||
}
|
||||
|
||||
return observation, reward, done, info
|
||||
return observation, reward, terminated, info
|
||||
|
||||
def _get_obs(self):
|
||||
return np.append(super()._get_obs(), self.box_x)
|
||||
|
||||
def reset(self):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||
-> Tuple[ObsType, Dict[str, Any]]:
|
||||
|
||||
self.max_height = 0
|
||||
self.min_distance = 5000
|
||||
@ -136,7 +139,7 @@ class HopperJumpOnBoxEnv(HopperEnv):
|
||||
if self.context:
|
||||
self.box_x = self.np_random.uniform(1, 3, 1)
|
||||
self.model.body("box").pos = [self.box_x[0], 0, 0]
|
||||
return super().reset()
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
# overwrite reset_model to make it deterministic
|
||||
def reset_model(self):
|
||||
@ -151,20 +154,5 @@ class HopperJumpOnBoxEnv(HopperEnv):
|
||||
observation = self._get_obs()
|
||||
return observation
|
||||
|
||||
if __name__ == '__main__':
|
||||
render_mode = "human" # "human" or "partial" or "final"
|
||||
env = HopperJumpOnBoxEnv()
|
||||
obs = env.reset()
|
||||
|
||||
for i in range(2000):
|
||||
# objective.load_result("/tmp/cma")
|
||||
# test with random actions
|
||||
ac = env.action_space.sample()
|
||||
obs, rew, d, info = env.step(ac)
|
||||
if i % 10 == 0:
|
||||
env.render(mode=render_mode)
|
||||
if d:
|
||||
print('After ', i, ' steps, done: ', d)
|
||||
env.reset()
|
||||
|
||||
env.close()
|
@ -1,8 +1,9 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Optional, Any, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
from gym.envs.mujoco.hopper_v4 import HopperEnv
|
||||
from gymnasium.core import ObsType
|
||||
from gymnasium.envs.mujoco.hopper_v4 import HopperEnv
|
||||
|
||||
MAX_EPISODE_STEPS_HOPPERTHROW = 250
|
||||
|
||||
@ -56,14 +57,14 @@ class HopperThrowEnv(HopperEnv):
|
||||
|
||||
# done = self.done TODO We should use this, not sure why there is no other termination; ball_landed should be enough, because we only look at the throw itself? - Paul and Marc
|
||||
ball_landed = bool(self.get_body_com("ball")[2] <= 0.05)
|
||||
done = ball_landed
|
||||
terminated = ball_landed
|
||||
|
||||
ctrl_cost = self.control_cost(action)
|
||||
costs = ctrl_cost
|
||||
|
||||
rewards = 0
|
||||
|
||||
if self.current_step >= self.max_episode_steps or done:
|
||||
if self.current_step >= self.max_episode_steps or terminated:
|
||||
distance_reward = -np.linalg.norm(ball_pos_after - self.goal) if self.context else \
|
||||
self._forward_reward_weight * ball_pos_after
|
||||
healthy_reward = 0 if self.context else self.healthy_reward * self.current_step
|
||||
@ -78,16 +79,18 @@ class HopperThrowEnv(HopperEnv):
|
||||
'_steps': self.current_step,
|
||||
'goal': self.goal,
|
||||
}
|
||||
truncated = False
|
||||
|
||||
return observation, reward, done, info
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def _get_obs(self):
|
||||
return np.append(super()._get_obs(), self.goal)
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||
-> Tuple[ObsType, Dict[str, Any]]:
|
||||
self.current_step = 0
|
||||
self.goal = self.goal = self.np_random.uniform(2.0, 6.0, 1) # 0.5 8.0
|
||||
return super().reset()
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
# overwrite reset_model to make it deterministic
|
||||
def reset_model(self):
|
||||
@ -103,20 +106,3 @@ class HopperThrowEnv(HopperEnv):
|
||||
return observation
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
render_mode = "human" # "human" or "partial" or "final"
|
||||
env = HopperThrowEnv()
|
||||
obs = env.reset()
|
||||
|
||||
for i in range(2000):
|
||||
# objective.load_result("/tmp/cma")
|
||||
# test with random actions
|
||||
ac = env.action_space.sample()
|
||||
obs, rew, d, info = env.step(ac)
|
||||
if i % 10 == 0:
|
||||
env.render(mode=render_mode)
|
||||
if d:
|
||||
print('After ', i, ' steps, done: ', d)
|
||||
env.reset()
|
||||
|
||||
env.close()
|
||||
|
@ -1,8 +1,9 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Optional, Any, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
from gym.envs.mujoco.hopper_v4 import HopperEnv
|
||||
from gymnasium.envs.mujoco.hopper_v4 import HopperEnv
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
MAX_EPISODE_STEPS_HOPPERTHROWINBASKET = 250
|
||||
|
||||
@ -72,7 +73,7 @@ class HopperThrowInBasketEnv(HopperEnv):
|
||||
self.ball_in_basket = True
|
||||
|
||||
ball_landed = self.get_body_com("ball")[2] <= 0.05
|
||||
done = bool(ball_landed or is_in_basket)
|
||||
terminated = bool(ball_landed or is_in_basket)
|
||||
|
||||
rewards = 0
|
||||
|
||||
@ -80,7 +81,7 @@ class HopperThrowInBasketEnv(HopperEnv):
|
||||
|
||||
costs = ctrl_cost
|
||||
|
||||
if self.current_step >= self.max_episode_steps or done:
|
||||
if self.current_step >= self.max_episode_steps or terminated:
|
||||
|
||||
if is_in_basket:
|
||||
if not self.context:
|
||||
@ -101,13 +102,16 @@ class HopperThrowInBasketEnv(HopperEnv):
|
||||
info = {
|
||||
'ball_pos': ball_pos[0],
|
||||
}
|
||||
truncated = False
|
||||
|
||||
return observation, reward, done, info
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def _get_obs(self):
|
||||
return np.append(super()._get_obs(), self.basket_x)
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||
-> Tuple[ObsType, Dict[str, Any]]:
|
||||
|
||||
if self.max_episode_steps == 10:
|
||||
# We have to initialize this here, because the spec is only added after creating the env.
|
||||
self.max_episode_steps = self.spec.max_episode_steps
|
||||
@ -117,7 +121,7 @@ class HopperThrowInBasketEnv(HopperEnv):
|
||||
if self.context:
|
||||
self.basket_x = self.np_random.uniform(low=3, high=7, size=1)
|
||||
self.model.body("basket_ground").pos[:] = [self.basket_x[0], 0, 0]
|
||||
return super().reset()
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
# overwrite reset_model to make it deterministic
|
||||
def reset_model(self):
|
||||
@ -134,20 +138,4 @@ class HopperThrowInBasketEnv(HopperEnv):
|
||||
return observation
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
render_mode = "human" # "human" or "partial" or "final"
|
||||
env = HopperThrowInBasketEnv()
|
||||
obs = env.reset()
|
||||
|
||||
for i in range(2000):
|
||||
# objective.load_result("/tmp/cma")
|
||||
# test with random actions
|
||||
ac = env.action_space.sample()
|
||||
obs, rew, d, info = env.step(ac)
|
||||
if i % 10 == 0:
|
||||
env.render(mode=render_mode)
|
||||
if d:
|
||||
print('After ', i, ' steps, done: ', d)
|
||||
env.reset()
|
||||
|
||||
env.close()
|
||||
|
@ -1,8 +1,9 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from gym import utils
|
||||
from gym.envs.mujoco import MujocoEnv
|
||||
from gymnasium import utils
|
||||
from gymnasium.envs.mujoco import MujocoEnv
|
||||
from gymnasium.spaces import Box
|
||||
|
||||
MAX_EPISODE_STEPS_REACHER = 200
|
||||
|
||||
@ -12,7 +13,17 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
||||
More general version of the gym mujoco Reacher environment
|
||||
"""
|
||||
|
||||
def __init__(self, sparse: bool = False, n_links: int = 5, reward_weight: float = 1, ctrl_cost_weight: float = 1):
|
||||
metadata = {
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"depth_array",
|
||||
],
|
||||
"render_fps": 50,
|
||||
}
|
||||
|
||||
def __init__(self, sparse: bool = False, n_links: int = 5, reward_weight: float = 1, ctrl_cost_weight: float = 1.,
|
||||
**kwargs):
|
||||
utils.EzPickle.__init__(**locals())
|
||||
|
||||
self._steps = 0
|
||||
@ -25,10 +36,16 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
file_name = f'reacher_{n_links}links.xml'
|
||||
|
||||
# sin, cos, velocity * n_Links + goal position (2) and goal distance (3)
|
||||
shape = (self.n_links * 3 + 5,)
|
||||
observation_space = Box(low=-np.inf, high=np.inf, shape=shape, dtype=np.float64)
|
||||
|
||||
MujocoEnv.__init__(self,
|
||||
model_path=os.path.join(os.path.dirname(__file__), "assets", file_name),
|
||||
frame_skip=2,
|
||||
mujoco_bindings="mujoco")
|
||||
observation_space=observation_space,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def step(self, action):
|
||||
self._steps += 1
|
||||
@ -45,10 +62,14 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
reward = reward_dist + reward_ctrl + angular_vel
|
||||
self.do_simulation(action, self.frame_skip)
|
||||
ob = self._get_obs()
|
||||
done = False
|
||||
if self.render_mode == "human":
|
||||
self.render()
|
||||
|
||||
infos = dict(
|
||||
ob = self._get_obs()
|
||||
terminated = False
|
||||
truncated = False
|
||||
|
||||
info = dict(
|
||||
reward_dist=reward_dist,
|
||||
reward_ctrl=reward_ctrl,
|
||||
velocity=angular_vel,
|
||||
@ -56,7 +77,7 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
||||
goal=self.goal if hasattr(self, "goal") else None
|
||||
)
|
||||
|
||||
return ob, reward, done, infos
|
||||
return ob, reward, terminated, truncated, info
|
||||
|
||||
def distance_reward(self):
|
||||
vec = self.get_body_com("fingertip") - self.get_body_com("target")
|
||||
@ -66,6 +87,7 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
||||
return -10 * np.square(self.data.qvel.flat[:self.n_links]).sum() if self.sparse else 0.0
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
self.viewer.cam.trackbodyid = 0
|
||||
|
||||
def reset_model(self):
|
||||
|
@ -1,8 +1,9 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Optional, Any, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
from gym.envs.mujoco.walker2d_v4 import Walker2dEnv
|
||||
from gymnasium.envs.mujoco.walker2d_v4 import Walker2dEnv
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
MAX_EPISODE_STEPS_WALKERJUMP = 300
|
||||
|
||||
@ -54,13 +55,13 @@ class Walker2dJumpEnv(Walker2dEnv):
|
||||
|
||||
self.max_height = max(height, self.max_height)
|
||||
|
||||
done = bool(height < 0.2)
|
||||
terminated = bool(height < 0.2)
|
||||
|
||||
ctrl_cost = self.control_cost(action)
|
||||
costs = ctrl_cost
|
||||
rewards = 0
|
||||
if self.current_step >= self.max_episode_steps or done:
|
||||
done = True
|
||||
if self.current_step >= self.max_episode_steps or terminated:
|
||||
terminated = True
|
||||
height_goal_distance = -10 * (np.linalg.norm(self.max_height - self.goal))
|
||||
healthy_reward = self.healthy_reward * self.current_step
|
||||
|
||||
@ -73,17 +74,19 @@ class Walker2dJumpEnv(Walker2dEnv):
|
||||
'max_height': self.max_height,
|
||||
'goal': self.goal,
|
||||
}
|
||||
truncated = False
|
||||
|
||||
return observation, reward, done, info
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def _get_obs(self):
|
||||
return np.append(super()._get_obs(), self.goal)
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||
-> Tuple[ObsType, Dict[str, Any]]:
|
||||
self.current_step = 0
|
||||
self.max_height = 0
|
||||
self.goal = self.np_random.uniform(1.5, 2.5, 1) # 1.5 3.0
|
||||
return super().reset()
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
# overwrite reset_model to make it deterministic
|
||||
def reset_model(self):
|
||||
@ -98,20 +101,3 @@ class Walker2dJumpEnv(Walker2dEnv):
|
||||
observation = self._get_obs()
|
||||
return observation
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
render_mode = "human" # "human" or "partial" or "final"
|
||||
env = Walker2dJumpEnv()
|
||||
obs = env.reset()
|
||||
|
||||
for i in range(6000):
|
||||
# test with random actions
|
||||
ac = env.action_space.sample()
|
||||
obs, rew, d, info = env.step(ac)
|
||||
if i % 10 == 0:
|
||||
env.render(mode=render_mode)
|
||||
if d:
|
||||
print('After ', i, ' steps, done: ', d)
|
||||
env.reset()
|
||||
|
||||
env.close()
|
||||
|
@ -26,10 +26,10 @@ def example_dmc(env_id="dmc:fish-swim", seed=1, iterations=1000, render=True):
|
||||
ac = env.action_space.sample()
|
||||
if render:
|
||||
env.render(mode="human")
|
||||
obs, reward, done, info = env.step(ac)
|
||||
obs, reward, terminated, truncated, info = env.step(ac)
|
||||
rewards += reward
|
||||
|
||||
if done:
|
||||
if terminated or truncated:
|
||||
print(env_id, rewards)
|
||||
rewards = 0
|
||||
obs = env.reset()
|
||||
@ -102,10 +102,10 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
||||
# number of samples/full trajectories (multiple environment steps)
|
||||
for i in range(iterations):
|
||||
ac = env.action_space.sample()
|
||||
obs, reward, done, info = env.step(ac)
|
||||
obs, reward, terminated, truncated, info = env.step(ac)
|
||||
rewards += reward
|
||||
|
||||
if done:
|
||||
if terminated or truncated:
|
||||
print(base_env_id, rewards)
|
||||
rewards = 0
|
||||
obs = env.reset()
|
||||
|
@ -1,6 +1,6 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
|
||||
import fancy_gym
|
||||
@ -29,13 +29,13 @@ def example_general(env_id="Pendulum-v1", seed=1, iterations=1000, render=True):
|
||||
|
||||
# number of environment steps
|
||||
for i in range(iterations):
|
||||
obs, reward, done, info = env.step(env.action_space.sample())
|
||||
obs, reward, terminated, truncated, info = env.step(env.action_space.sample())
|
||||
rewards += reward
|
||||
|
||||
if render:
|
||||
env.render()
|
||||
|
||||
if done:
|
||||
if terminated or truncated:
|
||||
print(rewards)
|
||||
rewards = 0
|
||||
obs = env.reset()
|
||||
@ -69,12 +69,15 @@ def example_async(env_id="HoleReacher-v0", n_cpu=4, seed=int('533D', 16), n_samp
|
||||
# this would generate more samples than requested if n_samples % num_envs != 0
|
||||
repeat = int(np.ceil(n_samples / env.num_envs))
|
||||
for i in range(repeat):
|
||||
obs, reward, done, info = env.step(env.action_space.sample())
|
||||
obs, reward, terminated, truncated, info = env.step(env.action_space.sample())
|
||||
buffer['obs'].append(obs)
|
||||
buffer['reward'].append(reward)
|
||||
buffer['done'].append(done)
|
||||
buffer['terminated'].append(terminated)
|
||||
buffer['truncated'].append(truncated)
|
||||
buffer['info'].append(info)
|
||||
rewards += reward
|
||||
|
||||
done = terminated or truncated
|
||||
if np.any(done):
|
||||
print(f"Reward at iteration {i}: {rewards[done]}")
|
||||
rewards[done] = 0
|
||||
|
@ -29,9 +29,9 @@ def example_dmc(env_id="fish-swim", seed=1, iterations=1000, render=True):
|
||||
# THIS NEEDS TO BE SET TO FALSE FOR NOW, BECAUSE THE INTERFACE FOR RENDERING IS DIFFERENT TO BASIC GYM
|
||||
# TODO: Remove this, when Metaworld fixes its interface.
|
||||
env.render(False)
|
||||
obs, reward, done, info = env.step(ac)
|
||||
obs, reward, terminated, truncated, info = env.step(ac)
|
||||
rewards += reward
|
||||
if done:
|
||||
if terminated or truncated:
|
||||
print(env_id, rewards)
|
||||
rewards = 0
|
||||
obs = env.reset()
|
||||
@ -103,10 +103,10 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
||||
# number of samples/full trajectories (multiple environment steps)
|
||||
for i in range(iterations):
|
||||
ac = env.action_space.sample()
|
||||
obs, reward, done, info = env.step(ac)
|
||||
obs, reward, terminated, truncated, info = env.step(ac)
|
||||
rewards += reward
|
||||
|
||||
if done:
|
||||
if terminated or truncated:
|
||||
print(base_env_id, rewards)
|
||||
rewards = 0
|
||||
obs = env.reset()
|
||||
@ -131,4 +131,3 @@ if __name__ == '__main__':
|
||||
#
|
||||
# # Custom MetaWorld task
|
||||
example_custom_dmc_and_mp(seed=10, iterations=1, render=render)
|
||||
|
||||
|
@ -41,11 +41,11 @@ def example_mp(env_name="HoleReacherProMP-v0", seed=1, iterations=1, render=True
|
||||
# This executes a full trajectory and gives back the context (obs) of the last step in the trajectory, or the
|
||||
# full observation space of the last step, if replanning/sub-trajectory learning is used. The 'reward' is equal
|
||||
# to the return of a trajectory. Default is the sum over the step-wise rewards.
|
||||
obs, reward, done, info = env.step(ac)
|
||||
obs, reward, terminated, truncated, info = env.step(ac)
|
||||
# Aggregated returns
|
||||
returns += reward
|
||||
|
||||
if done:
|
||||
if terminated or truncated:
|
||||
print(reward)
|
||||
obs = env.reset()
|
||||
|
||||
@ -79,10 +79,10 @@ def example_custom_mp(env_name="Reacher5dProMP-v0", seed=1, iterations=1, render
|
||||
# number of samples/full trajectories (multiple environment steps)
|
||||
for i in range(iterations):
|
||||
ac = env.action_space.sample()
|
||||
obs, reward, done, info = env.step(ac)
|
||||
obs, reward, terminated, truncated, info = env.step(ac)
|
||||
returns += reward
|
||||
|
||||
if done:
|
||||
if terminated or truncated:
|
||||
print(i, reward)
|
||||
obs = env.reset()
|
||||
|
||||
@ -145,10 +145,10 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
||||
# number of samples/full trajectories (multiple environment steps)
|
||||
for i in range(iterations):
|
||||
ac = env.action_space.sample()
|
||||
obs, reward, done, info = env.step(ac)
|
||||
obs, reward, terminated, truncated, info = env.step(ac)
|
||||
rewards += reward
|
||||
|
||||
if done:
|
||||
if terminated or truncated:
|
||||
print(rewards)
|
||||
rewards = 0
|
||||
obs = env.reset()
|
||||
|
@ -24,10 +24,10 @@ def example_mp(env_name, seed=1, render=True):
|
||||
else:
|
||||
env.render()
|
||||
ac = env.action_space.sample()
|
||||
obs, reward, done, info = env.step(ac)
|
||||
obs, reward, terminated, truncated, info = env.step(ac)
|
||||
returns += reward
|
||||
|
||||
if done:
|
||||
if terminated or truncated:
|
||||
print(returns)
|
||||
obs = env.reset()
|
||||
|
||||
|
@ -34,7 +34,7 @@ fig.show()
|
||||
for t, pos_vel in enumerate(zip(pos, vel)):
|
||||
actions = env.tracking_controller.get_action(pos_vel[0], pos_vel[1], env.current_vel, env.current_pos)
|
||||
actions = np.clip(actions, env.env.action_space.low, env.env.action_space.high)
|
||||
_, _, _, _ = env.env.step(actions)
|
||||
env.env.step(actions)
|
||||
if t % 15 == 0:
|
||||
img.set_data(env.env.render(mode="rgb_array"))
|
||||
fig.canvas.draw()
|
||||
|
@ -1,6 +1,6 @@
|
||||
from copy import deepcopy
|
||||
|
||||
from gym import register
|
||||
from gymnasium import register
|
||||
|
||||
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
|
||||
object_change_mp_wrapper
|
||||
|
@ -1,6 +1,6 @@
|
||||
from copy import deepcopy
|
||||
|
||||
from gym import register
|
||||
from gymnasium import register
|
||||
|
||||
from . import mujoco
|
||||
from .deprecated_needs_gym_robotics import robotics
|
||||
|
11
fancy_gym/utils/env_compatibility.py
Normal file
11
fancy_gym/utils/env_compatibility.py
Normal file
@ -0,0 +1,11 @@
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
class EnvCompatibility(gym.wrappers.EnvCompatibility):
|
||||
def __getattr__(self, item):
|
||||
"""Propagate only non-existent properties to wrapped env."""
|
||||
if item.startswith('_'):
|
||||
raise AttributeError("attempted to get missing private attribute '{}'".format(item))
|
||||
if item in self.__dict__:
|
||||
return getattr(self, item)
|
||||
return getattr(self.env, item)
|
@ -1,17 +1,19 @@
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import MutableMapping
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
from typing import Iterable, Type, Union
|
||||
from typing import Iterable, Type, Union, Optional
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gym.envs.registration import register, registry
|
||||
from gymnasium.envs.registration import register, registry
|
||||
|
||||
from fancy_gym.utils.env_compatibility import EnvCompatibility
|
||||
|
||||
try:
|
||||
from dm_control import suite, manipulation
|
||||
from shimmy.dm_control_compatibility import EnvType
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@ -82,13 +84,20 @@ def make(env_id: str, seed: int, **kwargs):
|
||||
if framework == 'metaworld':
|
||||
# MetaWorld environment
|
||||
env = make_metaworld(env_id, seed, **kwargs)
|
||||
elif framework == 'dmc':
|
||||
# DeepMind Control environment
|
||||
env = make_dmc(env_id, seed, **kwargs)
|
||||
# elif framework == 'dmc':
|
||||
# Deprecated: With shimmy gym now has native support for deepmind envs
|
||||
# # DeepMind Control environment
|
||||
# env = make_dmc(env_id, seed, **kwargs)
|
||||
else:
|
||||
env = make_gym(env_id, seed, **kwargs)
|
||||
|
||||
env.seed(seed)
|
||||
# try:
|
||||
env.reset(seed=seed)
|
||||
# except TypeError:
|
||||
# # Support for older gym envs that do not have seeding
|
||||
# # env.seed(seed)
|
||||
# np_random, _ = seeding.np_random(seed)
|
||||
# env.np_random = np_random
|
||||
env.action_space.seed(seed)
|
||||
env.observation_space.seed(seed)
|
||||
|
||||
@ -158,7 +167,7 @@ def make_bb(
|
||||
traj_gen_kwargs['action_dim'] = traj_gen_kwargs.get('action_dim', np.prod(env.action_space.shape).item())
|
||||
|
||||
if black_box_kwargs.get('duration') is None:
|
||||
black_box_kwargs['duration'] = env.spec.max_episode_steps * env.dt
|
||||
black_box_kwargs['duration'] = get_env_duration(env)
|
||||
if phase_kwargs.get('tau') is None:
|
||||
phase_kwargs['tau'] = black_box_kwargs['duration']
|
||||
|
||||
@ -186,6 +195,24 @@ def make_bb(
|
||||
return bb_env
|
||||
|
||||
|
||||
def get_env_duration(env: gym.Env):
|
||||
try:
|
||||
duration = env.spec.max_episode_steps * env.dt
|
||||
except (AttributeError, TypeError) as e:
|
||||
# TODO Remove if this information is in the compatibility class
|
||||
logging.error(f'Attributes env.spec.max_episode_steps and env.dt are not available. '
|
||||
f'Assuming you are using dm_control. Please make sure you have ran '
|
||||
f'"pip install shimmy[dm_control]" for that.')
|
||||
if env.env_type is EnvType.COMPOSER:
|
||||
max_episode_steps = ceil(env.unwrapped._time_limit / env.dt)
|
||||
elif env.env_type is EnvType.RL_CONTROL:
|
||||
max_episode_steps = int(env.unwrapped._step_limit)
|
||||
else:
|
||||
raise e
|
||||
duration = max_episode_steps * env.control_timestep()
|
||||
return duration
|
||||
|
||||
|
||||
def make_bb_env_helper(**kwargs):
|
||||
"""
|
||||
Helper function for registering a black box gym environment.
|
||||
@ -235,55 +262,56 @@ def make_bb_env_helper(**kwargs):
|
||||
basis_kwargs=basis_kwargs, **kwargs, seed=seed)
|
||||
|
||||
|
||||
def make_dmc(
|
||||
env_id: str,
|
||||
seed: int = None,
|
||||
visualize_reward: bool = True,
|
||||
time_limit: Union[None, float] = None,
|
||||
**kwargs
|
||||
):
|
||||
if not re.match(r"\w+-\w+", env_id):
|
||||
raise ValueError("env_id does not have the following structure: 'domain_name-task_name'")
|
||||
domain_name, task_name = env_id.split("-")
|
||||
|
||||
if task_name.endswith("_vision"):
|
||||
# TODO
|
||||
raise ValueError("The vision interface for manipulation tasks is currently not supported.")
|
||||
|
||||
if (domain_name, task_name) not in suite.ALL_TASKS and task_name not in manipulation.ALL:
|
||||
raise ValueError(f'Specified domain "{domain_name}" and task "{task_name}" combination does not exist.')
|
||||
|
||||
# env_id = f'dmc_{domain_name}_{task_name}_{seed}-v1'
|
||||
gym_id = uuid.uuid4().hex + '-v1'
|
||||
|
||||
task_kwargs = {'random': seed}
|
||||
if time_limit is not None:
|
||||
task_kwargs['time_limit'] = time_limit
|
||||
|
||||
# create task
|
||||
# Accessing private attribute because DMC does not expose time_limit or step_limit.
|
||||
# Only the current time_step/time as well as the control_timestep can be accessed.
|
||||
if domain_name == "manipulation":
|
||||
env = manipulation.load(environment_name=task_name, seed=seed)
|
||||
max_episode_steps = ceil(env._time_limit / env.control_timestep())
|
||||
else:
|
||||
env = suite.load(domain_name=domain_name, task_name=task_name, task_kwargs=task_kwargs,
|
||||
visualize_reward=visualize_reward, environment_kwargs=kwargs)
|
||||
max_episode_steps = int(env._step_limit)
|
||||
|
||||
register(
|
||||
id=gym_id,
|
||||
entry_point='fancy_gym.dmc.dmc_wrapper:DMCWrapper',
|
||||
kwargs={'env': lambda: env},
|
||||
max_episode_steps=max_episode_steps,
|
||||
)
|
||||
|
||||
env = gym.make(gym_id)
|
||||
env.seed(seed)
|
||||
return env
|
||||
# Deprecated: With shimmy gym now has native support for deepmind envs
|
||||
# def make_dmc(
|
||||
# env_id: str,
|
||||
# seed: int = None,
|
||||
# visualize_reward: bool = True,
|
||||
# time_limit: Union[None, float] = None,
|
||||
# **kwargs
|
||||
# ):
|
||||
# if not re.match(r"\w+-\w+", env_id):
|
||||
# raise ValueError("env_id does not have the following structure: 'domain_name-task_name'")
|
||||
# domain_name, task_name = env_id.split("-")
|
||||
#
|
||||
# if task_name.endswith("_vision"):
|
||||
# # TODO
|
||||
# raise ValueError("The vision interface for manipulation tasks is currently not supported.")
|
||||
#
|
||||
# if (domain_name, task_name) not in suite.ALL_TASKS and task_name not in manipulation.ALL:
|
||||
# raise ValueError(f'Specified domain "{domain_name}" and task "{task_name}" combination does not exist.')
|
||||
#
|
||||
# # env_id = f'dmc_{domain_name}_{task_name}_{seed}-v1'
|
||||
# gym_id = uuid.uuid4().hex + '-v1'
|
||||
#
|
||||
# task_kwargs = {'random': seed}
|
||||
# if time_limit is not None:
|
||||
# task_kwargs['time_limit'] = time_limit
|
||||
#
|
||||
# # create task
|
||||
# # Accessing private attribute because DMC does not expose time_limit or step_limit.
|
||||
# # Only the current time_step/time as well as the control_timestep can be accessed.
|
||||
# if domain_name == "manipulation":
|
||||
# env = manipulation.load(environment_name=task_name, seed=seed)
|
||||
# max_episode_steps = ceil(env._time_limit / env.control_timestep())
|
||||
# else:
|
||||
# env = suite.load(domain_name=domain_name, task_name=task_name, task_kwargs=task_kwargs,
|
||||
# visualize_reward=visualize_reward, environment_kwargs=kwargs)
|
||||
# max_episode_steps = int(env._step_limit)
|
||||
#
|
||||
# register(
|
||||
# id=gym_id,
|
||||
# entry_point='fancy_gym.dmc.dmc_wrapper:DMCWrapper',
|
||||
# kwargs={'env': lambda: env},
|
||||
# max_episode_steps=max_episode_steps,
|
||||
# )
|
||||
#
|
||||
# env = gym.make(gym_id)
|
||||
# env.seed(seed)
|
||||
# return env
|
||||
|
||||
|
||||
def make_metaworld(env_id: str, seed: int, **kwargs):
|
||||
def make_metaworld(env_id: str, seed: int, render_mode: Optional[str] = None, **kwargs):
|
||||
if env_id not in metaworld.ML1.ENV_NAMES:
|
||||
raise ValueError(f'Specified environment "{env_id}" not present in metaworld ML1.')
|
||||
|
||||
@ -294,12 +322,17 @@ def make_metaworld(env_id: str, seed: int, **kwargs):
|
||||
# New argument to use global seeding
|
||||
_env.seeded_rand_vec = True
|
||||
|
||||
max_episode_steps = _env.max_path_length
|
||||
|
||||
# TODO remove this as soon as there is support for the new API
|
||||
_env = EnvCompatibility(_env, render_mode)
|
||||
|
||||
gym_id = uuid.uuid4().hex + '-v1'
|
||||
|
||||
register(
|
||||
id=gym_id,
|
||||
entry_point=lambda: _env,
|
||||
max_episode_steps=_env.max_path_length,
|
||||
max_episode_steps=max_episode_steps,
|
||||
)
|
||||
|
||||
# TODO enable checker when the incorrect dtype of obs and observation space are fixed by metaworld
|
||||
|
@ -1,45 +1,11 @@
|
||||
"""
|
||||
Adapted from: https://github.com/openai/gym/blob/907b1b20dd9ac0cba5803225059b9c6673702467/gym/wrappers/time_aware_observation.py
|
||||
License: MIT
|
||||
Copyright (c) 2016 OpenAI (https://openai.com)
|
||||
|
||||
Wrapper for adding time aware observations to environment observation.
|
||||
"""
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gym.spaces import Box
|
||||
|
||||
|
||||
class TimeAwareObservation(gym.ObservationWrapper):
|
||||
"""Augment the observation with the current time step in the episode.
|
||||
|
||||
The observation space of the wrapped environment is assumed to be a flat :class:`Box`.
|
||||
In particular, pixel observations are not supported. This wrapper will append the current timestep
|
||||
within the current episode to the observation.
|
||||
|
||||
Example:
|
||||
>>> import gym
|
||||
>>> env = gym.make('CartPole-v1')
|
||||
>>> env = TimeAwareObservation(env)
|
||||
>>> env.reset()
|
||||
array([ 0.03810719, 0.03522411, 0.02231044, -0.01088205, 0. ])
|
||||
>>> env.step(env.action_space.sample())[0]
|
||||
array([ 0.03881167, -0.16021058, 0.0220928 , 0.28875574, 1. ])
|
||||
"""
|
||||
class TimeAwareObservation(gym.wrappers.TimeAwareObservation):
|
||||
|
||||
def __init__(self, env: gym.Env):
|
||||
"""Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box`
|
||||
observation space.
|
||||
|
||||
Args:
|
||||
env: The environment to apply the wrapper
|
||||
"""
|
||||
super().__init__(env)
|
||||
assert isinstance(env.observation_space, Box)
|
||||
low = np.append(self.observation_space.low, 0.0)
|
||||
high = np.append(self.observation_space.high, 1.0)
|
||||
self.observation_space = Box(low, high, dtype=self.observation_space.dtype)
|
||||
self.t = 0
|
||||
self._max_episode_steps = env.spec.max_episode_steps
|
||||
|
||||
def observation(self, observation):
|
||||
@ -52,27 +18,3 @@ class TimeAwareObservation(gym.ObservationWrapper):
|
||||
The observation with the time step appended to
|
||||
"""
|
||||
return np.append(observation, self.t / self._max_episode_steps)
|
||||
|
||||
def step(self, action):
|
||||
"""Steps through the environment, incrementing the time step.
|
||||
|
||||
Args:
|
||||
action: The action to take
|
||||
|
||||
Returns:
|
||||
The environment's step using the action.
|
||||
"""
|
||||
self.t += 1
|
||||
return super().step(action)
|
||||
|
||||
def reset(self, **kwargs):
|
||||
"""Reset the environment setting the time to zero.
|
||||
|
||||
Args:
|
||||
**kwargs: Kwargs to apply to env.reset()
|
||||
|
||||
Returns:
|
||||
The reset environment
|
||||
"""
|
||||
self.t = 0
|
||||
return super().reset(**kwargs)
|
||||
|
1
setup.py
1
setup.py
@ -12,6 +12,7 @@ extras = {
|
||||
],
|
||||
'box2d': ['gymnasium[box2d]>=0.26.0'],
|
||||
'testing': ['pytest'],
|
||||
"mujoco": ["gymnasium[mujoco]"],
|
||||
}
|
||||
|
||||
# All dependencies
|
||||
|
@ -1,48 +1,54 @@
|
||||
from itertools import chain
|
||||
from typing import Callable
|
||||
|
||||
import gymnasium as gym
|
||||
import pytest
|
||||
from dm_control import suite, manipulation
|
||||
|
||||
import fancy_gym
|
||||
from test.utils import run_env, run_env_determinism
|
||||
|
||||
SUITE_IDS = [f'dmc:{env}-{task}' for env, task in suite.ALL_TASKS if env != "lqr"]
|
||||
MANIPULATION_IDS = [f'dmc:manipulation-{task}' for task in manipulation.ALL if task.endswith('_features')]
|
||||
DMC_MP_IDS = chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
||||
# SUITE_IDS = [f'dmc:{env}-{task}' for env, task in suite.ALL_TASKS if env != "lqr"]
|
||||
# MANIPULATION_IDS = [f'dmc:manipulation-{task}' for task in manipulation.ALL if task.endswith('_features')]
|
||||
DM_CONTROL_IDS = [spec.id for spec in gym.envs.registry.values() if
|
||||
spec.id.startswith('dm_control/')
|
||||
and 'compatibility-env-v0' not in spec.id
|
||||
and 'lqr-lqr' not in spec.id]
|
||||
DM_control_MP_IDS = list(chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()))
|
||||
SEED = 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize('env_id', SUITE_IDS)
|
||||
def test_step_suite_functionality(env_id: str):
|
||||
@pytest.mark.parametrize('env_id', DM_CONTROL_IDS)
|
||||
def test_step_dm_control_functionality(env_id: str):
|
||||
"""Tests that suite step environments run without errors using random actions."""
|
||||
run_env(env_id)
|
||||
run_env(env_id, 5000, wrappers=[gym.wrappers.FlattenObservation])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('env_id', SUITE_IDS)
|
||||
def test_step_suite_determinism(env_id: str):
|
||||
@pytest.mark.parametrize('env_id', DM_CONTROL_IDS)
|
||||
def test_step_dm_control_determinism(env_id: str):
|
||||
"""Tests that for step environments identical seeds produce identical trajectories."""
|
||||
run_env_determinism(env_id, SEED)
|
||||
run_env_determinism(env_id, SEED, 5000, wrappers=[gym.wrappers.FlattenObservation])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('env_id', MANIPULATION_IDS)
|
||||
def test_step_manipulation_functionality(env_id: str):
|
||||
"""Tests that manipulation step environments run without errors using random actions."""
|
||||
run_env(env_id)
|
||||
# @pytest.mark.parametrize('env_id', MANIPULATION_IDS)
|
||||
# def test_step_manipulation_functionality(env_id: str):
|
||||
# """Tests that manipulation step environments run without errors using random actions."""
|
||||
# run_env(env_id)
|
||||
#
|
||||
#
|
||||
# @pytest.mark.parametrize('env_id', MANIPULATION_IDS)
|
||||
# def test_step_manipulation_determinism(env_id: str):
|
||||
# """Tests that for step environments identical seeds produce identical trajectories."""
|
||||
# run_env_determinism(env_id, SEED)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('env_id', MANIPULATION_IDS)
|
||||
def test_step_manipulation_determinism(env_id: str):
|
||||
"""Tests that for step environments identical seeds produce identical trajectories."""
|
||||
run_env_determinism(env_id, SEED)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('env_id', DMC_MP_IDS)
|
||||
@pytest.mark.parametrize('env_id', DM_control_MP_IDS)
|
||||
def test_bb_dmc_functionality(env_id: str):
|
||||
"""Tests that black box environments run without errors using random actions."""
|
||||
run_env(env_id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('env_id', DMC_MP_IDS)
|
||||
@pytest.mark.parametrize('env_id', DM_control_MP_IDS)
|
||||
def test_bb_dmc_determinism(env_id: str):
|
||||
"""Tests that for black box environment identical seeds produce identical trajectories."""
|
||||
run_env_determinism(env_id, SEED)
|
||||
|
@ -1,14 +1,16 @@
|
||||
import itertools
|
||||
from itertools import chain
|
||||
from typing import Callable
|
||||
|
||||
import fancy_gym
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import pytest
|
||||
|
||||
from test.utils import run_env, run_env_determinism
|
||||
|
||||
CUSTOM_IDS = [spec.id for spec in gym.envs.registry.all() if
|
||||
CUSTOM_IDS = [id for id, spec in gym.envs.registry.items() if
|
||||
not isinstance(spec.entry_point, Callable) and
|
||||
"fancy_gym" in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point]
|
||||
CUSTOM_MP_IDS = itertools.chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
||||
CUSTOM_MP_IDS = list(chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()))
|
||||
SEED = 1
|
||||
|
||||
|
||||
|
@ -1,14 +1,20 @@
|
||||
import re
|
||||
from itertools import chain
|
||||
from typing import Callable
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import pytest
|
||||
|
||||
import fancy_gym
|
||||
from test.utils import run_env, run_env_determinism
|
||||
|
||||
GYM_IDS = [spec.id for spec in gym.envs.registry.all() if
|
||||
"fancy_gym" not in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point]
|
||||
GYM_MP_IDS = chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
||||
GYM_IDS = [spec.id for spec in gym.envs.registry.values() if
|
||||
not isinstance(spec.entry_point, Callable) and
|
||||
"fancy_gym" not in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point
|
||||
and 'jax' not in spec.id.lower()
|
||||
and not re.match(r'GymV2.Environment', spec.id)
|
||||
]
|
||||
GYM_MP_IDS = list(chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()))
|
||||
SEED = 1
|
||||
|
||||
|
||||
|
@ -8,7 +8,11 @@ from test.utils import run_env, run_env_determinism
|
||||
|
||||
METAWORLD_IDS = [f'metaworld:{env.split("-goal-observable")[0]}' for env, _ in
|
||||
ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.items()]
|
||||
<<<<<<< HEAD
|
||||
METAWORLD_MP_IDS = chain(*fancy_gym.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
||||
=======
|
||||
METAWORLD_MP_IDS = list(chain(*fancy_gym.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()))
|
||||
>>>>>>> 47-update-to-new-gym-api
|
||||
SEED = 1
|
||||
|
||||
|
||||
|
@ -1,9 +1,12 @@
|
||||
import gym
|
||||
from typing import List, Type
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from fancy_gym import make
|
||||
|
||||
|
||||
def run_env(env_id, iterations=None, seed=0, render=False):
|
||||
def run_env(env_id: str, iterations: int = None, seed: int = 0, wrappers: List[Type[gym.Wrapper]] = [],
|
||||
render: bool = False):
|
||||
"""
|
||||
Example for running a DMC based env in the step based setting.
|
||||
The env_id has to be specified as `dmc:domain_name-task_name` or
|
||||
@ -13,17 +16,21 @@ def run_env(env_id, iterations=None, seed=0, render=False):
|
||||
env_id: Either `dmc:domain_name-task_name` or `dmc:manipulation-environment_name`
|
||||
iterations: Number of rollout steps to run
|
||||
seed: random seeding
|
||||
wrappers: List of Wrappers to apply to the environment
|
||||
render: Render the episode
|
||||
|
||||
Returns: observations, rewards, dones, actions
|
||||
Returns: observations, rewards, terminations, truncations, actions
|
||||
|
||||
"""
|
||||
env: gym.Env = make(env_id, seed=seed)
|
||||
for w in wrappers:
|
||||
env = w(env)
|
||||
rewards = []
|
||||
observations = []
|
||||
actions = []
|
||||
dones = []
|
||||
obs = env.reset()
|
||||
terminations = []
|
||||
truncations = []
|
||||
obs, _ = env.reset()
|
||||
verify_observations(obs, env.observation_space, "reset()")
|
||||
|
||||
iterations = iterations or (env.spec.max_episode_steps or 1)
|
||||
@ -35,38 +42,49 @@ def run_env(env_id, iterations=None, seed=0, render=False):
|
||||
ac = env.action_space.sample()
|
||||
actions.append(ac)
|
||||
# ac = np.random.uniform(env.action_space.low, env.action_space.high, env.action_space.shape)
|
||||
obs, reward, done, info = env.step(ac)
|
||||
obs, reward, terminated, truncated, info = env.step(ac)
|
||||
|
||||
verify_observations(obs, env.observation_space, "step()")
|
||||
verify_reward(reward)
|
||||
verify_done(done)
|
||||
verify_done(terminated)
|
||||
verify_done(truncated)
|
||||
|
||||
rewards.append(reward)
|
||||
dones.append(done)
|
||||
terminations.append(terminated)
|
||||
truncations.append(truncated)
|
||||
|
||||
if render:
|
||||
env.render("human")
|
||||
|
||||
if done:
|
||||
if terminated or truncated:
|
||||
break
|
||||
if not hasattr(env, "replanning_schedule"):
|
||||
assert done, "Done flag is not True after end of episode."
|
||||
assert terminated or truncated, f"Termination or truncation flag is not True after {i + 1} iterations."
|
||||
|
||||
observations.append(obs)
|
||||
env.close()
|
||||
del env
|
||||
return np.array(observations), np.array(rewards), np.array(dones), np.array(actions)
|
||||
return np.array(observations), np.array(rewards), np.array(terminations), np.array(truncations), np.array(actions)
|
||||
|
||||
|
||||
def run_env_determinism(env_id: str, seed: int):
|
||||
traj1 = run_env(env_id, seed=seed)
|
||||
traj2 = run_env(env_id, seed=seed)
|
||||
def run_env_determinism(env_id: str, seed: int, iterations: int = None, wrappers: List[Type[gym.Wrapper]] = []):
|
||||
traj1 = run_env(env_id, iterations=iterations,
|
||||
seed=seed, wrappers=wrappers)
|
||||
traj2 = run_env(env_id, iterations=iterations,
|
||||
seed=seed, wrappers=wrappers)
|
||||
# Iterate over two trajectories, which should have the same state and action sequence
|
||||
for i, time_step in enumerate(zip(*traj1, *traj2)):
|
||||
obs1, rwd1, done1, ac1, obs2, rwd2, done2, ac2 = time_step
|
||||
assert np.array_equal(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match."
|
||||
assert np.array_equal(ac1, ac2), f"Actions [{i}] {ac1} and {ac2} do not match."
|
||||
assert np.array_equal(rwd1, rwd2), f"Rewards [{i}] {rwd1} and {rwd2} do not match."
|
||||
assert np.array_equal(done1, done2), f"Dones [{i}] {done1} and {done2} do not match."
|
||||
obs1, rwd1, term1, trunc1, ac1, obs2, rwd2, term2, trunc2, ac2 = time_step
|
||||
assert np.allclose(
|
||||
obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match."
|
||||
assert np.array_equal(
|
||||
ac1, ac2), f"Actions [{i}] {ac1} and {ac2} do not match."
|
||||
assert np.array_equal(
|
||||
rwd1, rwd2), f"Rewards [{i}] {rwd1} and {rwd2} do not match."
|
||||
assert np.array_equal(
|
||||
term1, term2), f"Terminateds [{i}] {term1} and {term2} do not match."
|
||||
assert np.array_equal(
|
||||
term1, term2), f"Truncateds [{i}] {trunc1} and {trunc2} do not match."
|
||||
|
||||
|
||||
def verify_observations(obs, observation_space: gym.Space, obs_type="reset()"):
|
||||
@ -75,8 +93,10 @@ def verify_observations(obs, observation_space: gym.Space, obs_type="reset()"):
|
||||
|
||||
|
||||
def verify_reward(reward):
|
||||
assert isinstance(reward, (float, int)), f"Returned type {type(reward)} as reward, expected float or int."
|
||||
assert isinstance(
|
||||
reward, (float, int)), f"Returned type {type(reward)} as reward, expected float or int."
|
||||
|
||||
|
||||
def verify_done(done):
|
||||
assert isinstance(done, bool), f"Returned {done} as done flag, expected bool."
|
||||
assert isinstance(
|
||||
done, bool), f"Returned {done} as done flag, expected bool."
|
||||
|
Loading…
Reference in New Issue
Block a user