From 9ebc021ae0529b7481ff9e9bc1cbc368a587ba18 Mon Sep 17 00:00:00 2001 From: Fabian Date: Thu, 12 Jan 2023 17:23:56 +0100 Subject: [PATCH] updated dm_control envs to use shimmy --- fancy_gym/dmc/__init__.py | 36 ++++++++++--------- fancy_gym/dmc/dmc_wrapper.py | 22 ++++++------ .../dmc/manipulation/reach_site/mp_wrapper.py | 2 +- fancy_gym/dmc/suite/ball_in_cup/mp_wrapper.py | 2 +- fancy_gym/dmc/suite/cartpole/mp_wrapper.py | 2 +- fancy_gym/dmc/suite/reacher/mp_wrapper.py | 2 +- 6 files changed, 34 insertions(+), 32 deletions(-) diff --git a/fancy_gym/dmc/__init__.py b/fancy_gym/dmc/__init__.py index 22ae47f..29bd354 100644 --- a/fancy_gym/dmc/__init__.py +++ b/fancy_gym/dmc/__init__.py @@ -1,14 +1,16 @@ from copy import deepcopy +from gymnasium.wrappers import FlattenObservation + from . import manipulation, suite ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []} -from gym.envs.registration import register +from gymnasium.envs.registration import register DEFAULT_BB_DICT_ProMP = { "name": 'EnvName', - "wrappers": [], + "wrappers": [FlattenObservation], "trajectory_generator_kwargs": { 'trajectory_generator_type': 'promp' }, @@ -29,7 +31,7 @@ DEFAULT_BB_DICT_ProMP = { DEFAULT_BB_DICT_DMP = { "name": 'EnvName', - "wrappers": [], + "wrappers": [FlattenObservation], "trajectory_generator_kwargs": { 'trajectory_generator_type': 'dmp' }, @@ -49,7 +51,7 @@ DEFAULT_BB_DICT_DMP = { # DeepMind Control Suite (DMC) kwargs_dict_bic_dmp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_bic_dmp['name'] = f"dmc:ball_in_cup-catch" +kwargs_dict_bic_dmp['name'] = f"dm_control/ball_in_cup-catch-v0" kwargs_dict_bic_dmp['wrappers'].append(suite.ball_in_cup.MPWrapper) # bandwidth_factor=2 kwargs_dict_bic_dmp['phase_generator_kwargs']['alpha_phase'] = 2 @@ -62,7 +64,7 @@ register( ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0") kwargs_dict_bic_promp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_bic_promp['name'] = f"dmc:ball_in_cup-catch" +kwargs_dict_bic_promp['name'] = f"dm_control/ball_in_cup-catch-v0" kwargs_dict_bic_promp['wrappers'].append(suite.ball_in_cup.MPWrapper) register( id=f'dmc_ball_in_cup-catch_promp-v0', @@ -72,7 +74,7 @@ register( ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_ball_in_cup-catch_promp-v0") kwargs_dict_reacher_easy_dmp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_reacher_easy_dmp['name'] = f"dmc:reacher-easy" +kwargs_dict_reacher_easy_dmp['name'] = f"dm_control/reacher-easy-v0" kwargs_dict_reacher_easy_dmp['wrappers'].append(suite.reacher.MPWrapper) # bandwidth_factor=2 kwargs_dict_reacher_easy_dmp['phase_generator_kwargs']['alpha_phase'] = 2 @@ -86,7 +88,7 @@ register( ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0") kwargs_dict_reacher_easy_promp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_reacher_easy_promp['name'] = f"dmc:reacher-easy" +kwargs_dict_reacher_easy_promp['name'] = f"dm_control/reacher-easy-v0" kwargs_dict_reacher_easy_promp['wrappers'].append(suite.reacher.MPWrapper) kwargs_dict_reacher_easy_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2 register( @@ -97,7 +99,7 @@ register( ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_reacher-easy_promp-v0") kwargs_dict_reacher_hard_dmp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_reacher_hard_dmp['name'] = f"dmc:reacher-hard" +kwargs_dict_reacher_hard_dmp['name'] = f"dm_control/reacher-hard-v0" kwargs_dict_reacher_hard_dmp['wrappers'].append(suite.reacher.MPWrapper) # bandwidth_factor = 2 kwargs_dict_reacher_hard_dmp['phase_generator_kwargs']['alpha_phase'] = 2 @@ -111,7 +113,7 @@ register( ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0") kwargs_dict_reacher_hard_promp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_reacher_hard_promp['name'] = f"dmc:reacher-hard" +kwargs_dict_reacher_hard_promp['name'] = f"dm_control/reacher-hard-v0" kwargs_dict_reacher_hard_promp['wrappers'].append(suite.reacher.MPWrapper) kwargs_dict_reacher_hard_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2 register( @@ -126,7 +128,7 @@ _dmc_cartpole_tasks = ["balance", "balance_sparse", "swingup", "swingup_sparse"] for _task in _dmc_cartpole_tasks: _env_id = f'dmc_cartpole-{_task}_dmp-v0' kwargs_dict_cartpole_dmp = deepcopy(DEFAULT_BB_DICT_DMP) - kwargs_dict_cartpole_dmp['name'] = f"dmc:cartpole-{_task}" + kwargs_dict_cartpole_dmp['name'] = f"dm_control/cartpole-{_task}-v0" kwargs_dict_cartpole_dmp['wrappers'].append(suite.cartpole.MPWrapper) # bandwidth_factor = 2 kwargs_dict_cartpole_dmp['phase_generator_kwargs']['alpha_phase'] = 2 @@ -143,7 +145,7 @@ for _task in _dmc_cartpole_tasks: _env_id = f'dmc_cartpole-{_task}_promp-v0' kwargs_dict_cartpole_promp = deepcopy(DEFAULT_BB_DICT_DMP) - kwargs_dict_cartpole_promp['name'] = f"dmc:cartpole-{_task}" + kwargs_dict_cartpole_promp['name'] = f"dm_control/cartpole-{_task}-v0" kwargs_dict_cartpole_promp['wrappers'].append(suite.cartpole.MPWrapper) kwargs_dict_cartpole_promp['controller_kwargs']['p_gains'] = 10 kwargs_dict_cartpole_promp['controller_kwargs']['d_gains'] = 10 @@ -156,7 +158,7 @@ for _task in _dmc_cartpole_tasks: ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) kwargs_dict_cartpole2poles_dmp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_cartpole2poles_dmp['name'] = f"dmc:cartpole-two_poles" +kwargs_dict_cartpole2poles_dmp['name'] = f"dm_control/cartpole-two_poles-v0" kwargs_dict_cartpole2poles_dmp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper) # bandwidth_factor = 2 kwargs_dict_cartpole2poles_dmp['phase_generator_kwargs']['alpha_phase'] = 2 @@ -173,7 +175,7 @@ register( ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) kwargs_dict_cartpole2poles_promp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_cartpole2poles_promp['name'] = f"dmc:cartpole-two_poles" +kwargs_dict_cartpole2poles_promp['name'] = f"dm_control/cartpole-two_poles-v0" kwargs_dict_cartpole2poles_promp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper) kwargs_dict_cartpole2poles_promp['controller_kwargs']['p_gains'] = 10 kwargs_dict_cartpole2poles_promp['controller_kwargs']['d_gains'] = 10 @@ -187,7 +189,7 @@ register( ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) kwargs_dict_cartpole3poles_dmp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_cartpole3poles_dmp['name'] = f"dmc:cartpole-three_poles" +kwargs_dict_cartpole3poles_dmp['name'] = f"dm_control/cartpole-three_poles-v0" kwargs_dict_cartpole3poles_dmp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper) # bandwidth_factor = 2 kwargs_dict_cartpole3poles_dmp['phase_generator_kwargs']['alpha_phase'] = 2 @@ -204,7 +206,7 @@ register( ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) kwargs_dict_cartpole3poles_promp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_cartpole3poles_promp['name'] = f"dmc:cartpole-three_poles" +kwargs_dict_cartpole3poles_promp['name'] = f"dm_control/cartpole-three_poles-v0" kwargs_dict_cartpole3poles_promp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper) kwargs_dict_cartpole3poles_promp['controller_kwargs']['p_gains'] = 10 kwargs_dict_cartpole3poles_promp['controller_kwargs']['d_gains'] = 10 @@ -219,7 +221,7 @@ ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) # DeepMind Manipulation kwargs_dict_mani_reach_site_features_dmp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_mani_reach_site_features_dmp['name'] = f"dmc:manipulation-reach_site_features" +kwargs_dict_mani_reach_site_features_dmp['name'] = f"dm_control/reach_site_features-v0" kwargs_dict_mani_reach_site_features_dmp['wrappers'].append(manipulation.reach_site.MPWrapper) kwargs_dict_mani_reach_site_features_dmp['phase_generator_kwargs']['alpha_phase'] = 2 # TODO: weight scale 50, but goal scale 0.1 @@ -233,7 +235,7 @@ register( ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0") kwargs_dict_mani_reach_site_features_promp = deepcopy(DEFAULT_BB_DICT_DMP) -kwargs_dict_mani_reach_site_features_promp['name'] = f"dmc:manipulation-reach_site_features" +kwargs_dict_mani_reach_site_features_promp['name'] = f"dm_control/reach_site_features-v0" kwargs_dict_mani_reach_site_features_promp['wrappers'].append(manipulation.reach_site.MPWrapper) kwargs_dict_mani_reach_site_features_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2 kwargs_dict_mani_reach_site_features_promp['controller_kwargs']['controller_type'] = 'velocity' diff --git a/fancy_gym/dmc/dmc_wrapper.py b/fancy_gym/dmc/dmc_wrapper.py index b1522c3..d1e5f0d 100644 --- a/fancy_gym/dmc/dmc_wrapper.py +++ b/fancy_gym/dmc/dmc_wrapper.py @@ -3,15 +3,15 @@ # Copyright (c) 2020 Denis Yarats import collections from collections.abc import MutableMapping -from typing import Any, Dict, Tuple, Optional, Union, Callable +from typing import Any, Dict, Tuple, Optional, Union, Callable, SupportsFloat -import gym +import gymnasium as gym import numpy as np from dm_control import composer from dm_control.rl import control from dm_env import specs -from gym import spaces -from gym.core import ObsType +from gymnasium import spaces +from gymnasium.core import ObsType, ActType def _spec_to_box(spec): @@ -100,23 +100,23 @@ class DMCWrapper(gym.Env): self._action_space.seed(seed) self._observation_space.seed(seed) - def step(self, action) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]: + def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: assert self._action_space.contains(action) extra = {'internal_state': self._env.physics.get_state().copy()} - time_step = self._env.step(action) reward = time_step.reward or 0. - done = time_step.last() + terminated = False + truncated = time_step.last() and time_step.discount > 0 obs = self._get_obs(time_step) extra['discount'] = time_step.discount - return obs, reward, done, extra + return obs, reward, terminated, truncated, extra - def reset(self, *, seed: Optional[int] = None, return_info: bool = False, - options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]: + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \ + -> Tuple[ObsType, Dict[str, Any]]: time_step = self._env.reset() obs = self._get_obs(time_step) - return obs + return obs, {} def render(self, mode='rgb_array', height=240, width=320, camera_id=-1, overlays=(), depth=False, segmentation=False, scene_option=None, render_flag_overrides=None): diff --git a/fancy_gym/dmc/manipulation/reach_site/mp_wrapper.py b/fancy_gym/dmc/manipulation/reach_site/mp_wrapper.py index f64ac4a..908cee1 100644 --- a/fancy_gym/dmc/manipulation/reach_site/mp_wrapper.py +++ b/fancy_gym/dmc/manipulation/reach_site/mp_wrapper.py @@ -35,4 +35,4 @@ class MPWrapper(RawInterfaceWrapper): @property def dt(self) -> Union[float, int]: - return self.env.dt + return self.env.control_timestep() diff --git a/fancy_gym/dmc/suite/ball_in_cup/mp_wrapper.py b/fancy_gym/dmc/suite/ball_in_cup/mp_wrapper.py index dc6a539..94f9041 100644 --- a/fancy_gym/dmc/suite/ball_in_cup/mp_wrapper.py +++ b/fancy_gym/dmc/suite/ball_in_cup/mp_wrapper.py @@ -31,4 +31,4 @@ class MPWrapper(RawInterfaceWrapper): @property def dt(self) -> Union[float, int]: - return self.env.dt + return self.env.control_timestep() diff --git a/fancy_gym/dmc/suite/cartpole/mp_wrapper.py b/fancy_gym/dmc/suite/cartpole/mp_wrapper.py index 7edd51f..85afa83 100644 --- a/fancy_gym/dmc/suite/cartpole/mp_wrapper.py +++ b/fancy_gym/dmc/suite/cartpole/mp_wrapper.py @@ -35,7 +35,7 @@ class MPWrapper(RawInterfaceWrapper): @property def dt(self) -> Union[float, int]: - return self.env.dt + return self.env.control_timestep() class TwoPolesMPWrapper(MPWrapper): diff --git a/fancy_gym/dmc/suite/reacher/mp_wrapper.py b/fancy_gym/dmc/suite/reacher/mp_wrapper.py index 5ac52e5..2d0aee5 100644 --- a/fancy_gym/dmc/suite/reacher/mp_wrapper.py +++ b/fancy_gym/dmc/suite/reacher/mp_wrapper.py @@ -30,4 +30,4 @@ class MPWrapper(RawInterfaceWrapper): @property def dt(self) -> Union[float, int]: - return self.env.dt + return self.env.control_timestep()