improved docs and modularity of env helpers
This commit is contained in:
parent
29b8c3a6c7
commit
0a1e55d97b
@ -1,11 +1,29 @@
|
||||
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
||||
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
||||
from typing import Iterable, List, Type
|
||||
|
||||
import gym
|
||||
from gym.vector.utils import write_to_shared_memory
|
||||
import sys
|
||||
|
||||
from mp_env_api.envs.mp_env_wrapper import MPEnvWrapper
|
||||
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
||||
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
||||
|
||||
|
||||
def make_env(env_id, seed, rank):
|
||||
def make_env(env_id: str, seed: int, rank: int = 0):
|
||||
"""
|
||||
Create a new gym environment with given seed.
|
||||
The rank is added to the seed and can be used for example when using vector environments.
|
||||
E.g. [make_env("my_env_name-v0", 123, i) for i in range(8)] creates a list of 8 environments
|
||||
with seeds 123 through 130.
|
||||
Hence, testing environments should be seeded with a value which is offset by the number of training environments.
|
||||
Here e.g. [make_env("my_env_name-v0", 123 + 8, i) for i in range(5)] for 5 testing environmetns
|
||||
|
||||
Args:
|
||||
env_id: name of the environment
|
||||
seed: seed for deterministic behaviour
|
||||
rank: environment rank for deterministic over multiple seeds behaviour
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
env = gym.make(env_id)
|
||||
env.seed(seed + rank)
|
||||
return lambda: env
|
||||
@ -17,17 +35,90 @@ def make_contextual_env(env_id, context, seed, rank):
|
||||
return lambda: env
|
||||
|
||||
|
||||
def make_dmp_env(**kwargs):
|
||||
name = kwargs.pop("name")
|
||||
_env = gym.make(name)
|
||||
for wrapper in kwargs.pop("wrappers"):
|
||||
_env = wrapper(_env)
|
||||
return DmpWrapper(_env, **kwargs.get("mp_kwargs"))
|
||||
def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]]):
|
||||
"""
|
||||
Helper function for creating a wrapped gym environment using MPs.
|
||||
It adds all provided wrappers to the specified environment and verifies at least one MPEnvWrapper is
|
||||
provided to expose the interface for MPs.
|
||||
|
||||
Args:
|
||||
env_id: name of the environment
|
||||
wrappers: list of wrappers (at least an MPEnvWrapper),
|
||||
|
||||
Returns: gym environment with all specified wrappers applied
|
||||
|
||||
"""
|
||||
_env = gym.make(env_id)
|
||||
|
||||
assert any(issubclass(w, MPEnvWrapper) for w in wrappers)
|
||||
for w in wrappers:
|
||||
_env = w(_env)
|
||||
|
||||
return _env
|
||||
|
||||
|
||||
def make_detpmp_env(**kwargs):
|
||||
name = kwargs.pop("name")
|
||||
_env = gym.make(name)
|
||||
for wrapper in kwargs.pop("wrappers"):
|
||||
_env = wrapper(_env)
|
||||
return DetPMPWrapper(_env, **kwargs.get("mp_kwargs"))
|
||||
def make_dmp_env(env_id: str, wrappers: Iterable, **mp_kwargs):
|
||||
"""
|
||||
This can also be used standalone for manually building a custom DMP environment.
|
||||
Args:
|
||||
env_id: base_env_name,
|
||||
wrappers: list of wrappers (at least an MPEnvWrapper),
|
||||
mp_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP
|
||||
|
||||
Returns: DMP wrapped gym env
|
||||
|
||||
"""
|
||||
|
||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers)
|
||||
return DmpWrapper(_env, **mp_kwargs)
|
||||
|
||||
|
||||
def make_detpmp_env(env_id: str, wrappers: Iterable, **mp_kwargs):
|
||||
"""
|
||||
This can also be used standalone for manually building a custom Det ProMP environment.
|
||||
Args:
|
||||
env_id: base_env_name,
|
||||
wrappers: list of wrappers (at least an MPEnvWrapper),
|
||||
mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int}
|
||||
|
||||
Returns: DMP wrapped gym env
|
||||
|
||||
"""
|
||||
|
||||
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers)
|
||||
return DetPMPWrapper(_env, **mp_kwargs)
|
||||
|
||||
|
||||
def make_dmp_env_helper(**kwargs):
|
||||
"""
|
||||
Helper function for registering a DMP gym environments.
|
||||
Args:
|
||||
**kwargs: expects at least the following:
|
||||
{
|
||||
"name": base_env_name,
|
||||
"wrappers": list of wrappers (at least an MPEnvWrapper),
|
||||
"mp_kwargs": dict of at least {num_dof: int, num_basis: int} for DMP
|
||||
}
|
||||
|
||||
Returns: DMP wrapped gym env
|
||||
|
||||
"""
|
||||
return make_dmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), **kwargs.get("mp_kwargs"))
|
||||
|
||||
|
||||
def make_detpmp_env_helper(**kwargs):
|
||||
"""
|
||||
Helper function for registering ProMP gym environments.
|
||||
This can also be used standalone for manually building a custom ProMP environment.
|
||||
Args:
|
||||
**kwargs: expects at least the following:
|
||||
{
|
||||
"name": base_env_name,
|
||||
"wrappers": list of wrappers (at least an MPEnvWrapper),
|
||||
"mp_kwargs": dict of at least {num_dof: int, num_basis: int, width: int}
|
||||
}
|
||||
|
||||
Returns: DMP wrapped gym env
|
||||
|
||||
"""
|
||||
return make_detpmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), **kwargs.get("mp_kwargs"))
|
||||
|
Loading…
Reference in New Issue
Block a user