updated dm_control envs to use shimmy

This commit is contained in:
Fabian 2023-01-12 17:23:56 +01:00
parent ed724046f3
commit 9ebc021ae0
6 changed files with 34 additions and 32 deletions

View File

@ -1,14 +1,16 @@
from copy import deepcopy
from gymnasium.wrappers import FlattenObservation
from . import manipulation, suite
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
from gym.envs.registration import register
from gymnasium.envs.registration import register
DEFAULT_BB_DICT_ProMP = {
"name": 'EnvName',
"wrappers": [],
"wrappers": [FlattenObservation],
"trajectory_generator_kwargs": {
'trajectory_generator_type': 'promp'
},
@ -29,7 +31,7 @@ DEFAULT_BB_DICT_ProMP = {
DEFAULT_BB_DICT_DMP = {
"name": 'EnvName',
"wrappers": [],
"wrappers": [FlattenObservation],
"trajectory_generator_kwargs": {
'trajectory_generator_type': 'dmp'
},
@ -49,7 +51,7 @@ DEFAULT_BB_DICT_DMP = {
# DeepMind Control Suite (DMC)
kwargs_dict_bic_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_bic_dmp['name'] = f"dmc:ball_in_cup-catch"
kwargs_dict_bic_dmp['name'] = f"dm_control/ball_in_cup-catch-v0"
kwargs_dict_bic_dmp['wrappers'].append(suite.ball_in_cup.MPWrapper)
# bandwidth_factor=2
kwargs_dict_bic_dmp['phase_generator_kwargs']['alpha_phase'] = 2
@ -62,7 +64,7 @@ register(
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0")
kwargs_dict_bic_promp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_bic_promp['name'] = f"dmc:ball_in_cup-catch"
kwargs_dict_bic_promp['name'] = f"dm_control/ball_in_cup-catch-v0"
kwargs_dict_bic_promp['wrappers'].append(suite.ball_in_cup.MPWrapper)
register(
id=f'dmc_ball_in_cup-catch_promp-v0',
@ -72,7 +74,7 @@ register(
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_ball_in_cup-catch_promp-v0")
kwargs_dict_reacher_easy_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_reacher_easy_dmp['name'] = f"dmc:reacher-easy"
kwargs_dict_reacher_easy_dmp['name'] = f"dm_control/reacher-easy-v0"
kwargs_dict_reacher_easy_dmp['wrappers'].append(suite.reacher.MPWrapper)
# bandwidth_factor=2
kwargs_dict_reacher_easy_dmp['phase_generator_kwargs']['alpha_phase'] = 2
@ -86,7 +88,7 @@ register(
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0")
kwargs_dict_reacher_easy_promp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_reacher_easy_promp['name'] = f"dmc:reacher-easy"
kwargs_dict_reacher_easy_promp['name'] = f"dm_control/reacher-easy-v0"
kwargs_dict_reacher_easy_promp['wrappers'].append(suite.reacher.MPWrapper)
kwargs_dict_reacher_easy_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
register(
@ -97,7 +99,7 @@ register(
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_reacher-easy_promp-v0")
kwargs_dict_reacher_hard_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_reacher_hard_dmp['name'] = f"dmc:reacher-hard"
kwargs_dict_reacher_hard_dmp['name'] = f"dm_control/reacher-hard-v0"
kwargs_dict_reacher_hard_dmp['wrappers'].append(suite.reacher.MPWrapper)
# bandwidth_factor = 2
kwargs_dict_reacher_hard_dmp['phase_generator_kwargs']['alpha_phase'] = 2
@ -111,7 +113,7 @@ register(
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0")
kwargs_dict_reacher_hard_promp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_reacher_hard_promp['name'] = f"dmc:reacher-hard"
kwargs_dict_reacher_hard_promp['name'] = f"dm_control/reacher-hard-v0"
kwargs_dict_reacher_hard_promp['wrappers'].append(suite.reacher.MPWrapper)
kwargs_dict_reacher_hard_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
register(
@ -126,7 +128,7 @@ _dmc_cartpole_tasks = ["balance", "balance_sparse", "swingup", "swingup_sparse"]
for _task in _dmc_cartpole_tasks:
_env_id = f'dmc_cartpole-{_task}_dmp-v0'
kwargs_dict_cartpole_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_cartpole_dmp['name'] = f"dmc:cartpole-{_task}"
kwargs_dict_cartpole_dmp['name'] = f"dm_control/cartpole-{_task}-v0"
kwargs_dict_cartpole_dmp['wrappers'].append(suite.cartpole.MPWrapper)
# bandwidth_factor = 2
kwargs_dict_cartpole_dmp['phase_generator_kwargs']['alpha_phase'] = 2
@ -143,7 +145,7 @@ for _task in _dmc_cartpole_tasks:
_env_id = f'dmc_cartpole-{_task}_promp-v0'
kwargs_dict_cartpole_promp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_cartpole_promp['name'] = f"dmc:cartpole-{_task}"
kwargs_dict_cartpole_promp['name'] = f"dm_control/cartpole-{_task}-v0"
kwargs_dict_cartpole_promp['wrappers'].append(suite.cartpole.MPWrapper)
kwargs_dict_cartpole_promp['controller_kwargs']['p_gains'] = 10
kwargs_dict_cartpole_promp['controller_kwargs']['d_gains'] = 10
@ -156,7 +158,7 @@ for _task in _dmc_cartpole_tasks:
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
kwargs_dict_cartpole2poles_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_cartpole2poles_dmp['name'] = f"dmc:cartpole-two_poles"
kwargs_dict_cartpole2poles_dmp['name'] = f"dm_control/cartpole-two_poles-v0"
kwargs_dict_cartpole2poles_dmp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper)
# bandwidth_factor = 2
kwargs_dict_cartpole2poles_dmp['phase_generator_kwargs']['alpha_phase'] = 2
@ -173,7 +175,7 @@ register(
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
kwargs_dict_cartpole2poles_promp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_cartpole2poles_promp['name'] = f"dmc:cartpole-two_poles"
kwargs_dict_cartpole2poles_promp['name'] = f"dm_control/cartpole-two_poles-v0"
kwargs_dict_cartpole2poles_promp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper)
kwargs_dict_cartpole2poles_promp['controller_kwargs']['p_gains'] = 10
kwargs_dict_cartpole2poles_promp['controller_kwargs']['d_gains'] = 10
@ -187,7 +189,7 @@ register(
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
kwargs_dict_cartpole3poles_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_cartpole3poles_dmp['name'] = f"dmc:cartpole-three_poles"
kwargs_dict_cartpole3poles_dmp['name'] = f"dm_control/cartpole-three_poles-v0"
kwargs_dict_cartpole3poles_dmp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper)
# bandwidth_factor = 2
kwargs_dict_cartpole3poles_dmp['phase_generator_kwargs']['alpha_phase'] = 2
@ -204,7 +206,7 @@ register(
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
kwargs_dict_cartpole3poles_promp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_cartpole3poles_promp['name'] = f"dmc:cartpole-three_poles"
kwargs_dict_cartpole3poles_promp['name'] = f"dm_control/cartpole-three_poles-v0"
kwargs_dict_cartpole3poles_promp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper)
kwargs_dict_cartpole3poles_promp['controller_kwargs']['p_gains'] = 10
kwargs_dict_cartpole3poles_promp['controller_kwargs']['d_gains'] = 10
@ -219,7 +221,7 @@ ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
# DeepMind Manipulation
kwargs_dict_mani_reach_site_features_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_mani_reach_site_features_dmp['name'] = f"dmc:manipulation-reach_site_features"
kwargs_dict_mani_reach_site_features_dmp['name'] = f"dm_control/reach_site_features-v0"
kwargs_dict_mani_reach_site_features_dmp['wrappers'].append(manipulation.reach_site.MPWrapper)
kwargs_dict_mani_reach_site_features_dmp['phase_generator_kwargs']['alpha_phase'] = 2
# TODO: weight scale 50, but goal scale 0.1
@ -233,7 +235,7 @@ register(
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0")
kwargs_dict_mani_reach_site_features_promp = deepcopy(DEFAULT_BB_DICT_DMP)
kwargs_dict_mani_reach_site_features_promp['name'] = f"dmc:manipulation-reach_site_features"
kwargs_dict_mani_reach_site_features_promp['name'] = f"dm_control/reach_site_features-v0"
kwargs_dict_mani_reach_site_features_promp['wrappers'].append(manipulation.reach_site.MPWrapper)
kwargs_dict_mani_reach_site_features_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
kwargs_dict_mani_reach_site_features_promp['controller_kwargs']['controller_type'] = 'velocity'

View File

@ -3,15 +3,15 @@
# Copyright (c) 2020 Denis Yarats
import collections
from collections.abc import MutableMapping
from typing import Any, Dict, Tuple, Optional, Union, Callable
from typing import Any, Dict, Tuple, Optional, Union, Callable, SupportsFloat
import gym
import gymnasium as gym
import numpy as np
from dm_control import composer
from dm_control.rl import control
from dm_env import specs
from gym import spaces
from gym.core import ObsType
from gymnasium import spaces
from gymnasium.core import ObsType, ActType
def _spec_to_box(spec):
@ -100,23 +100,23 @@ class DMCWrapper(gym.Env):
self._action_space.seed(seed)
self._observation_space.seed(seed)
def step(self, action) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]:
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
assert self._action_space.contains(action)
extra = {'internal_state': self._env.physics.get_state().copy()}
time_step = self._env.step(action)
reward = time_step.reward or 0.
done = time_step.last()
terminated = False
truncated = time_step.last() and time_step.discount > 0
obs = self._get_obs(time_step)
extra['discount'] = time_step.discount
return obs, reward, done, extra
return obs, reward, terminated, truncated, extra
def reset(self, *, seed: Optional[int] = None, return_info: bool = False,
options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]:
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
-> Tuple[ObsType, Dict[str, Any]]:
time_step = self._env.reset()
obs = self._get_obs(time_step)
return obs
return obs, {}
def render(self, mode='rgb_array', height=240, width=320, camera_id=-1, overlays=(), depth=False,
segmentation=False, scene_option=None, render_flag_overrides=None):

View File

@ -35,4 +35,4 @@ class MPWrapper(RawInterfaceWrapper):
@property
def dt(self) -> Union[float, int]:
return self.env.dt
return self.env.control_timestep()

View File

@ -31,4 +31,4 @@ class MPWrapper(RawInterfaceWrapper):
@property
def dt(self) -> Union[float, int]:
return self.env.dt
return self.env.control_timestep()

View File

@ -35,7 +35,7 @@ class MPWrapper(RawInterfaceWrapper):
@property
def dt(self) -> Union[float, int]:
return self.env.dt
return self.env.control_timestep()
class TwoPolesMPWrapper(MPWrapper):

View File

@ -30,4 +30,4 @@ class MPWrapper(RawInterfaceWrapper):
@property
def dt(self) -> Union[float, int]:
return self.env.dt
return self.env.control_timestep()