added sanity check for computed trajectory duration/length and environment episode length
This commit is contained in:
parent
c6b4cff3a3
commit
7e2f5d664b
@ -216,7 +216,7 @@ register(
|
|||||||
"random_start": False,
|
"random_start": False,
|
||||||
"allow_self_collision": False,
|
"allow_self_collision": False,
|
||||||
"allow_wall_collision": False,
|
"allow_wall_collision": False,
|
||||||
"hole_width": None,
|
"hole_width": 0.25,
|
||||||
"hole_depth": 1,
|
"hole_depth": 1,
|
||||||
"hole_x": None,
|
"hole_x": None,
|
||||||
"collision_penalty": 100,
|
"collision_penalty": 100,
|
||||||
@ -525,11 +525,13 @@ register(
|
|||||||
# max_episode_steps=1,
|
# max_episode_steps=1,
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": f"ball_in_cup-catch",
|
"name": f"ball_in_cup-catch",
|
||||||
|
"time_limit": 1,
|
||||||
|
"episode_length": 50,
|
||||||
"wrappers": [DMCBallInCupMPWrapper],
|
"wrappers": [DMCBallInCupMPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 20,
|
"duration": 1,
|
||||||
"learn_goal": True,
|
"learn_goal": True,
|
||||||
"alpha_phase": 2,
|
"alpha_phase": 2,
|
||||||
"bandwidth_factor": 2,
|
"bandwidth_factor": 2,
|
||||||
@ -549,11 +551,13 @@ register(
|
|||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": f"ball_in_cup-catch",
|
"name": f"ball_in_cup-catch",
|
||||||
|
"time_limit": 1,
|
||||||
|
"episode_length": 50,
|
||||||
"wrappers": [DMCBallInCupMPWrapper],
|
"wrappers": [DMCBallInCupMPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 20,
|
"duration": 1,
|
||||||
"width": 0.025,
|
"width": 0.025,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 0.2,
|
"weights_scale": 0.2,
|
||||||
|
@ -38,6 +38,7 @@ def example_dmc(env_id="fish-swim", seed=1, iterations=1000, render=True):
|
|||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
del env
|
||||||
|
|
||||||
|
|
||||||
def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
||||||
@ -78,7 +79,12 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
|||||||
"d_gains": 0.05
|
"d_gains": 0.05
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
env = make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
kwargs = {
|
||||||
|
"time_limit": 20,
|
||||||
|
"episode_length": 1000,
|
||||||
|
# "frame_skip": 1
|
||||||
|
}
|
||||||
|
env = make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
|
||||||
# OR for a deterministic ProMP:
|
# OR for a deterministic ProMP:
|
||||||
# env = make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
|
# env = make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
|
||||||
|
|
||||||
@ -105,6 +111,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
|||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
del env
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -113,16 +120,18 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# For rendering DMC
|
# For rendering DMC
|
||||||
# export MUJOCO_GL="osmesa"
|
# export MUJOCO_GL="osmesa"
|
||||||
|
render = False
|
||||||
|
|
||||||
# Standard DMC Suite tasks
|
# # Standard DMC Suite tasks
|
||||||
example_dmc("fish-swim", seed=10, iterations=1000, render=True)
|
# example_dmc("fish-swim", seed=10, iterations=1000, render=render)
|
||||||
|
#
|
||||||
# Manipulation tasks
|
# # Manipulation tasks
|
||||||
# Disclaimer: The vision versions are currently not integrated and yield an error
|
# # Disclaimer: The vision versions are currently not integrated and yield an error
|
||||||
example_dmc("manipulation-reach_site_features", seed=10, iterations=250, render=True)
|
# example_dmc("manipulation-reach_site_features", seed=10, iterations=250, render=render)
|
||||||
|
|
||||||
# Gym + DMC hybrid task provided in the MP framework
|
# Gym + DMC hybrid task provided in the MP framework
|
||||||
example_dmc("dmc_ball_in_cup-catch_detpmp-v0", seed=10, iterations=1, render=True)
|
example_dmc("dmc_ball_in_cup-catch_detpmp-v0", seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
# Custom DMC task
|
# Custom DMC task
|
||||||
example_custom_dmc_and_mp(seed=10, iterations=1, render=True)
|
# Different seed, because the episode is longer for this example and the name+seed combo is already registered above
|
||||||
|
example_custom_dmc_and_mp(seed=11, iterations=1, render=render)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import collections
|
|
||||||
import re
|
import re
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
@ -125,6 +125,19 @@ class DMCWrapper(core.Env):
|
|||||||
def dt(self):
|
def dt(self):
|
||||||
return self._env.control_timestep() * self._frame_skip
|
return self._env.control_timestep() * self._frame_skip
|
||||||
|
|
||||||
|
@property
|
||||||
|
def base_step_limit(self):
|
||||||
|
"""
|
||||||
|
Returns: max_episode_steps of the underlying DMC env
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Accessing private attribute because DMC does not expose time_limit or step_limit.
|
||||||
|
# Only the current time_step/time as well as the control_timestep can be accessed.
|
||||||
|
try:
|
||||||
|
return (self._env._step_limit + self._frame_skip - 1) // self._frame_skip
|
||||||
|
except AttributeError as e:
|
||||||
|
return self._env._time_limit / self.dt
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
self._action_space.seed(seed)
|
self._action_space.seed(seed)
|
||||||
self._observation_space.seed(seed)
|
self._observation_space.seed(seed)
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Iterable, List, Type
|
from typing import Iterable, List, Type, Union
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
|
|
||||||
@ -8,7 +8,7 @@ from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
|||||||
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
||||||
|
|
||||||
|
|
||||||
def make_env_rank(env_id: str, seed: int, rank: int = 0):
|
def make_env_rank(env_id: str, seed: int, rank: int = 0, **kwargs):
|
||||||
"""
|
"""
|
||||||
TODO: Do we need this?
|
TODO: Do we need this?
|
||||||
Generate a callable to create a new gym environment with a given seed.
|
Generate a callable to create a new gym environment with a given seed.
|
||||||
@ -26,7 +26,7 @@ def make_env_rank(env_id: str, seed: int, rank: int = 0):
|
|||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return lambda: make_env(env_id, seed + rank)
|
return lambda: make_env(env_id, seed + rank, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def make_env(env_id: str, seed, **kwargs):
|
def make_env(env_id: str, seed, **kwargs):
|
||||||
@ -54,6 +54,10 @@ def make_env(env_id: str, seed, **kwargs):
|
|||||||
from alr_envs.utils import make
|
from alr_envs.utils import make
|
||||||
env = make(env_id, seed=seed, **kwargs)
|
env = make(env_id, seed=seed, **kwargs)
|
||||||
|
|
||||||
|
assert env.base_step_limit == env.spec.max_episode_steps, \
|
||||||
|
f"The specified 'episode_length' of {env.spec.max_episode_steps} steps for gym is different from " \
|
||||||
|
f"the DMC environment specification of {env.base_step_limit} steps."
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
@ -94,6 +98,7 @@ def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs
|
|||||||
Returns: DMP wrapped gym env
|
Returns: DMP wrapped gym env
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
||||||
|
|
||||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
||||||
return DmpWrapper(_env, **mp_kwargs)
|
return DmpWrapper(_env, **mp_kwargs)
|
||||||
@ -110,6 +115,7 @@ def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwa
|
|||||||
Returns: DMP wrapped gym env
|
Returns: DMP wrapped gym env
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
||||||
|
|
||||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
||||||
return DetPMPWrapper(_env, **mp_kwargs)
|
return DetPMPWrapper(_env, **mp_kwargs)
|
||||||
@ -159,3 +165,23 @@ def make_contextual_env(env_id, context, seed, rank):
|
|||||||
# env = gym.make(env_id, context=context)
|
# env = gym.make(env_id, context=context)
|
||||||
# env.seed(seed + rank)
|
# env.seed(seed + rank)
|
||||||
return lambda: env
|
return lambda: env
|
||||||
|
|
||||||
|
|
||||||
|
def verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]):
|
||||||
|
"""
|
||||||
|
When using DMC check if a manually specified time limit matches the trajectory duration the MP receives.
|
||||||
|
Mostly, the time_limit for DMC is not specified and the default values from DMC are taken.
|
||||||
|
This check, however, can only been done after instantiating the environment.
|
||||||
|
It can be found in the BaseMP class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mp_time_limit: max trajectory length of mp in seconds
|
||||||
|
env_time_limit: max trajectory length of DMC environment in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
if mp_time_limit is not None and env_time_limit is not None:
|
||||||
|
assert mp_time_limit == env_time_limit, \
|
||||||
|
f"The manually specified 'time_limit' of {env_time_limit}s does not match " \
|
||||||
|
f"the duration of {mp_time_limit}s for the MP."
|
||||||
|
Loading…
Reference in New Issue
Block a user