added sanity check for computed trajectory duration/length and environment episode length
This commit is contained in:
parent
c6b4cff3a3
commit
7e2f5d664b
@ -216,7 +216,7 @@ register(
|
||||
"random_start": False,
|
||||
"allow_self_collision": False,
|
||||
"allow_wall_collision": False,
|
||||
"hole_width": None,
|
||||
"hole_width": 0.25,
|
||||
"hole_depth": 1,
|
||||
"hole_x": None,
|
||||
"collision_penalty": 100,
|
||||
@ -525,11 +525,13 @@ register(
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": f"ball_in_cup-catch",
|
||||
"time_limit": 1,
|
||||
"episode_length": 50,
|
||||
"wrappers": [DMCBallInCupMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
"duration": 20,
|
||||
"duration": 1,
|
||||
"learn_goal": True,
|
||||
"alpha_phase": 2,
|
||||
"bandwidth_factor": 2,
|
||||
@ -549,11 +551,13 @@ register(
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": f"ball_in_cup-catch",
|
||||
"time_limit": 1,
|
||||
"episode_length": 50,
|
||||
"wrappers": [DMCBallInCupMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 2,
|
||||
"num_basis": 5,
|
||||
"duration": 20,
|
||||
"duration": 1,
|
||||
"width": 0.025,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
|
@ -38,6 +38,7 @@ def example_dmc(env_id="fish-swim", seed=1, iterations=1000, render=True):
|
||||
obs = env.reset()
|
||||
|
||||
env.close()
|
||||
del env
|
||||
|
||||
|
||||
def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
||||
@ -78,7 +79,12 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
||||
"d_gains": 0.05
|
||||
}
|
||||
}
|
||||
env = make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs)
|
||||
kwargs = {
|
||||
"time_limit": 20,
|
||||
"episode_length": 1000,
|
||||
# "frame_skip": 1
|
||||
}
|
||||
env = make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
|
||||
# OR for a deterministic ProMP:
|
||||
# env = make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
|
||||
|
||||
@ -105,6 +111,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
||||
obs = env.reset()
|
||||
|
||||
env.close()
|
||||
del env
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -113,16 +120,18 @@ if __name__ == '__main__':
|
||||
|
||||
# For rendering DMC
|
||||
# export MUJOCO_GL="osmesa"
|
||||
render = False
|
||||
|
||||
# Standard DMC Suite tasks
|
||||
example_dmc("fish-swim", seed=10, iterations=1000, render=True)
|
||||
|
||||
# Manipulation tasks
|
||||
# Disclaimer: The vision versions are currently not integrated and yield an error
|
||||
example_dmc("manipulation-reach_site_features", seed=10, iterations=250, render=True)
|
||||
# # Standard DMC Suite tasks
|
||||
# example_dmc("fish-swim", seed=10, iterations=1000, render=render)
|
||||
#
|
||||
# # Manipulation tasks
|
||||
# # Disclaimer: The vision versions are currently not integrated and yield an error
|
||||
# example_dmc("manipulation-reach_site_features", seed=10, iterations=250, render=render)
|
||||
|
||||
# Gym + DMC hybrid task provided in the MP framework
|
||||
example_dmc("dmc_ball_in_cup-catch_detpmp-v0", seed=10, iterations=1, render=True)
|
||||
example_dmc("dmc_ball_in_cup-catch_detpmp-v0", seed=10, iterations=1, render=render)
|
||||
|
||||
# Custom DMC task
|
||||
example_custom_dmc_and_mp(seed=10, iterations=1, render=True)
|
||||
# Different seed, because the episode is longer for this example and the name+seed combo is already registered above
|
||||
example_custom_dmc_and_mp(seed=11, iterations=1, render=render)
|
||||
|
@ -1,4 +1,3 @@
|
||||
import collections
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
|
@ -125,6 +125,19 @@ class DMCWrapper(core.Env):
|
||||
def dt(self):
|
||||
return self._env.control_timestep() * self._frame_skip
|
||||
|
||||
@property
|
||||
def base_step_limit(self):
|
||||
"""
|
||||
Returns: max_episode_steps of the underlying DMC env
|
||||
|
||||
"""
|
||||
# Accessing private attribute because DMC does not expose time_limit or step_limit.
|
||||
# Only the current time_step/time as well as the control_timestep can be accessed.
|
||||
try:
|
||||
return (self._env._step_limit + self._frame_skip - 1) // self._frame_skip
|
||||
except AttributeError as e:
|
||||
return self._env._time_limit / self.dt
|
||||
|
||||
def seed(self, seed=None):
|
||||
self._action_space.seed(seed)
|
||||
self._observation_space.seed(seed)
|
||||
|
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Iterable, List, Type
|
||||
from typing import Iterable, List, Type, Union
|
||||
|
||||
import gym
|
||||
|
||||
@ -8,7 +8,7 @@ from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
||||
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
||||
|
||||
|
||||
def make_env_rank(env_id: str, seed: int, rank: int = 0):
|
||||
def make_env_rank(env_id: str, seed: int, rank: int = 0, **kwargs):
|
||||
"""
|
||||
TODO: Do we need this?
|
||||
Generate a callable to create a new gym environment with a given seed.
|
||||
@ -26,7 +26,7 @@ def make_env_rank(env_id: str, seed: int, rank: int = 0):
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return lambda: make_env(env_id, seed + rank)
|
||||
return lambda: make_env(env_id, seed + rank, **kwargs)
|
||||
|
||||
|
||||
def make_env(env_id: str, seed, **kwargs):
|
||||
@ -54,6 +54,10 @@ def make_env(env_id: str, seed, **kwargs):
|
||||
from alr_envs.utils import make
|
||||
env = make(env_id, seed=seed, **kwargs)
|
||||
|
||||
assert env.base_step_limit == env.spec.max_episode_steps, \
|
||||
f"The specified 'episode_length' of {env.spec.max_episode_steps} steps for gym is different from " \
|
||||
f"the DMC environment specification of {env.base_step_limit} steps."
|
||||
|
||||
return env
|
||||
|
||||
|
||||
@ -94,6 +98,7 @@ def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs
|
||||
Returns: DMP wrapped gym env
|
||||
|
||||
"""
|
||||
verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
||||
|
||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
||||
return DmpWrapper(_env, **mp_kwargs)
|
||||
@ -110,6 +115,7 @@ def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwa
|
||||
Returns: DMP wrapped gym env
|
||||
|
||||
"""
|
||||
verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
|
||||
|
||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
||||
return DetPMPWrapper(_env, **mp_kwargs)
|
||||
@ -159,3 +165,23 @@ def make_contextual_env(env_id, context, seed, rank):
|
||||
# env = gym.make(env_id, context=context)
|
||||
# env.seed(seed + rank)
|
||||
return lambda: env
|
||||
|
||||
|
||||
def verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]):
|
||||
"""
|
||||
When using DMC check if a manually specified time limit matches the trajectory duration the MP receives.
|
||||
Mostly, the time_limit for DMC is not specified and the default values from DMC are taken.
|
||||
This check, however, can only been done after instantiating the environment.
|
||||
It can be found in the BaseMP class.
|
||||
|
||||
Args:
|
||||
mp_time_limit: max trajectory length of mp in seconds
|
||||
env_time_limit: max trajectory length of DMC environment in seconds
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if mp_time_limit is not None and env_time_limit is not None:
|
||||
assert mp_time_limit == env_time_limit, \
|
||||
f"The manually specified 'time_limit' of {env_time_limit}s does not match " \
|
||||
f"the duration of {mp_time_limit}s for the MP."
|
||||
|
Loading…
Reference in New Issue
Block a user