finalized examples and added seed control
This commit is contained in:
parent
3b215cd877
commit
7c04b25eec
@ -4,6 +4,7 @@ from gym.envs.registration import register
|
|||||||
from alr_envs.classic_control.hole_reacher.hole_reacher_mp_wrapper import HoleReacherMPWrapper
|
from alr_envs.classic_control.hole_reacher.hole_reacher_mp_wrapper import HoleReacherMPWrapper
|
||||||
from alr_envs.classic_control.simple_reacher.simple_reacher_mp_wrapper import SimpleReacherMPWrapper
|
from alr_envs.classic_control.simple_reacher.simple_reacher_mp_wrapper import SimpleReacherMPWrapper
|
||||||
from alr_envs.classic_control.viapoint_reacher.viapoint_reacher_mp_wrapper import ViaPointReacherMPWrapper
|
from alr_envs.classic_control.viapoint_reacher.viapoint_reacher_mp_wrapper import ViaPointReacherMPWrapper
|
||||||
|
from alr_envs.dmc.Ball_in_the_cup_mp_wrapper import DMCBallInCupMPWrapper
|
||||||
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_mp_wrapper import BallInACupMPWrapper
|
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_mp_wrapper import BallInACupMPWrapper
|
||||||
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_positional_wrapper import BallInACupPositionalWrapper
|
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_positional_wrapper import BallInACupPositionalWrapper
|
||||||
from alr_envs.stochastic_search.functions.f_rosenbrock import Rosenbrock
|
from alr_envs.stochastic_search.functions.f_rosenbrock import Rosenbrock
|
||||||
@ -518,6 +519,48 @@ register(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## DMC
|
||||||
|
|
||||||
|
register(
|
||||||
|
id=f'dmc_ball_in_cup_dmp-v0',
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||||
|
# max_episode_steps=1,
|
||||||
|
kwargs={
|
||||||
|
"name": f"ball_in_cup-catch",
|
||||||
|
"wrappers": [DMCBallInCupMPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 2,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 2,
|
||||||
|
"learn_goal": True,
|
||||||
|
"alpha_phase": 2,
|
||||||
|
"bandwidth_factor": 2,
|
||||||
|
"policy_type": "velocity",
|
||||||
|
"weights_scale": 50,
|
||||||
|
"goal_scale": 0.1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
register(
|
||||||
|
id=f'dmc_ball_in_cup_detpmp-v0',
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
|
kwargs={
|
||||||
|
"name": f"ball_in_cup-catch",
|
||||||
|
"wrappers": [DMCBallInCupMPWrapper],
|
||||||
|
"mp_kwargs": {
|
||||||
|
"num_dof": 2,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 2,
|
||||||
|
"width": 0.025,
|
||||||
|
"policy_type": "velocity",
|
||||||
|
"weights_scale": 0.2,
|
||||||
|
"zero_start": True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# BBO functions
|
# BBO functions
|
||||||
|
|
||||||
for dim in [5, 10, 25, 50, 100]:
|
for dim in [5, 10, 25, 50, 100]:
|
||||||
|
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
from mp_env_api.env_wrappers.mp_env_wrapper import MPEnvWrapper
|
from mp_env_api.env_wrappers.mp_env_wrapper import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
class BallInCupMPWrapper(MPEnvWrapper):
|
class DMCBallInCupMPWrapper(MPEnvWrapper):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def active_obs(self):
|
def active_obs(self):
|
||||||
|
@ -1,14 +1,15 @@
|
|||||||
from alr_envs.dmc.Ball_in_the_cup_mp_wrapper import BallInCupMPWrapper
|
from alr_envs.dmc.Ball_in_the_cup_mp_wrapper import DMCBallInCupMPWrapper
|
||||||
from alr_envs.utils.make_env_helpers import make_dmp_env, make_env
|
from alr_envs.utils.make_env_helpers import make_dmp_env, make_env
|
||||||
|
|
||||||
|
|
||||||
def example_dmc(env_name="fish-swim", seed=1):
|
def example_dmc(env_name="fish-swim", seed=1, iterations=1000):
|
||||||
env = make_env(env_name, seed)
|
env = make_env(env_name, seed)
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
print(obs)
|
||||||
|
|
||||||
# number of samples/full trajectories (multiple environment steps)
|
# number of samples(multiple environment steps)
|
||||||
for i in range(2000):
|
for i in range(10):
|
||||||
ac = env.action_space.sample()
|
ac = env.action_space.sample()
|
||||||
obs, reward, done, info = env.step(ac)
|
obs, reward, done, info = env.step(ac)
|
||||||
rewards += reward
|
rewards += reward
|
||||||
@ -37,7 +38,7 @@ def example_custom_dmc_and_mp(seed=1):
|
|||||||
# Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper.
|
# Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper.
|
||||||
# You can also add other gym.Wrappers in case they are needed.
|
# You can also add other gym.Wrappers in case they are needed.
|
||||||
# wrappers = [HoleReacherMPWrapper]
|
# wrappers = [HoleReacherMPWrapper]
|
||||||
wrappers = [BallInCupMPWrapper]
|
wrappers = [DMCBallInCupMPWrapper]
|
||||||
mp_kwargs = {
|
mp_kwargs = {
|
||||||
"num_dof": 2, # env.start_pos
|
"num_dof": 2, # env.start_pos
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -69,5 +70,14 @@ def example_custom_dmc_and_mp(seed=1):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
example_dmc()
|
# Disclaimer: DMC environments require the seed to be specified in the beginning.
|
||||||
|
# Adjusting it afterwards with env.seed() is not recommended as it does not affect the underlying physics.
|
||||||
|
|
||||||
|
# Standard DMC task
|
||||||
|
example_dmc("fish_swim", seed=10, iterations=1000)
|
||||||
|
|
||||||
|
# Gym + DMC hybrid task provided in the MP framework
|
||||||
|
example_dmc("dmc_ball_in_cup_dmp-v0", seed=10, iterations=10)
|
||||||
|
|
||||||
|
# Custom DMC task
|
||||||
example_custom_dmc_and_mp()
|
example_custom_dmc_and_mp()
|
||||||
|
@ -15,6 +15,22 @@ def example_mp(env_name="alr_envs:HoleReacherDMP-v1", seed=1):
|
|||||||
# While in this case gym.make() is possible to use as well, we recommend our custom make env function.
|
# While in this case gym.make() is possible to use as well, we recommend our custom make env function.
|
||||||
# First, it already takes care of seeding and second enables the use of DMC tasks within the gym interface.
|
# First, it already takes care of seeding and second enables the use of DMC tasks within the gym interface.
|
||||||
env = make_env(env_name, seed)
|
env = make_env(env_name, seed)
|
||||||
|
|
||||||
|
# Changing the mp_kwargs is possible by providing them to gym.
|
||||||
|
# E.g. here by providing way to many basis functions
|
||||||
|
# mp_kwargs = {
|
||||||
|
# "num_dof": 5,
|
||||||
|
# "num_basis": 1000,
|
||||||
|
# "duration": 2,
|
||||||
|
# "learn_goal": True,
|
||||||
|
# "alpha_phase": 2,
|
||||||
|
# "bandwidth_factor": 2,
|
||||||
|
# "policy_type": "velocity",
|
||||||
|
# "weights_scale": 50,
|
||||||
|
# "goal_scale": 0.1
|
||||||
|
# }
|
||||||
|
# env = make_env(env_name, seed, mp_kwargs=mp_kwargs)
|
||||||
|
|
||||||
rewards = 0
|
rewards = 0
|
||||||
# env.render(mode=None)
|
# env.render(mode=None)
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
@ -40,8 +56,9 @@ def example_mp(env_name="alr_envs:HoleReacherDMP-v1", seed=1):
|
|||||||
def example_custom_mp(seed=1):
|
def example_custom_mp(seed=1):
|
||||||
"""
|
"""
|
||||||
Example for running a custom motion primitive based environments.
|
Example for running a custom motion primitive based environments.
|
||||||
Our already registered environments follow the same structure, but do not directly allow for modifications.
|
Our already registered environments follow the same structure.
|
||||||
Hence, this also allows to adjust hyperparameters of the motion primitives more easily.
|
Hence, this also allows to adjust hyperparameters of the motion primitives.
|
||||||
|
Yet, we recommend the method above if you are just interested in chaining those parameters for existing tasks.
|
||||||
We appreciate PRs for custom environments (especially MP wrappers of existing tasks)
|
We appreciate PRs for custom environments (especially MP wrappers of existing tasks)
|
||||||
for our repo: https://github.com/ALRhub/alr_envs/
|
for our repo: https://github.com/ALRhub/alr_envs/
|
||||||
Args:
|
Args:
|
||||||
|
@ -35,7 +35,7 @@ def make(
|
|||||||
|
|
||||||
if env_id not in gym.envs.registry.env_specs:
|
if env_id not in gym.envs.registry.env_specs:
|
||||||
task_kwargs = {}
|
task_kwargs = {}
|
||||||
if seed is not None:
|
# if seed is not None:
|
||||||
task_kwargs['random'] = seed
|
task_kwargs['random'] = seed
|
||||||
if time_limit is not None:
|
if time_limit is not None:
|
||||||
task_kwargs['time_limit'] = time_limit
|
task_kwargs['time_limit'] = time_limit
|
||||||
|
@ -42,6 +42,10 @@ def make_env(env_id: str, seed, **kwargs):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Add seed to kwargs in case it is a predefined dmc environment.
|
||||||
|
if env_id.startswith("dmc"):
|
||||||
|
kwargs.update({"seed": seed})
|
||||||
|
|
||||||
# Gym
|
# Gym
|
||||||
env = gym.make(env_id, **kwargs)
|
env = gym.make(env_id, **kwargs)
|
||||||
env.seed(seed)
|
env.seed(seed)
|
||||||
@ -125,7 +129,9 @@ def make_dmp_env_helper(**kwargs):
|
|||||||
Returns: DMP wrapped gym env
|
Returns: DMP wrapped gym env
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return make_dmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), **kwargs.get("mp_kwargs"))
|
seed = kwargs.get("seed", None)
|
||||||
|
return make_dmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
|
||||||
|
**kwargs.get("mp_kwargs"))
|
||||||
|
|
||||||
|
|
||||||
def make_detpmp_env_helper(**kwargs):
|
def make_detpmp_env_helper(**kwargs):
|
||||||
@ -143,7 +149,9 @@ def make_detpmp_env_helper(**kwargs):
|
|||||||
Returns: DMP wrapped gym env
|
Returns: DMP wrapped gym env
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return make_detpmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), **kwargs.get("mp_kwargs"))
|
seed = kwargs.get("seed", None)
|
||||||
|
return make_detpmp_env(env_id=kwargs.pop("name"), wrappers=kwargs.pop("wrappers"), seed=seed,
|
||||||
|
**kwargs.get("mp_kwargs"))
|
||||||
|
|
||||||
|
|
||||||
def make_contextual_env(env_id, context, seed, rank):
|
def make_contextual_env(env_id, context, seed, rank):
|
||||||
|
Loading…
Reference in New Issue
Block a user