updated dm_control envs to use shimmy
This commit is contained in:
parent
ed724046f3
commit
9ebc021ae0
@ -1,14 +1,16 @@
|
||||
from copy import deepcopy
|
||||
|
||||
from gymnasium.wrappers import FlattenObservation
|
||||
|
||||
from . import manipulation, suite
|
||||
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||
|
||||
from gym.envs.registration import register
|
||||
from gymnasium.envs.registration import register
|
||||
|
||||
DEFAULT_BB_DICT_ProMP = {
|
||||
"name": 'EnvName',
|
||||
"wrappers": [],
|
||||
"wrappers": [FlattenObservation],
|
||||
"trajectory_generator_kwargs": {
|
||||
'trajectory_generator_type': 'promp'
|
||||
},
|
||||
@ -29,7 +31,7 @@ DEFAULT_BB_DICT_ProMP = {
|
||||
|
||||
DEFAULT_BB_DICT_DMP = {
|
||||
"name": 'EnvName',
|
||||
"wrappers": [],
|
||||
"wrappers": [FlattenObservation],
|
||||
"trajectory_generator_kwargs": {
|
||||
'trajectory_generator_type': 'dmp'
|
||||
},
|
||||
@ -49,7 +51,7 @@ DEFAULT_BB_DICT_DMP = {
|
||||
|
||||
# DeepMind Control Suite (DMC)
|
||||
kwargs_dict_bic_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_bic_dmp['name'] = f"dmc:ball_in_cup-catch"
|
||||
kwargs_dict_bic_dmp['name'] = f"dm_control/ball_in_cup-catch-v0"
|
||||
kwargs_dict_bic_dmp['wrappers'].append(suite.ball_in_cup.MPWrapper)
|
||||
# bandwidth_factor=2
|
||||
kwargs_dict_bic_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
@ -62,7 +64,7 @@ register(
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0")
|
||||
|
||||
kwargs_dict_bic_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_bic_promp['name'] = f"dmc:ball_in_cup-catch"
|
||||
kwargs_dict_bic_promp['name'] = f"dm_control/ball_in_cup-catch-v0"
|
||||
kwargs_dict_bic_promp['wrappers'].append(suite.ball_in_cup.MPWrapper)
|
||||
register(
|
||||
id=f'dmc_ball_in_cup-catch_promp-v0',
|
||||
@ -72,7 +74,7 @@ register(
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_ball_in_cup-catch_promp-v0")
|
||||
|
||||
kwargs_dict_reacher_easy_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_reacher_easy_dmp['name'] = f"dmc:reacher-easy"
|
||||
kwargs_dict_reacher_easy_dmp['name'] = f"dm_control/reacher-easy-v0"
|
||||
kwargs_dict_reacher_easy_dmp['wrappers'].append(suite.reacher.MPWrapper)
|
||||
# bandwidth_factor=2
|
||||
kwargs_dict_reacher_easy_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
@ -86,7 +88,7 @@ register(
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0")
|
||||
|
||||
kwargs_dict_reacher_easy_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_reacher_easy_promp['name'] = f"dmc:reacher-easy"
|
||||
kwargs_dict_reacher_easy_promp['name'] = f"dm_control/reacher-easy-v0"
|
||||
kwargs_dict_reacher_easy_promp['wrappers'].append(suite.reacher.MPWrapper)
|
||||
kwargs_dict_reacher_easy_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
||||
register(
|
||||
@ -97,7 +99,7 @@ register(
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_reacher-easy_promp-v0")
|
||||
|
||||
kwargs_dict_reacher_hard_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_reacher_hard_dmp['name'] = f"dmc:reacher-hard"
|
||||
kwargs_dict_reacher_hard_dmp['name'] = f"dm_control/reacher-hard-v0"
|
||||
kwargs_dict_reacher_hard_dmp['wrappers'].append(suite.reacher.MPWrapper)
|
||||
# bandwidth_factor = 2
|
||||
kwargs_dict_reacher_hard_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
@ -111,7 +113,7 @@ register(
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0")
|
||||
|
||||
kwargs_dict_reacher_hard_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_reacher_hard_promp['name'] = f"dmc:reacher-hard"
|
||||
kwargs_dict_reacher_hard_promp['name'] = f"dm_control/reacher-hard-v0"
|
||||
kwargs_dict_reacher_hard_promp['wrappers'].append(suite.reacher.MPWrapper)
|
||||
kwargs_dict_reacher_hard_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
||||
register(
|
||||
@ -126,7 +128,7 @@ _dmc_cartpole_tasks = ["balance", "balance_sparse", "swingup", "swingup_sparse"]
|
||||
for _task in _dmc_cartpole_tasks:
|
||||
_env_id = f'dmc_cartpole-{_task}_dmp-v0'
|
||||
kwargs_dict_cartpole_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_cartpole_dmp['name'] = f"dmc:cartpole-{_task}"
|
||||
kwargs_dict_cartpole_dmp['name'] = f"dm_control/cartpole-{_task}-v0"
|
||||
kwargs_dict_cartpole_dmp['wrappers'].append(suite.cartpole.MPWrapper)
|
||||
# bandwidth_factor = 2
|
||||
kwargs_dict_cartpole_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
@ -143,7 +145,7 @@ for _task in _dmc_cartpole_tasks:
|
||||
|
||||
_env_id = f'dmc_cartpole-{_task}_promp-v0'
|
||||
kwargs_dict_cartpole_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_cartpole_promp['name'] = f"dmc:cartpole-{_task}"
|
||||
kwargs_dict_cartpole_promp['name'] = f"dm_control/cartpole-{_task}-v0"
|
||||
kwargs_dict_cartpole_promp['wrappers'].append(suite.cartpole.MPWrapper)
|
||||
kwargs_dict_cartpole_promp['controller_kwargs']['p_gains'] = 10
|
||||
kwargs_dict_cartpole_promp['controller_kwargs']['d_gains'] = 10
|
||||
@ -156,7 +158,7 @@ for _task in _dmc_cartpole_tasks:
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
kwargs_dict_cartpole2poles_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_cartpole2poles_dmp['name'] = f"dmc:cartpole-two_poles"
|
||||
kwargs_dict_cartpole2poles_dmp['name'] = f"dm_control/cartpole-two_poles-v0"
|
||||
kwargs_dict_cartpole2poles_dmp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper)
|
||||
# bandwidth_factor = 2
|
||||
kwargs_dict_cartpole2poles_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
@ -173,7 +175,7 @@ register(
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||
|
||||
kwargs_dict_cartpole2poles_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_cartpole2poles_promp['name'] = f"dmc:cartpole-two_poles"
|
||||
kwargs_dict_cartpole2poles_promp['name'] = f"dm_control/cartpole-two_poles-v0"
|
||||
kwargs_dict_cartpole2poles_promp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper)
|
||||
kwargs_dict_cartpole2poles_promp['controller_kwargs']['p_gains'] = 10
|
||||
kwargs_dict_cartpole2poles_promp['controller_kwargs']['d_gains'] = 10
|
||||
@ -187,7 +189,7 @@ register(
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
kwargs_dict_cartpole3poles_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_cartpole3poles_dmp['name'] = f"dmc:cartpole-three_poles"
|
||||
kwargs_dict_cartpole3poles_dmp['name'] = f"dm_control/cartpole-three_poles-v0"
|
||||
kwargs_dict_cartpole3poles_dmp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper)
|
||||
# bandwidth_factor = 2
|
||||
kwargs_dict_cartpole3poles_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
@ -204,7 +206,7 @@ register(
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||
|
||||
kwargs_dict_cartpole3poles_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_cartpole3poles_promp['name'] = f"dmc:cartpole-three_poles"
|
||||
kwargs_dict_cartpole3poles_promp['name'] = f"dm_control/cartpole-three_poles-v0"
|
||||
kwargs_dict_cartpole3poles_promp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper)
|
||||
kwargs_dict_cartpole3poles_promp['controller_kwargs']['p_gains'] = 10
|
||||
kwargs_dict_cartpole3poles_promp['controller_kwargs']['d_gains'] = 10
|
||||
@ -219,7 +221,7 @@ ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||
|
||||
# DeepMind Manipulation
|
||||
kwargs_dict_mani_reach_site_features_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_mani_reach_site_features_dmp['name'] = f"dmc:manipulation-reach_site_features"
|
||||
kwargs_dict_mani_reach_site_features_dmp['name'] = f"dm_control/reach_site_features-v0"
|
||||
kwargs_dict_mani_reach_site_features_dmp['wrappers'].append(manipulation.reach_site.MPWrapper)
|
||||
kwargs_dict_mani_reach_site_features_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||
# TODO: weight scale 50, but goal scale 0.1
|
||||
@ -233,7 +235,7 @@ register(
|
||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0")
|
||||
|
||||
kwargs_dict_mani_reach_site_features_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||
kwargs_dict_mani_reach_site_features_promp['name'] = f"dmc:manipulation-reach_site_features"
|
||||
kwargs_dict_mani_reach_site_features_promp['name'] = f"dm_control/reach_site_features-v0"
|
||||
kwargs_dict_mani_reach_site_features_promp['wrappers'].append(manipulation.reach_site.MPWrapper)
|
||||
kwargs_dict_mani_reach_site_features_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
||||
kwargs_dict_mani_reach_site_features_promp['controller_kwargs']['controller_type'] = 'velocity'
|
||||
|
@ -3,15 +3,15 @@
|
||||
# Copyright (c) 2020 Denis Yarats
|
||||
import collections
|
||||
from collections.abc import MutableMapping
|
||||
from typing import Any, Dict, Tuple, Optional, Union, Callable
|
||||
from typing import Any, Dict, Tuple, Optional, Union, Callable, SupportsFloat
|
||||
|
||||
import gym
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from dm_control import composer
|
||||
from dm_control.rl import control
|
||||
from dm_env import specs
|
||||
from gym import spaces
|
||||
from gym.core import ObsType
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType, ActType
|
||||
|
||||
|
||||
def _spec_to_box(spec):
|
||||
@ -100,23 +100,23 @@ class DMCWrapper(gym.Env):
|
||||
self._action_space.seed(seed)
|
||||
self._observation_space.seed(seed)
|
||||
|
||||
def step(self, action) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]:
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||
assert self._action_space.contains(action)
|
||||
extra = {'internal_state': self._env.physics.get_state().copy()}
|
||||
|
||||
time_step = self._env.step(action)
|
||||
reward = time_step.reward or 0.
|
||||
done = time_step.last()
|
||||
terminated = False
|
||||
truncated = time_step.last() and time_step.discount > 0
|
||||
obs = self._get_obs(time_step)
|
||||
extra['discount'] = time_step.discount
|
||||
|
||||
return obs, reward, done, extra
|
||||
return obs, reward, terminated, truncated, extra
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, return_info: bool = False,
|
||||
options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||
-> Tuple[ObsType, Dict[str, Any]]:
|
||||
time_step = self._env.reset()
|
||||
obs = self._get_obs(time_step)
|
||||
return obs
|
||||
return obs, {}
|
||||
|
||||
def render(self, mode='rgb_array', height=240, width=320, camera_id=-1, overlays=(), depth=False,
|
||||
segmentation=False, scene_option=None, render_flag_overrides=None):
|
||||
|
@ -35,4 +35,4 @@ class MPWrapper(RawInterfaceWrapper):
|
||||
|
||||
@property
|
||||
def dt(self) -> Union[float, int]:
|
||||
return self.env.dt
|
||||
return self.env.control_timestep()
|
||||
|
@ -31,4 +31,4 @@ class MPWrapper(RawInterfaceWrapper):
|
||||
|
||||
@property
|
||||
def dt(self) -> Union[float, int]:
|
||||
return self.env.dt
|
||||
return self.env.control_timestep()
|
||||
|
@ -35,7 +35,7 @@ class MPWrapper(RawInterfaceWrapper):
|
||||
|
||||
@property
|
||||
def dt(self) -> Union[float, int]:
|
||||
return self.env.dt
|
||||
return self.env.control_timestep()
|
||||
|
||||
|
||||
class TwoPolesMPWrapper(MPWrapper):
|
||||
|
@ -30,4 +30,4 @@ class MPWrapper(RawInterfaceWrapper):
|
||||
|
||||
@property
|
||||
def dt(self) -> Union[float, int]:
|
||||
return self.env.dt
|
||||
return self.env.control_timestep()
|
||||
|
Loading…
Reference in New Issue
Block a user