added sanity check for computed trajectory duration/length and environment episode length

This commit is contained in:
ottofabian 2021-07-19 14:05:25 +02:00
parent c6b4cff3a3
commit 7e2f5d664b
6 changed files with 68 additions and 17 deletions

View File

@ -216,7 +216,7 @@ register(
"random_start": False, "random_start": False,
"allow_self_collision": False, "allow_self_collision": False,
"allow_wall_collision": False, "allow_wall_collision": False,
"hole_width": None, "hole_width": 0.25,
"hole_depth": 1, "hole_depth": 1,
"hole_x": None, "hole_x": None,
"collision_penalty": 100, "collision_penalty": 100,
@ -525,11 +525,13 @@ register(
# max_episode_steps=1, # max_episode_steps=1,
kwargs={ kwargs={
"name": f"ball_in_cup-catch", "name": f"ball_in_cup-catch",
"time_limit": 1,
"episode_length": 50,
"wrappers": [DMCBallInCupMPWrapper], "wrappers": [DMCBallInCupMPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 2, "num_dof": 2,
"num_basis": 5, "num_basis": 5,
"duration": 20, "duration": 1,
"learn_goal": True, "learn_goal": True,
"alpha_phase": 2, "alpha_phase": 2,
"bandwidth_factor": 2, "bandwidth_factor": 2,
@ -549,11 +551,13 @@ register(
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
kwargs={ kwargs={
"name": f"ball_in_cup-catch", "name": f"ball_in_cup-catch",
"time_limit": 1,
"episode_length": 50,
"wrappers": [DMCBallInCupMPWrapper], "wrappers": [DMCBallInCupMPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 2, "num_dof": 2,
"num_basis": 5, "num_basis": 5,
"duration": 20, "duration": 1,
"width": 0.025, "width": 0.025,
"policy_type": "motor", "policy_type": "motor",
"weights_scale": 0.2, "weights_scale": 0.2,

View File

@ -38,6 +38,7 @@ def example_dmc(env_id="fish-swim", seed=1, iterations=1000, render=True):
obs = env.reset() obs = env.reset()
env.close() env.close()
del env
def example_custom_dmc_and_mp(seed=1, iterations=1, render=True): def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
@ -78,7 +79,12 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
"d_gains": 0.05 "d_gains": 0.05
} }
} }
env = make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs) kwargs = {
"time_limit": 20,
"episode_length": 1000,
# "frame_skip": 1
}
env = make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
# OR for a deterministic ProMP: # OR for a deterministic ProMP:
# env = make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args) # env = make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
@ -105,6 +111,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
obs = env.reset() obs = env.reset()
env.close() env.close()
del env
if __name__ == '__main__': if __name__ == '__main__':
@ -113,16 +120,18 @@ if __name__ == '__main__':
# For rendering DMC # For rendering DMC
# export MUJOCO_GL="osmesa" # export MUJOCO_GL="osmesa"
render = False
# Standard DMC Suite tasks # # Standard DMC Suite tasks
example_dmc("fish-swim", seed=10, iterations=1000, render=True) # example_dmc("fish-swim", seed=10, iterations=1000, render=render)
#
# Manipulation tasks # # Manipulation tasks
# Disclaimer: The vision versions are currently not integrated and yield an error # # Disclaimer: The vision versions are currently not integrated and yield an error
example_dmc("manipulation-reach_site_features", seed=10, iterations=250, render=True) # example_dmc("manipulation-reach_site_features", seed=10, iterations=250, render=render)
# Gym + DMC hybrid task provided in the MP framework # Gym + DMC hybrid task provided in the MP framework
example_dmc("dmc_ball_in_cup-catch_detpmp-v0", seed=10, iterations=1, render=True) example_dmc("dmc_ball_in_cup-catch_detpmp-v0", seed=10, iterations=1, render=render)
# Custom DMC task # Custom DMC task
example_custom_dmc_and_mp(seed=10, iterations=1, render=True) # Different seed, because the episode is longer for this example and the name+seed combo is already registered above
example_custom_dmc_and_mp(seed=11, iterations=1, render=render)

View File

@ -1,4 +1,3 @@
import collections
import re import re
from typing import Union from typing import Union

View File

@ -125,6 +125,19 @@ class DMCWrapper(core.Env):
def dt(self): def dt(self):
return self._env.control_timestep() * self._frame_skip return self._env.control_timestep() * self._frame_skip
@property
def base_step_limit(self):
"""
Returns: max_episode_steps of the underlying DMC env
"""
# Accessing private attribute because DMC does not expose time_limit or step_limit.
# Only the current time_step/time as well as the control_timestep can be accessed.
try:
return (self._env._step_limit + self._frame_skip - 1) // self._frame_skip
except AttributeError as e:
return self._env._time_limit / self.dt
def seed(self, seed=None): def seed(self, seed=None):
self._action_space.seed(seed) self._action_space.seed(seed)
self._observation_space.seed(seed) self._observation_space.seed(seed)

View File

@ -1,5 +1,5 @@
import logging import logging
from typing import Iterable, List, Type from typing import Iterable, List, Type, Union
import gym import gym
@ -8,7 +8,7 @@ from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
def make_env_rank(env_id: str, seed: int, rank: int = 0): def make_env_rank(env_id: str, seed: int, rank: int = 0, **kwargs):
""" """
TODO: Do we need this? TODO: Do we need this?
Generate a callable to create a new gym environment with a given seed. Generate a callable to create a new gym environment with a given seed.
@ -26,7 +26,7 @@ def make_env_rank(env_id: str, seed: int, rank: int = 0):
Returns: Returns:
""" """
return lambda: make_env(env_id, seed + rank) return lambda: make_env(env_id, seed + rank, **kwargs)
def make_env(env_id: str, seed, **kwargs): def make_env(env_id: str, seed, **kwargs):
@ -54,6 +54,10 @@ def make_env(env_id: str, seed, **kwargs):
from alr_envs.utils import make from alr_envs.utils import make
env = make(env_id, seed=seed, **kwargs) env = make(env_id, seed=seed, **kwargs)
assert env.base_step_limit == env.spec.max_episode_steps, \
f"The specified 'episode_length' of {env.spec.max_episode_steps} steps for gym is different from " \
f"the DMC environment specification of {env.base_step_limit} steps."
return env return env
@ -94,6 +98,7 @@ def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs
Returns: DMP wrapped gym env Returns: DMP wrapped gym env
""" """
verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
return DmpWrapper(_env, **mp_kwargs) return DmpWrapper(_env, **mp_kwargs)
@ -110,6 +115,7 @@ def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwa
Returns: DMP wrapped gym env Returns: DMP wrapped gym env
""" """
verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
return DetPMPWrapper(_env, **mp_kwargs) return DetPMPWrapper(_env, **mp_kwargs)
@ -159,3 +165,23 @@ def make_contextual_env(env_id, context, seed, rank):
# env = gym.make(env_id, context=context) # env = gym.make(env_id, context=context)
# env.seed(seed + rank) # env.seed(seed + rank)
return lambda: env return lambda: env
def verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]):
"""
When using DMC check if a manually specified time limit matches the trajectory duration the MP receives.
Mostly, the time_limit for DMC is not specified and the default values from DMC are taken.
This check, however, can only been done after instantiating the environment.
It can be found in the BaseMP class.
Args:
mp_time_limit: max trajectory length of mp in seconds
env_time_limit: max trajectory length of DMC environment in seconds
Returns:
"""
if mp_time_limit is not None and env_time_limit is not None:
assert mp_time_limit == env_time_limit, \
f"The manually specified 'time_limit' of {env_time_limit}s does not match " \
f"the duration of {mp_time_limit}s for the MP."

View File

@ -9,7 +9,7 @@ setup(
'gym', 'gym',
'PyQt5', 'PyQt5',
'matplotlib', 'matplotlib',
'mp_env_api @ git+ssh://git@github.com/ALRhub/motion_primitive_env_api.git', # 'mp_env_api @ git+ssh://git@github.com/ALRhub/motion_primitive_env_api.git',
'mujoco_py' 'mujoco_py'
], ],