updated dm_control envs to use shimmy
This commit is contained in:
parent
ed724046f3
commit
9ebc021ae0
@ -1,14 +1,16 @@
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from gymnasium.wrappers import FlattenObservation
|
||||||
|
|
||||||
from . import manipulation, suite
|
from . import manipulation, suite
|
||||||
|
|
||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||||
|
|
||||||
from gym.envs.registration import register
|
from gymnasium.envs.registration import register
|
||||||
|
|
||||||
DEFAULT_BB_DICT_ProMP = {
|
DEFAULT_BB_DICT_ProMP = {
|
||||||
"name": 'EnvName',
|
"name": 'EnvName',
|
||||||
"wrappers": [],
|
"wrappers": [FlattenObservation],
|
||||||
"trajectory_generator_kwargs": {
|
"trajectory_generator_kwargs": {
|
||||||
'trajectory_generator_type': 'promp'
|
'trajectory_generator_type': 'promp'
|
||||||
},
|
},
|
||||||
@ -29,7 +31,7 @@ DEFAULT_BB_DICT_ProMP = {
|
|||||||
|
|
||||||
DEFAULT_BB_DICT_DMP = {
|
DEFAULT_BB_DICT_DMP = {
|
||||||
"name": 'EnvName',
|
"name": 'EnvName',
|
||||||
"wrappers": [],
|
"wrappers": [FlattenObservation],
|
||||||
"trajectory_generator_kwargs": {
|
"trajectory_generator_kwargs": {
|
||||||
'trajectory_generator_type': 'dmp'
|
'trajectory_generator_type': 'dmp'
|
||||||
},
|
},
|
||||||
@ -49,7 +51,7 @@ DEFAULT_BB_DICT_DMP = {
|
|||||||
|
|
||||||
# DeepMind Control Suite (DMC)
|
# DeepMind Control Suite (DMC)
|
||||||
kwargs_dict_bic_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_bic_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||||
kwargs_dict_bic_dmp['name'] = f"dmc:ball_in_cup-catch"
|
kwargs_dict_bic_dmp['name'] = f"dm_control/ball_in_cup-catch-v0"
|
||||||
kwargs_dict_bic_dmp['wrappers'].append(suite.ball_in_cup.MPWrapper)
|
kwargs_dict_bic_dmp['wrappers'].append(suite.ball_in_cup.MPWrapper)
|
||||||
# bandwidth_factor=2
|
# bandwidth_factor=2
|
||||||
kwargs_dict_bic_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
kwargs_dict_bic_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||||
@ -62,7 +64,7 @@ register(
|
|||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0")
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_ball_in_cup-catch_dmp-v0")
|
||||||
|
|
||||||
kwargs_dict_bic_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_bic_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||||
kwargs_dict_bic_promp['name'] = f"dmc:ball_in_cup-catch"
|
kwargs_dict_bic_promp['name'] = f"dm_control/ball_in_cup-catch-v0"
|
||||||
kwargs_dict_bic_promp['wrappers'].append(suite.ball_in_cup.MPWrapper)
|
kwargs_dict_bic_promp['wrappers'].append(suite.ball_in_cup.MPWrapper)
|
||||||
register(
|
register(
|
||||||
id=f'dmc_ball_in_cup-catch_promp-v0',
|
id=f'dmc_ball_in_cup-catch_promp-v0',
|
||||||
@ -72,7 +74,7 @@ register(
|
|||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_ball_in_cup-catch_promp-v0")
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_ball_in_cup-catch_promp-v0")
|
||||||
|
|
||||||
kwargs_dict_reacher_easy_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_reacher_easy_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||||
kwargs_dict_reacher_easy_dmp['name'] = f"dmc:reacher-easy"
|
kwargs_dict_reacher_easy_dmp['name'] = f"dm_control/reacher-easy-v0"
|
||||||
kwargs_dict_reacher_easy_dmp['wrappers'].append(suite.reacher.MPWrapper)
|
kwargs_dict_reacher_easy_dmp['wrappers'].append(suite.reacher.MPWrapper)
|
||||||
# bandwidth_factor=2
|
# bandwidth_factor=2
|
||||||
kwargs_dict_reacher_easy_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
kwargs_dict_reacher_easy_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||||
@ -86,7 +88,7 @@ register(
|
|||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0")
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-easy_dmp-v0")
|
||||||
|
|
||||||
kwargs_dict_reacher_easy_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_reacher_easy_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||||
kwargs_dict_reacher_easy_promp['name'] = f"dmc:reacher-easy"
|
kwargs_dict_reacher_easy_promp['name'] = f"dm_control/reacher-easy-v0"
|
||||||
kwargs_dict_reacher_easy_promp['wrappers'].append(suite.reacher.MPWrapper)
|
kwargs_dict_reacher_easy_promp['wrappers'].append(suite.reacher.MPWrapper)
|
||||||
kwargs_dict_reacher_easy_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
kwargs_dict_reacher_easy_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
||||||
register(
|
register(
|
||||||
@ -97,7 +99,7 @@ register(
|
|||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_reacher-easy_promp-v0")
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("dmc_reacher-easy_promp-v0")
|
||||||
|
|
||||||
kwargs_dict_reacher_hard_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_reacher_hard_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||||
kwargs_dict_reacher_hard_dmp['name'] = f"dmc:reacher-hard"
|
kwargs_dict_reacher_hard_dmp['name'] = f"dm_control/reacher-hard-v0"
|
||||||
kwargs_dict_reacher_hard_dmp['wrappers'].append(suite.reacher.MPWrapper)
|
kwargs_dict_reacher_hard_dmp['wrappers'].append(suite.reacher.MPWrapper)
|
||||||
# bandwidth_factor = 2
|
# bandwidth_factor = 2
|
||||||
kwargs_dict_reacher_hard_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
kwargs_dict_reacher_hard_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||||
@ -111,7 +113,7 @@ register(
|
|||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0")
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_reacher-hard_dmp-v0")
|
||||||
|
|
||||||
kwargs_dict_reacher_hard_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_reacher_hard_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||||
kwargs_dict_reacher_hard_promp['name'] = f"dmc:reacher-hard"
|
kwargs_dict_reacher_hard_promp['name'] = f"dm_control/reacher-hard-v0"
|
||||||
kwargs_dict_reacher_hard_promp['wrappers'].append(suite.reacher.MPWrapper)
|
kwargs_dict_reacher_hard_promp['wrappers'].append(suite.reacher.MPWrapper)
|
||||||
kwargs_dict_reacher_hard_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
kwargs_dict_reacher_hard_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
||||||
register(
|
register(
|
||||||
@ -126,7 +128,7 @@ _dmc_cartpole_tasks = ["balance", "balance_sparse", "swingup", "swingup_sparse"]
|
|||||||
for _task in _dmc_cartpole_tasks:
|
for _task in _dmc_cartpole_tasks:
|
||||||
_env_id = f'dmc_cartpole-{_task}_dmp-v0'
|
_env_id = f'dmc_cartpole-{_task}_dmp-v0'
|
||||||
kwargs_dict_cartpole_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_cartpole_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||||
kwargs_dict_cartpole_dmp['name'] = f"dmc:cartpole-{_task}"
|
kwargs_dict_cartpole_dmp['name'] = f"dm_control/cartpole-{_task}-v0"
|
||||||
kwargs_dict_cartpole_dmp['wrappers'].append(suite.cartpole.MPWrapper)
|
kwargs_dict_cartpole_dmp['wrappers'].append(suite.cartpole.MPWrapper)
|
||||||
# bandwidth_factor = 2
|
# bandwidth_factor = 2
|
||||||
kwargs_dict_cartpole_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
kwargs_dict_cartpole_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||||
@ -143,7 +145,7 @@ for _task in _dmc_cartpole_tasks:
|
|||||||
|
|
||||||
_env_id = f'dmc_cartpole-{_task}_promp-v0'
|
_env_id = f'dmc_cartpole-{_task}_promp-v0'
|
||||||
kwargs_dict_cartpole_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_cartpole_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||||
kwargs_dict_cartpole_promp['name'] = f"dmc:cartpole-{_task}"
|
kwargs_dict_cartpole_promp['name'] = f"dm_control/cartpole-{_task}-v0"
|
||||||
kwargs_dict_cartpole_promp['wrappers'].append(suite.cartpole.MPWrapper)
|
kwargs_dict_cartpole_promp['wrappers'].append(suite.cartpole.MPWrapper)
|
||||||
kwargs_dict_cartpole_promp['controller_kwargs']['p_gains'] = 10
|
kwargs_dict_cartpole_promp['controller_kwargs']['p_gains'] = 10
|
||||||
kwargs_dict_cartpole_promp['controller_kwargs']['d_gains'] = 10
|
kwargs_dict_cartpole_promp['controller_kwargs']['d_gains'] = 10
|
||||||
@ -156,7 +158,7 @@ for _task in _dmc_cartpole_tasks:
|
|||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
kwargs_dict_cartpole2poles_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_cartpole2poles_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||||
kwargs_dict_cartpole2poles_dmp['name'] = f"dmc:cartpole-two_poles"
|
kwargs_dict_cartpole2poles_dmp['name'] = f"dm_control/cartpole-two_poles-v0"
|
||||||
kwargs_dict_cartpole2poles_dmp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper)
|
kwargs_dict_cartpole2poles_dmp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper)
|
||||||
# bandwidth_factor = 2
|
# bandwidth_factor = 2
|
||||||
kwargs_dict_cartpole2poles_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
kwargs_dict_cartpole2poles_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||||
@ -173,7 +175,7 @@ register(
|
|||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||||
|
|
||||||
kwargs_dict_cartpole2poles_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_cartpole2poles_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||||
kwargs_dict_cartpole2poles_promp['name'] = f"dmc:cartpole-two_poles"
|
kwargs_dict_cartpole2poles_promp['name'] = f"dm_control/cartpole-two_poles-v0"
|
||||||
kwargs_dict_cartpole2poles_promp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper)
|
kwargs_dict_cartpole2poles_promp['wrappers'].append(suite.cartpole.TwoPolesMPWrapper)
|
||||||
kwargs_dict_cartpole2poles_promp['controller_kwargs']['p_gains'] = 10
|
kwargs_dict_cartpole2poles_promp['controller_kwargs']['p_gains'] = 10
|
||||||
kwargs_dict_cartpole2poles_promp['controller_kwargs']['d_gains'] = 10
|
kwargs_dict_cartpole2poles_promp['controller_kwargs']['d_gains'] = 10
|
||||||
@ -187,7 +189,7 @@ register(
|
|||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
kwargs_dict_cartpole3poles_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_cartpole3poles_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||||
kwargs_dict_cartpole3poles_dmp['name'] = f"dmc:cartpole-three_poles"
|
kwargs_dict_cartpole3poles_dmp['name'] = f"dm_control/cartpole-three_poles-v0"
|
||||||
kwargs_dict_cartpole3poles_dmp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper)
|
kwargs_dict_cartpole3poles_dmp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper)
|
||||||
# bandwidth_factor = 2
|
# bandwidth_factor = 2
|
||||||
kwargs_dict_cartpole3poles_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
kwargs_dict_cartpole3poles_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||||
@ -204,7 +206,7 @@ register(
|
|||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
|
||||||
|
|
||||||
kwargs_dict_cartpole3poles_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_cartpole3poles_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||||
kwargs_dict_cartpole3poles_promp['name'] = f"dmc:cartpole-three_poles"
|
kwargs_dict_cartpole3poles_promp['name'] = f"dm_control/cartpole-three_poles-v0"
|
||||||
kwargs_dict_cartpole3poles_promp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper)
|
kwargs_dict_cartpole3poles_promp['wrappers'].append(suite.cartpole.ThreePolesMPWrapper)
|
||||||
kwargs_dict_cartpole3poles_promp['controller_kwargs']['p_gains'] = 10
|
kwargs_dict_cartpole3poles_promp['controller_kwargs']['p_gains'] = 10
|
||||||
kwargs_dict_cartpole3poles_promp['controller_kwargs']['d_gains'] = 10
|
kwargs_dict_cartpole3poles_promp['controller_kwargs']['d_gains'] = 10
|
||||||
@ -219,7 +221,7 @@ ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
|||||||
|
|
||||||
# DeepMind Manipulation
|
# DeepMind Manipulation
|
||||||
kwargs_dict_mani_reach_site_features_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_mani_reach_site_features_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||||
kwargs_dict_mani_reach_site_features_dmp['name'] = f"dmc:manipulation-reach_site_features"
|
kwargs_dict_mani_reach_site_features_dmp['name'] = f"dm_control/reach_site_features-v0"
|
||||||
kwargs_dict_mani_reach_site_features_dmp['wrappers'].append(manipulation.reach_site.MPWrapper)
|
kwargs_dict_mani_reach_site_features_dmp['wrappers'].append(manipulation.reach_site.MPWrapper)
|
||||||
kwargs_dict_mani_reach_site_features_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
kwargs_dict_mani_reach_site_features_dmp['phase_generator_kwargs']['alpha_phase'] = 2
|
||||||
# TODO: weight scale 50, but goal scale 0.1
|
# TODO: weight scale 50, but goal scale 0.1
|
||||||
@ -233,7 +235,7 @@ register(
|
|||||||
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0")
|
ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("dmc_manipulation-reach_site_dmp-v0")
|
||||||
|
|
||||||
kwargs_dict_mani_reach_site_features_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
kwargs_dict_mani_reach_site_features_promp = deepcopy(DEFAULT_BB_DICT_DMP)
|
||||||
kwargs_dict_mani_reach_site_features_promp['name'] = f"dmc:manipulation-reach_site_features"
|
kwargs_dict_mani_reach_site_features_promp['name'] = f"dm_control/reach_site_features-v0"
|
||||||
kwargs_dict_mani_reach_site_features_promp['wrappers'].append(manipulation.reach_site.MPWrapper)
|
kwargs_dict_mani_reach_site_features_promp['wrappers'].append(manipulation.reach_site.MPWrapper)
|
||||||
kwargs_dict_mani_reach_site_features_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
kwargs_dict_mani_reach_site_features_promp['trajectory_generator_kwargs']['weight_scale'] = 0.2
|
||||||
kwargs_dict_mani_reach_site_features_promp['controller_kwargs']['controller_type'] = 'velocity'
|
kwargs_dict_mani_reach_site_features_promp['controller_kwargs']['controller_type'] = 'velocity'
|
||||||
|
@ -3,15 +3,15 @@
|
|||||||
# Copyright (c) 2020 Denis Yarats
|
# Copyright (c) 2020 Denis Yarats
|
||||||
import collections
|
import collections
|
||||||
from collections.abc import MutableMapping
|
from collections.abc import MutableMapping
|
||||||
from typing import Any, Dict, Tuple, Optional, Union, Callable
|
from typing import Any, Dict, Tuple, Optional, Union, Callable, SupportsFloat
|
||||||
|
|
||||||
import gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from dm_control import composer
|
from dm_control import composer
|
||||||
from dm_control.rl import control
|
from dm_control.rl import control
|
||||||
from dm_env import specs
|
from dm_env import specs
|
||||||
from gym import spaces
|
from gymnasium import spaces
|
||||||
from gym.core import ObsType
|
from gymnasium.core import ObsType, ActType
|
||||||
|
|
||||||
|
|
||||||
def _spec_to_box(spec):
|
def _spec_to_box(spec):
|
||||||
@ -100,23 +100,23 @@ class DMCWrapper(gym.Env):
|
|||||||
self._action_space.seed(seed)
|
self._action_space.seed(seed)
|
||||||
self._observation_space.seed(seed)
|
self._observation_space.seed(seed)
|
||||||
|
|
||||||
def step(self, action) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]:
|
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||||
assert self._action_space.contains(action)
|
assert self._action_space.contains(action)
|
||||||
extra = {'internal_state': self._env.physics.get_state().copy()}
|
extra = {'internal_state': self._env.physics.get_state().copy()}
|
||||||
|
|
||||||
time_step = self._env.step(action)
|
time_step = self._env.step(action)
|
||||||
reward = time_step.reward or 0.
|
reward = time_step.reward or 0.
|
||||||
done = time_step.last()
|
terminated = False
|
||||||
|
truncated = time_step.last() and time_step.discount > 0
|
||||||
obs = self._get_obs(time_step)
|
obs = self._get_obs(time_step)
|
||||||
extra['discount'] = time_step.discount
|
extra['discount'] = time_step.discount
|
||||||
|
|
||||||
return obs, reward, done, extra
|
return obs, reward, terminated, truncated, extra
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, return_info: bool = False,
|
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||||
options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]:
|
-> Tuple[ObsType, Dict[str, Any]]:
|
||||||
time_step = self._env.reset()
|
time_step = self._env.reset()
|
||||||
obs = self._get_obs(time_step)
|
obs = self._get_obs(time_step)
|
||||||
return obs
|
return obs, {}
|
||||||
|
|
||||||
def render(self, mode='rgb_array', height=240, width=320, camera_id=-1, overlays=(), depth=False,
|
def render(self, mode='rgb_array', height=240, width=320, camera_id=-1, overlays=(), depth=False,
|
||||||
segmentation=False, scene_option=None, render_flag_overrides=None):
|
segmentation=False, scene_option=None, render_flag_overrides=None):
|
||||||
|
@ -35,4 +35,4 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def dt(self) -> Union[float, int]:
|
def dt(self) -> Union[float, int]:
|
||||||
return self.env.dt
|
return self.env.control_timestep()
|
||||||
|
@ -31,4 +31,4 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def dt(self) -> Union[float, int]:
|
def dt(self) -> Union[float, int]:
|
||||||
return self.env.dt
|
return self.env.control_timestep()
|
||||||
|
@ -35,7 +35,7 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def dt(self) -> Union[float, int]:
|
def dt(self) -> Union[float, int]:
|
||||||
return self.env.dt
|
return self.env.control_timestep()
|
||||||
|
|
||||||
|
|
||||||
class TwoPolesMPWrapper(MPWrapper):
|
class TwoPolesMPWrapper(MPWrapper):
|
||||||
|
@ -30,4 +30,4 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def dt(self) -> Union[float, int]:
|
def dt(self) -> Union[float, int]:
|
||||||
return self.env.dt
|
return self.env.control_timestep()
|
||||||
|
Loading…
Reference in New Issue
Block a user