2021-06-28 17:25:53 +02:00
|
|
|
import logging
|
2021-06-25 15:51:06 +02:00
|
|
|
from typing import Iterable, List, Type
|
|
|
|
|
2021-04-21 10:45:34 +02:00
|
|
|
import gym
|
|
|
|
|
2021-06-28 17:25:53 +02:00
|
|
|
from mp_env_api.env_wrappers.mp_env_wrapper import MPEnvWrapper
|
2021-06-25 15:51:06 +02:00
|
|
|
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
|
|
|
|
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
|
|
|
|
|
|
|
|
|
2021-06-28 17:25:53 +02:00
|
|
|
def make_env_rank(env_id: str, seed: int, rank: int = 0):
|
2021-06-25 15:51:06 +02:00
|
|
|
"""
|
2021-06-28 17:25:53 +02:00
|
|
|
TODO: Do we need this?
|
|
|
|
Generate a callable to create a new gym environment with a given seed.
|
2021-06-25 15:51:06 +02:00
|
|
|
The rank is added to the seed and can be used for example when using vector environments.
|
2021-06-28 17:25:53 +02:00
|
|
|
E.g. [make_env_rank("my_env_name-v0", 123, i) for i in range(8)] creates a list of 8 environments
|
2021-06-25 15:51:06 +02:00
|
|
|
with seeds 123 through 130.
|
|
|
|
Hence, testing environments should be seeded with a value which is offset by the number of training environments.
|
2021-06-28 17:25:53 +02:00
|
|
|
Here e.g. [make_env_rank("my_env_name-v0", 123 + 8, i) for i in range(5)] for 5 testing environmetns
|
2021-06-25 15:51:06 +02:00
|
|
|
|
|
|
|
Args:
|
|
|
|
env_id: name of the environment
|
|
|
|
seed: seed for deterministic behaviour
|
|
|
|
rank: environment rank for deterministic over multiple seeds behaviour
|
2021-04-21 10:45:34 +02:00
|
|
|
|
2021-06-25 15:51:06 +02:00
|
|
|
Returns:
|
|
|
|
|
|
|
|
"""
|
2021-06-28 17:25:53 +02:00
|
|
|
return lambda: make_env(env_id, seed + rank)
|
2021-04-21 10:45:34 +02:00
|
|
|
|
|
|
|
|
2021-06-28 17:25:53 +02:00
|
|
|
def make_env(env_id: str, seed, **kwargs):
|
|
|
|
"""
|
|
|
|
Converts an env_id to an environment with the gym API.
|
|
|
|
This also works for DeepMind Control Suite env_wrappers
|
|
|
|
for which domain name and task name are expected to be separated by "-".
|
|
|
|
Args:
|
|
|
|
env_id: gym name or env_id of the form "domain_name-task_name" for DMC tasks
|
|
|
|
**kwargs: Additional kwargs for the constructor such as pixel observations, etc.
|
|
|
|
|
|
|
|
Returns: Gym environment
|
|
|
|
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
# Gym
|
|
|
|
env = gym.make(env_id, **kwargs)
|
|
|
|
env.seed(seed)
|
|
|
|
except gym.error.Error:
|
|
|
|
# DMC
|
|
|
|
from alr_envs.utils import make
|
|
|
|
env = make(env_id, seed=seed, **kwargs)
|
|
|
|
|
|
|
|
return env
|
2021-04-21 10:45:34 +02:00
|
|
|
|
|
|
|
|
2021-06-28 17:25:53 +02:00
|
|
|
def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1, **kwargs):
|
2021-06-25 15:51:06 +02:00
|
|
|
"""
|
|
|
|
Helper function for creating a wrapped gym environment using MPs.
|
|
|
|
It adds all provided wrappers to the specified environment and verifies at least one MPEnvWrapper is
|
|
|
|
provided to expose the interface for MPs.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
env_id: name of the environment
|
|
|
|
wrappers: list of wrappers (at least an MPEnvWrapper),
|
2021-06-28 17:25:53 +02:00
|
|
|
seed: seed of environment
|
2021-06-25 15:51:06 +02:00
|
|
|
|
|
|
|
Returns: gym environment with all specified wrappers applied
|
|
|
|
|
|
|
|
"""
|
2021-06-28 17:25:53 +02:00
|
|
|
# _env = gym.make(env_id)
|
|
|
|
_env = make_env(env_id, seed, **kwargs)
|
2021-06-25 15:51:06 +02:00
|
|
|
|
2021-06-28 17:25:53 +02:00
|
|
|
assert any(issubclass(w, MPEnvWrapper) for w in wrappers),\
|
|
|
|
"At least an MPEnvWrapper is required in order to leverage motion primitive environments."
|
2021-06-25 15:51:06 +02:00
|
|
|
for w in wrappers:
|
|
|
|
_env = w(_env)
|
|
|
|
|
|
|
|
return _env
|
|
|
|
|
|
|
|
|
2021-06-28 17:25:53 +02:00
|
|
|
def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, **mp_kwargs):
|
2021-06-25 15:51:06 +02:00
|
|
|
"""
|
|
|
|
This can also be used standalone for manually building a custom DMP environment.
|
|
|
|
Args:
|
|
|
|
env_id: base_env_name,
|
|
|
|
wrappers: list of wrappers (at least an MPEnvWrapper),
|
2021-06-28 17:25:53 +02:00
|
|
|
seed: seed of environment
|
2021-06-25 15:51:06 +02:00
|
|
|
mp_kwargs: dict of at least {num_dof: int, num_basis: int} for DMP
|
|
|
|
|
|
|
|
Returns: DMP wrapped gym env
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
2021-06-28 17:25:53 +02:00
|
|
|
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed)
|
2021-06-25 15:51:06 +02:00
|
|
|
return DmpWrapper(_env, **mp_kwargs)
|
|
|
|
|
|
|
|
|
2021-06-28 17:25:53 +02:00
|
|
|
def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, **mp_kwargs):
|
2021-06-25 15:51:06 +02:00
|
|
|
"""
|
|
|
|
This can also be used standalone for manually building a custom Det ProMP environment.
|
|
|
|
Args:
|
|
|
|
env_id: base_env_name,
|
|
|
|
wrappers: list of wrappers (at least an MPEnvWrapper),
|
|
|
|
mp_kwargs: dict of at least {num_dof: int, num_basis: int, width: int}
|
|
|
|
|
|
|
|
Returns: DMP wrapped gym env
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
2021-06-28 17:25:53 +02:00
|
|
|
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed)
|
2021-06-25 15:51:06 +02:00
|
|
|
return DetPMPWrapper(_env, **mp_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
def make_dmp_env_helper(**kwargs):
|
|
|
|
"""
|
|
|
|
Helper function for registering a DMP gym environments.
|
|
|
|
Args:
|
|
|
|
**kwargs: expects at least the following:
|
|
|
|
{
|
|
|
|
"name": base_env_name,
|
|
|
|
"wrappers": list of wrappers (at least an MPEnvWrapper),
|
|
|
|
"mp_kwargs": dict of at least {num_dof: int, num_basis: int} for DMP
|
|
|
|
}
|
|
|
|
|
|
|
|
Returns: DMP wrapped gym env
|
|
|
|
|
|
|
|
"""
|
|
|
|
return make_dmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), **kwargs.get("mp_kwargs"))
|
|
|
|
|
|
|
|
|
|
|
|
def make_detpmp_env_helper(**kwargs):
|
|
|
|
"""
|
|
|
|
Helper function for registering ProMP gym environments.
|
|
|
|
This can also be used standalone for manually building a custom ProMP environment.
|
|
|
|
Args:
|
|
|
|
**kwargs: expects at least the following:
|
|
|
|
{
|
|
|
|
"name": base_env_name,
|
|
|
|
"wrappers": list of wrappers (at least an MPEnvWrapper),
|
|
|
|
"mp_kwargs": dict of at least {num_dof: int, num_basis: int, width: int}
|
|
|
|
}
|
2021-04-21 10:45:34 +02:00
|
|
|
|
2021-06-25 15:51:06 +02:00
|
|
|
Returns: DMP wrapped gym env
|
2021-04-21 10:45:34 +02:00
|
|
|
|
2021-06-25 15:51:06 +02:00
|
|
|
"""
|
|
|
|
return make_detpmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), **kwargs.get("mp_kwargs"))
|
2021-06-28 17:25:53 +02:00
|
|
|
|
|
|
|
|
|
|
|
def make_contextual_env(env_id, context, seed, rank):
|
|
|
|
env = gym.make(env_id, context=context)
|
|
|
|
env.seed(seed + rank)
|
|
|
|
return lambda: env
|