updated dm_control envs to use shimmy

This commit is contained in:
Fabian 2023-01-12 17:23:56 +01:00
parent ed724046f3
commit 9ebc021ae0
6 changed files with 34 additions and 32 deletions

View File

@ -1,14 +1,16 @@
from copy import deepcopy from copy import deepcopy
from gymnasium.wrappers import FlattenObservation
from . import manipulation, suite from . import manipulation, suite
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []} ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
from gym.envs.registration import register from gymnasium.envs.registration import register
DEFAULT_BB_DICT_ProMP = { DEFAULT_BB_DICT_ProMP = {
"name": 'EnvName', "name": 'EnvName',
"wrappers": [], "wrappers": [FlattenObservation],
"trajectory_generator_kwargs": { "trajectory_generator_kwargs": {
'trajectory_generator_type': 'promp' 'trajectory_generator_type': 'promp'
}, },
@ -29,7 +31,7 @@ DEFAULT_BB_DICT_ProMP = {
DEFAULT_BB_DICT_DMP = { DEFAULT_BB_DICT_DMP = {
"name": 'EnvName', "name": 'EnvName',
"wrappers": [], "wrappers": [FlattenObservation],
"trajectory_generator_kwargs": { "trajectory_generator_kwargs": {
'trajectory_generator_type': 'dmp' 'trajectory_generator_type': 'dmp'
}, },
@ -49,7 +51,7 @@ DEFAULT_BB_DICT_DMP = {
# DeepMind Control Suite (DMC) # DeepMind Control Suite (DMC)
kwargs_dict_bic_dmp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_bic_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_bic_dmp['name'] = f"dmc:ball_in_cup-catch" kwargs_dict_bic_dmp['name'] = f"dm_control/ball_in_cup-catch-v0"
kwargs_dict_bic_dmp['wrappers'].append(suite.ball_in_cup.MPWrapper) kwargs_dict_bic_dmp['wrappers'].append(suite.ball_in_cup.MPWrapper)
# bandwidth_factor=2 # bandwidth_factor=2
kwargs_dict_bic_dmp['phase_generator_kwargs']['alpha_phase'] = 2 kwargs_dict_bic_dmp['phase_generator_kwargs']['alpha_phase'] = 2
@ -62,7 +64,7 @@ register(
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0") ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0")
kwargs_dict_bic_promp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_bic_promp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_bic_promp['name'] = f"dmc:ball_in_cup-catch" kwargs_dict_bic_promp['name'] = f"dm_control/ball_in_cup-catch-v0"
kwargs_dict_bic_promp['wrappers'].append(suite.ball_in_cup.MPWrapper) kwargs_dict_bic_promp['wrappers'].append(suite.ball_in_cup.MPWrapper)
register( register(
id=f'dmc_ball_in_cup-catch_promp-v0', id=f'dmc_ball_in_cup-catch_promp-v0',
@ -72,7 +74,7 @@ register(
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_ball_in_cup-catch_promp-v0") ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_ball_in_cup-catch_promp-v0")
kwargs_dict_reacher_easy_dmp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_reacher_easy_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_reacher_easy_dmp['name'] = f"dmc:reacher-easy" kwargs_dict_reacher_easy_dmp['name'] = f"dm_control/reacher-easy-v0"
kwargs_dict_reacher_easy_dmp['wrappers'].append(suite.reacher.MPWrapper) kwargs_dict_reacher_easy_dmp['wrappers'].append(suite.reacher.MPWrapper)
# bandwidth_factor=2 # bandwidth_factor=2
kwargs_dict_reacher_easy_dmp['phase_generator_kwargs']['alpha_phase'] = 2 kwargs_dict_reacher_easy_dmp['phase_generator_kwargs']['alpha_phase'] = 2
@ -86,7 +88,7 @@ register(
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0") ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0")
kwargs_dict_reacher_easy_promp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_reacher_easy_promp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_reacher_easy_promp['name'] = f"dmc:reacher-easy" kwargs_dict_reacher_easy_promp['name'] = f"dm_control/reacher-easy-v0"
kwargs_dict_reacher_easy_promp['wrappers'].append(suite.reacher.MPWrapper) kwargs_dict_reacher_easy_promp['wrappers'].append(suite.reacher.MPWrapper)
kwargs_dict_reacher_easy_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2 kwargs_dict_reacher_easy_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
register( register(
@ -97,7 +99,7 @@ register(
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_reacher-easy_promp-v0") ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_reacher-easy_promp-v0")
kwargs_dict_reacher_hard_dmp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_reacher_hard_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_reacher_hard_dmp['name'] = f"dmc:reacher-hard" kwargs_dict_reacher_hard_dmp['name'] = f"dm_control/reacher-hard-v0"
kwargs_dict_reacher_hard_dmp['wrappers'].append(suite.reacher.MPWrapper) kwargs_dict_reacher_hard_dmp['wrappers'].append(suite.reacher.MPWrapper)
# bandwidth_factor = 2 # bandwidth_factor = 2
kwargs_dict_reacher_hard_dmp['phase_generator_kwargs']['alpha_phase'] = 2 kwargs_dict_reacher_hard_dmp['phase_generator_kwargs']['alpha_phase'] = 2
@ -111,7 +113,7 @@ register(
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0") ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0")
kwargs_dict_reacher_hard_promp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_reacher_hard_promp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_reacher_hard_promp['name'] = f"dmc:reacher-hard" kwargs_dict_reacher_hard_promp['name'] = f"dm_control/reacher-hard-v0"
kwargs_dict_reacher_hard_promp['wrappers'].append(suite.reacher.MPWrapper) kwargs_dict_reacher_hard_promp['wrappers'].append(suite.reacher.MPWrapper)
kwargs_dict_reacher_hard_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2 kwargs_dict_reacher_hard_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
register( register(
@ -126,7 +128,7 @@ _dmc_cartpole_tasks = ["balance", "balance_sparse", "swingup", "swingup_sparse"]
for _task in _dmc_cartpole_tasks: for _task in _dmc_cartpole_tasks:
_env_id = f'dmc_cartpole-{_task}_dmp-v0' _env_id = f'dmc_cartpole-{_task}_dmp-v0'
kwargs_dict_cartpole_dmp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_cartpole_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_cartpole_dmp['name'] = f"dmc:cartpole-{_task}" kwargs_dict_cartpole_dmp['name'] = f"dm_control/cartpole-{_task}-v0"
kwargs_dict_cartpole_dmp['wrappers'].append(suite.cartpole.MPWrapper) kwargs_dict_cartpole_dmp['wrappers'].append(suite.cartpole.MPWrapper)
# bandwidth_factor = 2 # bandwidth_factor = 2
kwargs_dict_cartpole_dmp['phase_generator_kwargs']['alpha_phase'] = 2 kwargs_dict_cartpole_dmp['phase_generator_kwargs']['alpha_phase'] = 2
@ -143,7 +145,7 @@ for _task in _dmc_cartpole_tasks:
_env_id = f'dmc_cartpole-{_task}_promp-v0' _env_id = f'dmc_cartpole-{_task}_promp-v0'
kwargs_dict_cartpole_promp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_cartpole_promp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_cartpole_promp['name'] = f"dmc:cartpole-{_task}" kwargs_dict_cartpole_promp['name'] = f"dm_control/cartpole-{_task}-v0"
kwargs_dict_cartpole_promp['wrappers'].append(suite.cartpole.MPWrapper) kwargs_dict_cartpole_promp['wrappers'].append(suite.cartpole.MPWrapper)
kwargs_dict_cartpole_promp['controller_kwargs']['p_gains'] = 10 kwargs_dict_cartpole_promp['controller_kwargs']['p_gains'] = 10
kwargs_dict_cartpole_promp['controller_kwargs']['d_gains'] = 10 kwargs_dict_cartpole_promp['controller_kwargs']['d_gains'] = 10
@ -156,7 +158,7 @@ for _task in _dmc_cartpole_tasks:
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
kwargs_dict_cartpole2poles_dmp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_cartpole2poles_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_cartpole2poles_dmp['name'] = f"dmc:cartpole-two_poles" kwargs_dict_cartpole2poles_dmp['name'] = f"dm_control/cartpole-two_poles-v0"
kwargs_dict_cartpole2poles_dmp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper) kwargs_dict_cartpole2poles_dmp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper)
# bandwidth_factor = 2 # bandwidth_factor = 2
kwargs_dict_cartpole2poles_dmp['phase_generator_kwargs']['alpha_phase'] = 2 kwargs_dict_cartpole2poles_dmp['phase_generator_kwargs']['alpha_phase'] = 2
@ -173,7 +175,7 @@ register(
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
kwargs_dict_cartpole2poles_promp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_cartpole2poles_promp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_cartpole2poles_promp['name'] = f"dmc:cartpole-two_poles" kwargs_dict_cartpole2poles_promp['name'] = f"dm_control/cartpole-two_poles-v0"
kwargs_dict_cartpole2poles_promp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper) kwargs_dict_cartpole2poles_promp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper)
kwargs_dict_cartpole2poles_promp['controller_kwargs']['p_gains'] = 10 kwargs_dict_cartpole2poles_promp['controller_kwargs']['p_gains'] = 10
kwargs_dict_cartpole2poles_promp['controller_kwargs']['d_gains'] = 10 kwargs_dict_cartpole2poles_promp['controller_kwargs']['d_gains'] = 10
@ -187,7 +189,7 @@ register(
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
kwargs_dict_cartpole3poles_dmp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_cartpole3poles_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_cartpole3poles_dmp['name'] = f"dmc:cartpole-three_poles" kwargs_dict_cartpole3poles_dmp['name'] = f"dm_control/cartpole-three_poles-v0"
kwargs_dict_cartpole3poles_dmp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper) kwargs_dict_cartpole3poles_dmp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper)
# bandwidth_factor = 2 # bandwidth_factor = 2
kwargs_dict_cartpole3poles_dmp['phase_generator_kwargs']['alpha_phase'] = 2 kwargs_dict_cartpole3poles_dmp['phase_generator_kwargs']['alpha_phase'] = 2
@ -204,7 +206,7 @@ register(
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
kwargs_dict_cartpole3poles_promp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_cartpole3poles_promp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_cartpole3poles_promp['name'] = f"dmc:cartpole-three_poles" kwargs_dict_cartpole3poles_promp['name'] = f"dm_control/cartpole-three_poles-v0"
kwargs_dict_cartpole3poles_promp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper) kwargs_dict_cartpole3poles_promp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper)
kwargs_dict_cartpole3poles_promp['controller_kwargs']['p_gains'] = 10 kwargs_dict_cartpole3poles_promp['controller_kwargs']['p_gains'] = 10
kwargs_dict_cartpole3poles_promp['controller_kwargs']['d_gains'] = 10 kwargs_dict_cartpole3poles_promp['controller_kwargs']['d_gains'] = 10
@ -219,7 +221,7 @@ ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
# DeepMind Manipulation # DeepMind Manipulation
kwargs_dict_mani_reach_site_features_dmp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_mani_reach_site_features_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_mani_reach_site_features_dmp['name'] = f"dmc:manipulation-reach_site_features" kwargs_dict_mani_reach_site_features_dmp['name'] = f"dm_control/reach_site_features-v0"
kwargs_dict_mani_reach_site_features_dmp['wrappers'].append(manipulation.reach_site.MPWrapper) kwargs_dict_mani_reach_site_features_dmp['wrappers'].append(manipulation.reach_site.MPWrapper)
kwargs_dict_mani_reach_site_features_dmp['phase_generator_kwargs']['alpha_phase'] = 2 kwargs_dict_mani_reach_site_features_dmp['phase_generator_kwargs']['alpha_phase'] = 2
# TODO: weight scale 50, but goal scale 0.1 # TODO: weight scale 50, but goal scale 0.1
@ -233,7 +235,7 @@ register(
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0") ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0")
kwargs_dict_mani_reach_site_features_promp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_mani_reach_site_features_promp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_mani_reach_site_features_promp['name'] = f"dmc:manipulation-reach_site_features" kwargs_dict_mani_reach_site_features_promp['name'] = f"dm_control/reach_site_features-v0"
kwargs_dict_mani_reach_site_features_promp['wrappers'].append(manipulation.reach_site.MPWrapper) kwargs_dict_mani_reach_site_features_promp['wrappers'].append(manipulation.reach_site.MPWrapper)
kwargs_dict_mani_reach_site_features_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2 kwargs_dict_mani_reach_site_features_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
kwargs_dict_mani_reach_site_features_promp['controller_kwargs']['controller_type'] = 'velocity' kwargs_dict_mani_reach_site_features_promp['controller_kwargs']['controller_type'] = 'velocity'

View File

@ -3,15 +3,15 @@
# Copyright (c) 2020 Denis Yarats # Copyright (c) 2020 Denis Yarats
import collections import collections
from collections.abc import MutableMapping from collections.abc import MutableMapping
from typing import Any, Dict, Tuple, Optional, Union, Callable from typing import Any, Dict, Tuple, Optional, Union, Callable, SupportsFloat
import gym import gymnasium as gym
import numpy as np import numpy as np
from dm_control import composer from dm_control import composer
from dm_control.rl import control from dm_control.rl import control
from dm_env import specs from dm_env import specs
from gym import spaces from gymnasium import spaces
from gym.core import ObsType from gymnasium.core import ObsType, ActType
def _spec_to_box(spec): def _spec_to_box(spec):
@ -100,23 +100,23 @@ class DMCWrapper(gym.Env):
self._action_space.seed(seed) self._action_space.seed(seed)
self._observation_space.seed(seed) self._observation_space.seed(seed)
def step(self, action) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]: def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
assert self._action_space.contains(action) assert self._action_space.contains(action)
extra = {'internal_state': self._env.physics.get_state().copy()} extra = {'internal_state': self._env.physics.get_state().copy()}
time_step = self._env.step(action) time_step = self._env.step(action)
reward = time_step.reward or 0. reward = time_step.reward or 0.
done = time_step.last() terminated = False
truncated = time_step.last() and time_step.discount > 0
obs = self._get_obs(time_step) obs = self._get_obs(time_step)
extra['discount'] = time_step.discount extra['discount'] = time_step.discount
return obs, reward, done, extra return obs, reward, terminated, truncated, extra
def reset(self, *, seed: Optional[int] = None, return_info: bool = False, def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]: -> Tuple[ObsType, Dict[str, Any]]:
time_step = self._env.reset() time_step = self._env.reset()
obs = self._get_obs(time_step) obs = self._get_obs(time_step)
return obs return obs, {}
def render(self, mode='rgb_array', height=240, width=320, camera_id=-1, overlays=(), depth=False, def render(self, mode='rgb_array', height=240, width=320, camera_id=-1, overlays=(), depth=False,
segmentation=False, scene_option=None, render_flag_overrides=None): segmentation=False, scene_option=None, render_flag_overrides=None):

View File

@ -35,4 +35,4 @@ class MPWrapper(RawInterfaceWrapper):
@property @property
def dt(self) -> Union[float, int]: def dt(self) -> Union[float, int]:
return self.env.dt return self.env.control_timestep()

View File

@ -31,4 +31,4 @@ class MPWrapper(RawInterfaceWrapper):
@property @property
def dt(self) -> Union[float, int]: def dt(self) -> Union[float, int]:
return self.env.dt return self.env.control_timestep()

View File

@ -35,7 +35,7 @@ class MPWrapper(RawInterfaceWrapper):
@property @property
def dt(self) -> Union[float, int]: def dt(self) -> Union[float, int]:
return self.env.dt return self.env.control_timestep()
class TwoPolesMPWrapper(MPWrapper): class TwoPolesMPWrapper(MPWrapper):

View File

@ -30,4 +30,4 @@ class MPWrapper(RawInterfaceWrapper):
@property @property
def dt(self) -> Union[float, int]: def dt(self) -> Union[float, int]:
return self.env.dt return self.env.control_timestep()