From 0a1e55d97bc1f5bc6fd41e3273e01f94f8d6e460 Mon Sep 17 00:00:00 2001 From: ottofabian Date: Fri, 25 Jun 2021 15:51:06 +0200 Subject: [PATCH] improved docs and modularity of env helpers --- alr_envs/utils/make_env_helpers.py | 125 +++++++++++++++++++++++++---- 1 file changed, 108 insertions(+), 17 deletions(-) diff --git a/alr_envs/utils/make_env_helpers.py b/alr_envs/utils/make_env_helpers.py index 29ddb9d..246cd7a 100644 --- a/alr_envs/utils/make_env_helpers.py +++ b/alr_envs/utils/make_env_helpers.py @@ -1,11 +1,29 @@ -from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper -from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper +from typing import Iterable, List, Type + import gym -from gym.vector.utils import write_to_shared_memory -import sys + +from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper +from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper +from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper -def make_env(env_id, seed, rank): +def make_env(env_id: str, seed: int, rank: int = 0): + """ + Create a new gym environment with given seed. + The rank is added to the seed and can be used for example when using vector environments. + E.g. [make_env("my_env_name-v0", 123, i) for i in range(8)] creates a list of 8 environments + with seeds 123 through 130. + Hence, testing environments should be seeded with a value which is offset by the number of training environments. + Here e.g. [make_env("my_env_name-v0", 123 + 8, i) for i in range(5)] for 5 testing environmetns + + Args: + env_id: name of the environment + seed: seed for deterministic behaviour + rank: environment rank for deterministic over multiple seeds behaviour + + Returns: + + """ env = gym.make(env_id) env.seed(seed + rank) return lambda: env @@ -17,17 +35,90 @@ def make_contextual_env(env_id, context, seed, rank): return lambda: env -def make_dmp_env(**kwargs): - name = kwargs.pop("name") - _env = gym.make(name) - for wrapper in kwargs.pop("wrappers"): - _env = wrapper(_env) - return DmpWrapper(_env, **kwargs.get("mp_kwargs")) +def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]]): + """ + Helper function for creating a wrapped gym environment using MPs. + It adds all provided wrappers to the specified environment and verifies at least one MPEnvWrapper is + provided to expose the interface for MPs. + + Args: + env_id: name of the environment + wrappers: list of wrappers (at least an MPEnvWrapper), + + Returns: gym environment with all specified wrappers applied + + """ + _env = gym.make(env_id) + + assert any(issubclass(w, MPEnvWrapper) for w in wrappers) + for w in wrappers: + _env = w(_env) + + return _env -def make_detpmp_env(**kwargs): - name = kwargs.pop("name") - _env = gym.make(name) - for wrapper in kwargs.pop("wrappers"): - _env = wrapper(_env) - return DetPMPWrapper(_env, **kwargs.get("mp_kwargs")) +def make_dmp_env(env_id: str, wrappers: Iterable, **mp_kwargs): + """ + This can also be used standalone for manually building a custom DMP environment. + Args: + env_id: base_env_name, + wrappers: list of wrappers (at least an MPEnvWrapper), + mp_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP + + Returns: DMP wrapped gym env + + """ + + _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers) + return DmpWrapper(_env, **mp_kwargs) + + +def make_detpmp_env(env_id: str, wrappers: Iterable, **mp_kwargs): + """ + This can also be used standalone for manually building a custom Det ProMP environment. + Args: + env_id: base_env_name, + wrappers: list of wrappers (at least an MPEnvWrapper), + mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int} + + Returns: DMP wrapped gym env + + """ + + _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers) + return DetPMPWrapper(_env, **mp_kwargs) + + +def make_dmp_env_helper(**kwargs): + """ + Helper function for registering a DMP gym environments. + Args: + **kwargs: expects at least the following: + { + "name": base_env_name, + "wrappers": list of wrappers (at least an MPEnvWrapper), + "mp_kwargs": dict of at least {num_dof: int, num_basis: int} for DMP + } + + Returns: DMP wrapped gym env + + """ + return make_dmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), **kwargs.get("mp_kwargs")) + + +def make_detpmp_env_helper(**kwargs): + """ + Helper function for registering ProMP gym environments. + This can also be used standalone for manually building a custom ProMP environment. + Args: + **kwargs: expects at least the following: + { + "name": base_env_name, + "wrappers": list of wrappers (at least an MPEnvWrapper), + "mp_kwargs": dict of at least {num_dof: int, num_basis: int, width: int} + } + + Returns: DMP wrapped gym env + + """ + return make_detpmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), **kwargs.get("mp_kwargs"))