diff --git a/alr_envs/__init__.py b/alr_envs/__init__.py index 5b9bf10..04d0bf6 100644 --- a/alr_envs/__init__.py +++ b/alr_envs/__init__.py @@ -216,7 +216,7 @@ register( "random_start": False, "allow_self_collision": False, "allow_wall_collision": False, - "hole_width": None, + "hole_width": 0.25, "hole_depth": 1, "hole_x": None, "collision_penalty": 100, @@ -525,11 +525,13 @@ register( # max_episode_steps=1, kwargs={ "name": f"ball_in_cup-catch", + "time_limit": 1, + "episode_length": 50, "wrappers": [DMCBallInCupMPWrapper], "mp_kwargs": { "num_dof": 2, "num_basis": 5, - "duration": 20, + "duration": 1, "learn_goal": True, "alpha_phase": 2, "bandwidth_factor": 2, @@ -549,11 +551,13 @@ register( entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', kwargs={ "name": f"ball_in_cup-catch", + "time_limit": 1, + "episode_length": 50, "wrappers": [DMCBallInCupMPWrapper], "mp_kwargs": { "num_dof": 2, "num_basis": 5, - "duration": 20, + "duration": 1, "width": 0.025, "policy_type": "motor", "weights_scale": 0.2, diff --git a/alr_envs/examples/examples_dmc.py b/alr_envs/examples/examples_dmc.py index b877933..71eab74 100644 --- a/alr_envs/examples/examples_dmc.py +++ b/alr_envs/examples/examples_dmc.py @@ -38,6 +38,7 @@ def example_dmc(env_id="fish-swim", seed=1, iterations=1000, render=True): obs = env.reset() env.close() + del env def example_custom_dmc_and_mp(seed=1, iterations=1, render=True): @@ -78,7 +79,12 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True): "d_gains": 0.05 } } - env = make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs) + kwargs = { + "time_limit": 20, + "episode_length": 1000, + # "frame_skip": 1 + } + env = make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs) # OR for a deterministic ProMP: # env = make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args) @@ -105,6 +111,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True): obs = env.reset() env.close() + del env if __name__ == '__main__': @@ -113,16 +120,18 @@ if __name__ == '__main__': # For rendering DMC # export MUJOCO_GL="osmesa" + render = False - # Standard DMC Suite tasks - example_dmc("fish-swim", seed=10, iterations=1000, render=True) - - # Manipulation tasks - # Disclaimer: The vision versions are currently not integrated and yield an error - example_dmc("manipulation-reach_site_features", seed=10, iterations=250, render=True) + # # Standard DMC Suite tasks + # example_dmc("fish-swim", seed=10, iterations=1000, render=render) + # + # # Manipulation tasks + # # Disclaimer: The vision versions are currently not integrated and yield an error + # example_dmc("manipulation-reach_site_features", seed=10, iterations=250, render=render) # Gym + DMC hybrid task provided in the MP framework - example_dmc("dmc_ball_in_cup-catch_detpmp-v0", seed=10, iterations=1, render=True) + example_dmc("dmc_ball_in_cup-catch_detpmp-v0", seed=10, iterations=1, render=render) # Custom DMC task - example_custom_dmc_and_mp(seed=10, iterations=1, render=True) + # Different seed, because the episode is longer for this example and the name+seed combo is already registered above + example_custom_dmc_and_mp(seed=11, iterations=1, render=render) diff --git a/alr_envs/utils/__init__.py b/alr_envs/utils/__init__.py index 6bdf5ec..77fdd9f 100644 --- a/alr_envs/utils/__init__.py +++ b/alr_envs/utils/__init__.py @@ -1,4 +1,3 @@ -import collections import re from typing import Union diff --git a/alr_envs/utils/dmc2gym_wrapper.py b/alr_envs/utils/dmc2gym_wrapper.py index cb3658d..5e6a53d 100644 --- a/alr_envs/utils/dmc2gym_wrapper.py +++ b/alr_envs/utils/dmc2gym_wrapper.py @@ -125,6 +125,19 @@ class DMCWrapper(core.Env): def dt(self): return self._env.control_timestep() * self._frame_skip + @property + def base_step_limit(self): + """ + Returns: max_episode_steps of the underlying DMC env + + """ + # Accessing private attribute because DMC does not expose time_limit or step_limit. + # Only the current time_step/time as well as the control_timestep can be accessed. + try: + return (self._env._step_limit + self._frame_skip - 1) // self._frame_skip + except AttributeError as e: + return self._env._time_limit / self.dt + def seed(self, seed=None): self._action_space.seed(seed) self._observation_space.seed(seed) diff --git a/alr_envs/utils/make_env_helpers.py b/alr_envs/utils/make_env_helpers.py index 6a8fce2..7c39126 100644 --- a/alr_envs/utils/make_env_helpers.py +++ b/alr_envs/utils/make_env_helpers.py @@ -1,5 +1,5 @@ import logging -from typing import Iterable, List, Type +from typing import Iterable, List, Type, Union import gym @@ -8,7 +8,7 @@ from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper -def make_env_rank(env_id: str, seed: int, rank: int = 0): +def make_env_rank(env_id: str, seed: int, rank: int = 0, **kwargs): """ TODO: Do we need this? Generate a callable to create a new gym environment with a given seed. @@ -26,7 +26,7 @@ def make_env_rank(env_id: str, seed: int, rank: int = 0): Returns: """ - return lambda: make_env(env_id, seed + rank) + return lambda: make_env(env_id, seed + rank, **kwargs) def make_env(env_id: str, seed, **kwargs): @@ -54,6 +54,10 @@ def make_env(env_id: str, seed, **kwargs): from alr_envs.utils import make env = make(env_id, seed=seed, **kwargs) + assert env.base_step_limit == env.spec.max_episode_steps, \ + f"The specified 'episode_length' of {env.spec.max_episode_steps} steps for gym is different from " \ + f"the DMC environment specification of {env.base_step_limit} steps." + return env @@ -94,6 +98,7 @@ def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs Returns: DMP wrapped gym env """ + verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None)) _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) return DmpWrapper(_env, **mp_kwargs) @@ -110,6 +115,7 @@ def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwa Returns: DMP wrapped gym env """ + verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None)) _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) return DetPMPWrapper(_env, **mp_kwargs) @@ -159,3 +165,23 @@ def make_contextual_env(env_id, context, seed, rank): # env = gym.make(env_id, context=context) # env.seed(seed + rank) return lambda: env + + +def verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[None, float]): + """ + When using DMC check if a manually specified time limit matches the trajectory duration the MP receives. + Mostly, the time_limit for DMC is not specified and the default values from DMC are taken. + This check, however, can only been done after instantiating the environment. + It can be found in the BaseMP class. + + Args: + mp_time_limit: max trajectory length of mp in seconds + env_time_limit: max trajectory length of DMC environment in seconds + + Returns: + + """ + if mp_time_limit is not None and env_time_limit is not None: + assert mp_time_limit == env_time_limit, \ + f"The manually specified 'time_limit' of {env_time_limit}s does not match " \ + f"the duration of {mp_time_limit}s for the MP." diff --git a/setup.py b/setup.py index 7170fa6..b30ce41 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ setup( 'gym', 'PyQt5', 'matplotlib', - 'mp_env_api @ git+ssh://git@github.com/ALRhub/motion_primitive_env_api.git', + # 'mp_env_api @ git+ssh://git@github.com/ALRhub/motion_primitive_env_api.git', 'mujoco_py' ],