fixed OpenAI fetch tasks; added nicer imports

This commit is contained in:
ottofabian 2021-07-30 11:59:02 +02:00
parent f5fcbf7f54
commit a11965827d
41 changed files with 316 additions and 159 deletions

View File

@ -1,15 +1,12 @@
import numpy as np
from gym.envs.registration import register from gym.envs.registration import register
from gym.wrappers import FlattenObservation
from alr_envs.classic_control.hole_reacher.hole_reacher_mp_wrapper import HoleReacherMPWrapper from alr_envs import classic_control, dmc, open_ai
from alr_envs.classic_control.simple_reacher.simple_reacher_mp_wrapper import SimpleReacherMPWrapper
from alr_envs.classic_control.viapoint_reacher.viapoint_reacher_mp_wrapper import ViaPointReacherMPWrapper from alr_envs.utils.make_env_helpers import make_dmp_env
from alr_envs.dmc.manipulation.reach.reach_mp_wrapper import DMCReachSiteMPWrapper from alr_envs.utils.make_env_helpers import make_detpmp_env
from alr_envs.dmc.suite.ball_in_cup.ball_in_cup_mp_wrapper import DMCBallInCupMPWrapper from alr_envs.utils.make_env_helpers import make_env
from alr_envs.dmc.suite.cartpole.cartpole_mp_wrapper import DMCCartpoleMPWrapper, DMCCartpoleThreePolesMPWrapper, \ from alr_envs.utils.make_env_helpers import make_env_rank
DMCCartpoleTwoPolesMPWrapper
from alr_envs.open_ai import reacher_v2, continuous_mountain_car, fetch
from alr_envs.dmc.suite.reacher.reacher_mp_wrapper import DMCReacherMPWrapper
# Mujoco # Mujoco
@ -206,7 +203,7 @@ for v in versions:
# max_episode_steps=1, # max_episode_steps=1,
kwargs={ kwargs={
"name": f"alr_envs:{v}", "name": f"alr_envs:{v}",
"wrappers": [SimpleReacherMPWrapper], "wrappers": [classic_control.simple_reacher.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 2 if "long" not in v.lower() else 5, "num_dof": 2 if "long" not in v.lower() else 5,
"num_basis": 5, "num_basis": 5,
@ -225,7 +222,7 @@ register(
# max_episode_steps=1, # max_episode_steps=1,
kwargs={ kwargs={
"name": "alr_envs:ViaPointReacher-v0", "name": "alr_envs:ViaPointReacher-v0",
"wrappers": [ViaPointReacherMPWrapper], "wrappers": [classic_control.viapoint_reacher.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 5, "num_dof": 5,
"num_basis": 5, "num_basis": 5,
@ -238,6 +235,25 @@ register(
} }
) )
register(
id='ViaPointReacherDetPMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
# max_episode_steps=1,
kwargs={
"name": "alr_envs:ViaPointReacher-v0",
"wrappers": [classic_control.viapoint_reacher.MPWrapper],
"mp_kwargs": {
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"width": 0.025,
"policy_type": "velocity",
"weights_scale": 0.2,
"zero_start": True
}
}
)
## Hole Reacher ## Hole Reacher
versions = ["v0", "v1", "v2"] versions = ["v0", "v1", "v2"]
for v in versions: for v in versions:
@ -247,7 +263,7 @@ for v in versions:
# max_episode_steps=1, # max_episode_steps=1,
kwargs={ kwargs={
"name": f"alr_envs:HoleReacher-{v}", "name": f"alr_envs:HoleReacher-{v}",
"wrappers": [HoleReacherMPWrapper], "wrappers": [classic_control.hole_reacher.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 5, "num_dof": 5,
"num_basis": 5, "num_basis": 5,
@ -267,7 +283,7 @@ for v in versions:
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
kwargs={ kwargs={
"name": f"alr_envs:HoleReacher-{v}", "name": f"alr_envs:HoleReacher-{v}",
"wrappers": [HoleReacherMPWrapper], "wrappers": [classic_control.hole_reacher.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 5, "num_dof": 5,
"num_basis": 5, "num_basis": 5,
@ -283,11 +299,6 @@ for v in versions:
## Deep Mind Control Suite (DMC) ## Deep Mind Control Suite (DMC)
### Suite ### Suite
# tasks = ["ball_in_cup-catch", "reacher-easy", "reacher-hard", "cartpole-balance", "cartpole-balance_sparse",
# "cartpole-swingup", "cartpole-swingup_sparse", "cartpole-two_poles", "cartpole-three_poles"]
# wrappers = [DMCBallInCupMPWrapper, DMCReacherMPWrapper, DMCReacherMPWrapper, DMCCartpoleMPWrapper,
# partial(DMCCartpoleMPWrapper)]
# for t, w in zip(tasks, wrappers):
register( register(
id=f'dmc_ball_in_cup-catch_dmp-v0', id=f'dmc_ball_in_cup-catch_dmp-v0',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
@ -296,7 +307,7 @@ register(
"name": f"ball_in_cup-catch", "name": f"ball_in_cup-catch",
"time_limit": 1, "time_limit": 1,
"episode_length": 50, "episode_length": 50,
"wrappers": [DMCBallInCupMPWrapper], "wrappers": [dmc.suite.ball_in_cup.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 2, "num_dof": 2,
"num_basis": 5, "num_basis": 5,
@ -322,7 +333,7 @@ register(
"name": f"ball_in_cup-catch", "name": f"ball_in_cup-catch",
"time_limit": 1, "time_limit": 1,
"episode_length": 50, "episode_length": 50,
"wrappers": [DMCBallInCupMPWrapper], "wrappers": [dmc.suite.ball_in_cup.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 2, "num_dof": 2,
"num_basis": 5, "num_basis": 5,
@ -339,7 +350,7 @@ register(
} }
) )
# TODO tune gains and episode length for all below # TODO tune episode length for all below
register( register(
id=f'dmc_reacher-easy_dmp-v0', id=f'dmc_reacher-easy_dmp-v0',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
@ -348,7 +359,7 @@ register(
"name": f"reacher-easy", "name": f"reacher-easy",
"time_limit": 1, "time_limit": 1,
"episode_length": 50, "episode_length": 50,
"wrappers": [DMCReacherMPWrapper], "wrappers": [dmc.suite.reacher.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 2, "num_dof": 2,
"num_basis": 5, "num_basis": 5,
@ -374,7 +385,7 @@ register(
"name": f"reacher-easy", "name": f"reacher-easy",
"time_limit": 1, "time_limit": 1,
"episode_length": 50, "episode_length": 50,
"wrappers": [DMCReacherMPWrapper], "wrappers": [dmc.suite.reacher.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 2, "num_dof": 2,
"num_basis": 5, "num_basis": 5,
@ -399,7 +410,7 @@ register(
"name": f"reacher-hard", "name": f"reacher-hard",
"time_limit": 1, "time_limit": 1,
"episode_length": 50, "episode_length": 50,
"wrappers": [DMCReacherMPWrapper], "wrappers": [dmc.suite.reacher.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 2, "num_dof": 2,
"num_basis": 5, "num_basis": 5,
@ -425,7 +436,7 @@ register(
"name": f"reacher-hard", "name": f"reacher-hard",
"time_limit": 1, "time_limit": 1,
"episode_length": 50, "episode_length": 50,
"wrappers": [DMCReacherMPWrapper], "wrappers": [dmc.suite.reacher.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 2, "num_dof": 2,
"num_basis": 5, "num_basis": 5,
@ -448,8 +459,9 @@ register(
kwargs={ kwargs={
"name": f"cartpole-balance", "name": f"cartpole-balance",
# "time_limit": 1, # "time_limit": 1,
"camera_id": 0,
"episode_length": 1000, "episode_length": 1000,
"wrappers": [DMCCartpoleMPWrapper], "wrappers": [dmc.suite.cartpole.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
@ -461,8 +473,8 @@ register(
"weights_scale": 50, "weights_scale": 50,
"goal_scale": 0.1, "goal_scale": 0.1,
"policy_kwargs": { "policy_kwargs": {
"p_gains": 50, "p_gains": 10,
"d_gains": 1 "d_gains": 10
} }
} }
} }
@ -474,8 +486,9 @@ register(
kwargs={ kwargs={
"name": f"cartpole-balance", "name": f"cartpole-balance",
# "time_limit": 1, # "time_limit": 1,
"camera_id": 0,
"episode_length": 1000, "episode_length": 1000,
"wrappers": [DMCCartpoleMPWrapper], "wrappers": [dmc.suite.cartpole.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
@ -485,8 +498,8 @@ register(
"weights_scale": 0.2, "weights_scale": 0.2,
"zero_start": True, "zero_start": True,
"policy_kwargs": { "policy_kwargs": {
"p_gains": 50, "p_gains": 10,
"d_gains": 1 "d_gains": 10
} }
} }
} }
@ -498,8 +511,9 @@ register(
kwargs={ kwargs={
"name": f"cartpole-balance_sparse", "name": f"cartpole-balance_sparse",
# "time_limit": 1, # "time_limit": 1,
"camera_id": 0,
"episode_length": 1000, "episode_length": 1000,
"wrappers": [DMCCartpoleMPWrapper], "wrappers": [dmc.suite.cartpole.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
@ -511,8 +525,8 @@ register(
"weights_scale": 50, "weights_scale": 50,
"goal_scale": 0.1, "goal_scale": 0.1,
"policy_kwargs": { "policy_kwargs": {
"p_gains": 50, "p_gains": 10,
"d_gains": 1 "d_gains": 10
} }
} }
} }
@ -524,8 +538,9 @@ register(
kwargs={ kwargs={
"name": f"cartpole-balance_sparse", "name": f"cartpole-balance_sparse",
# "time_limit": 1, # "time_limit": 1,
"camera_id": 0,
"episode_length": 1000, "episode_length": 1000,
"wrappers": [DMCCartpoleMPWrapper], "wrappers": [dmc.suite.cartpole.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
@ -535,8 +550,8 @@ register(
"weights_scale": 0.2, "weights_scale": 0.2,
"zero_start": True, "zero_start": True,
"policy_kwargs": { "policy_kwargs": {
"p_gains": 50, "p_gains": 10,
"d_gains": 1 "d_gains": 10
} }
} }
} }
@ -549,8 +564,9 @@ register(
kwargs={ kwargs={
"name": f"cartpole-swingup", "name": f"cartpole-swingup",
# "time_limit": 1, # "time_limit": 1,
"camera_id": 0,
"episode_length": 1000, "episode_length": 1000,
"wrappers": [DMCCartpoleMPWrapper], "wrappers": [dmc.suite.cartpole.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
@ -562,8 +578,8 @@ register(
"weights_scale": 50, "weights_scale": 50,
"goal_scale": 0.1, "goal_scale": 0.1,
"policy_kwargs": { "policy_kwargs": {
"p_gains": 50, "p_gains": 10,
"d_gains": 1 "d_gains": 10
} }
} }
} }
@ -575,8 +591,9 @@ register(
kwargs={ kwargs={
"name": f"cartpole-swingup", "name": f"cartpole-swingup",
# "time_limit": 1, # "time_limit": 1,
"camera_id": 0,
"episode_length": 1000, "episode_length": 1000,
"wrappers": [DMCCartpoleMPWrapper], "wrappers": [dmc.suite.cartpole.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
@ -586,8 +603,8 @@ register(
"weights_scale": 0.2, "weights_scale": 0.2,
"zero_start": True, "zero_start": True,
"policy_kwargs": { "policy_kwargs": {
"p_gains": 50, "p_gains": 10,
"d_gains": 1 "d_gains": 10
} }
} }
} }
@ -599,8 +616,9 @@ register(
kwargs={ kwargs={
"name": f"cartpole-swingup_sparse", "name": f"cartpole-swingup_sparse",
# "time_limit": 1, # "time_limit": 1,
"camera_id": 0,
"episode_length": 1000, "episode_length": 1000,
"wrappers": [DMCCartpoleMPWrapper], "wrappers": [dmc.suite.cartpole.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
@ -612,8 +630,8 @@ register(
"weights_scale": 50, "weights_scale": 50,
"goal_scale": 0.1, "goal_scale": 0.1,
"policy_kwargs": { "policy_kwargs": {
"p_gains": 50, "p_gains": 10,
"d_gains": 1 "d_gains": 10
} }
} }
} }
@ -625,8 +643,9 @@ register(
kwargs={ kwargs={
"name": f"cartpole-swingup_sparse", "name": f"cartpole-swingup_sparse",
# "time_limit": 1, # "time_limit": 1,
"camera_id": 0,
"episode_length": 1000, "episode_length": 1000,
"wrappers": [DMCCartpoleMPWrapper], "wrappers": [dmc.suite.cartpole.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
@ -636,8 +655,8 @@ register(
"weights_scale": 0.2, "weights_scale": 0.2,
"zero_start": True, "zero_start": True,
"policy_kwargs": { "policy_kwargs": {
"p_gains": 50, "p_gains": 10,
"d_gains": 1 "d_gains": 10
} }
} }
} }
@ -649,9 +668,10 @@ register(
kwargs={ kwargs={
"name": f"cartpole-two_poles", "name": f"cartpole-two_poles",
# "time_limit": 1, # "time_limit": 1,
"camera_id": 0,
"episode_length": 1000, "episode_length": 1000,
# "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=2)], # "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=2)],
"wrappers": [DMCCartpoleTwoPolesMPWrapper], "wrappers": [dmc.suite.cartpole.TwoPolesMPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
@ -663,8 +683,8 @@ register(
"weights_scale": 50, "weights_scale": 50,
"goal_scale": 0.1, "goal_scale": 0.1,
"policy_kwargs": { "policy_kwargs": {
"p_gains": 50, "p_gains": 10,
"d_gains": 1 "d_gains": 10
} }
} }
} }
@ -676,9 +696,10 @@ register(
kwargs={ kwargs={
"name": f"cartpole-two_poles", "name": f"cartpole-two_poles",
# "time_limit": 1, # "time_limit": 1,
"camera_id": 0,
"episode_length": 1000, "episode_length": 1000,
# "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=2)], # "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=2)],
"wrappers": [DMCCartpoleTwoPolesMPWrapper], "wrappers": [dmc.suite.cartpole.TwoPolesMPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
@ -688,8 +709,8 @@ register(
"weights_scale": 0.2, "weights_scale": 0.2,
"zero_start": True, "zero_start": True,
"policy_kwargs": { "policy_kwargs": {
"p_gains": 50, "p_gains": 10,
"d_gains": 1 "d_gains": 10
} }
} }
} }
@ -701,9 +722,10 @@ register(
kwargs={ kwargs={
"name": f"cartpole-three_poles", "name": f"cartpole-three_poles",
# "time_limit": 1, # "time_limit": 1,
"camera_id": 0,
"episode_length": 1000, "episode_length": 1000,
# "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=3)], # "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=3)],
"wrappers": [DMCCartpoleThreePolesMPWrapper], "wrappers": [dmc.suite.cartpole.ThreePolesMPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
@ -715,8 +737,8 @@ register(
"weights_scale": 50, "weights_scale": 50,
"goal_scale": 0.1, "goal_scale": 0.1,
"policy_kwargs": { "policy_kwargs": {
"p_gains": 50, "p_gains": 10,
"d_gains": 1 "d_gains": 10
} }
} }
} }
@ -728,9 +750,10 @@ register(
kwargs={ kwargs={
"name": f"cartpole-three_poles", "name": f"cartpole-three_poles",
# "time_limit": 1, # "time_limit": 1,
"camera_id": 0,
"episode_length": 1000, "episode_length": 1000,
# "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=3)], # "wrappers": [partial(DMCCartpoleMPWrapper, n_poles=3)],
"wrappers": [DMCCartpoleThreePolesMPWrapper], "wrappers": [dmc.suite.cartpole.ThreePolesMPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 5, "num_basis": 5,
@ -740,8 +763,8 @@ register(
"weights_scale": 0.2, "weights_scale": 0.2,
"zero_start": True, "zero_start": True,
"policy_kwargs": { "policy_kwargs": {
"p_gains": 50, "p_gains": 10,
"d_gains": 1 "d_gains": 10
} }
} }
} }
@ -757,7 +780,7 @@ register(
"name": f"manipulation-reach_site_features", "name": f"manipulation-reach_site_features",
# "time_limit": 1, # "time_limit": 1,
"episode_length": 250, "episode_length": 250,
"wrappers": [DMCReachSiteMPWrapper], "wrappers": [dmc.manipulation.reach.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 9, "num_dof": 9,
"num_basis": 5, "num_basis": 5,
@ -779,7 +802,7 @@ register(
"name": f"manipulation-reach_site_features", "name": f"manipulation-reach_site_features",
# "time_limit": 1, # "time_limit": 1,
"episode_length": 250, "episode_length": 250,
"wrappers": [DMCReachSiteMPWrapper], "wrappers": [dmc.manipulation.reach.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 9, "num_dof": 9,
"num_basis": 5, "num_basis": 5,
@ -798,7 +821,7 @@ register(
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
kwargs={ kwargs={
"name": "gym.envs.classic_control:MountainCarContinuous-v0", "name": "gym.envs.classic_control:MountainCarContinuous-v0",
"wrappers": [continuous_mountain_car.MPWrapper], "wrappers": [open_ai.classic_control.continuous_mountain_car.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 1, "num_dof": 1,
"num_basis": 4, "num_basis": 4,
@ -819,7 +842,7 @@ register(
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
kwargs={ kwargs={
"name": "gym.envs.mujoco:Reacher-v2", "name": "gym.envs.mujoco:Reacher-v2",
"wrappers": [reacher_v2.MPWrapper], "wrappers": [open_ai.mujoco.reacher_v2.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 2, "num_dof": 2,
"num_basis": 6, "num_basis": 6,
@ -840,7 +863,7 @@ register(
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
kwargs={ kwargs={
"name": "gym.envs.robotics:FetchSlideDense-v1", "name": "gym.envs.robotics:FetchSlideDense-v1",
"wrappers": [fetch.MPWrapper], "wrappers": [FlattenObservation, open_ai.robotics.fetch.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 4, "num_dof": 4,
"num_basis": 5, "num_basis": 5,
@ -857,7 +880,7 @@ register(
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
kwargs={ kwargs={
"name": "gym.envs.robotics:FetchReachDense-v1", "name": "gym.envs.robotics:FetchReachDense-v1",
"wrappers": [fetch.MPWrapper], "wrappers": [FlattenObservation, open_ai.robotics.fetch.MPWrapper],
"mp_kwargs": { "mp_kwargs": {
"num_dof": 4, "num_dof": 4,
"num_basis": 5, "num_basis": 5,

View File

@ -0,0 +1 @@
from .mp_wrapper import MPWrapper

View File

@ -2,10 +2,10 @@ from typing import Tuple, Union
import numpy as np import numpy as np
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper from mp_env_api import MPEnvWrapper
class HoleReacherMPWrapper(MPEnvWrapper): class MPWrapper(MPEnvWrapper):
@property @property
def active_obs(self): def active_obs(self):
return np.hstack([ return np.hstack([

View File

@ -0,0 +1 @@
from .mp_wrapper import MPWrapper

View File

@ -2,10 +2,10 @@ from typing import Tuple, Union
import numpy as np import numpy as np
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper from mp_env_api import MPEnvWrapper
class SimpleReacherMPWrapper(MPEnvWrapper): class MPWrapper(MPEnvWrapper):
@property @property
def active_obs(self): def active_obs(self):
return np.hstack([ return np.hstack([

View File

@ -0,0 +1 @@
from .mp_wrapper import MPWrapper

View File

@ -2,10 +2,10 @@ from typing import Tuple, Union
import numpy as np import numpy as np
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper from mp_env_api import MPEnvWrapper
class ViaPointReacherMPWrapper(MPEnvWrapper): class MPWrapper(MPEnvWrapper):
@property @property
def active_obs(self): def active_obs(self):
return np.hstack([ return np.hstack([

View File

@ -0,0 +1,5 @@
# from alr_envs.dmc import manipulation, suite
from alr_envs.dmc.suite import ball_in_cup
from alr_envs.dmc.suite import reacher
from alr_envs.dmc.suite import cartpole
from alr_envs.dmc.manipulation import reach

View File

@ -0,0 +1 @@
from .mp_wrapper import MPWrapper

View File

@ -2,10 +2,10 @@ from typing import Tuple, Union
import numpy as np import numpy as np
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper from mp_env_api import MPEnvWrapper
class DMCReachSiteMPWrapper(MPEnvWrapper): class MPWrapper(MPEnvWrapper):
@property @property
def active_obs(self): def active_obs(self):

View File

@ -0,0 +1 @@
from .mp_wrapper import MPWrapper

View File

@ -2,10 +2,10 @@ from typing import Tuple, Union
import numpy as np import numpy as np
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper from mp_env_api import MPEnvWrapper
class DMCBallInCupMPWrapper(MPEnvWrapper): class MPWrapper(MPEnvWrapper):
@property @property
def active_obs(self): def active_obs(self):

View File

@ -0,0 +1,3 @@
from .mp_wrapper import MPWrapper
from .mp_wrapper import TwoPolesMPWrapper
from .mp_wrapper import ThreePolesMPWrapper

View File

@ -2,10 +2,10 @@ from typing import Tuple, Union
import numpy as np import numpy as np
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper from mp_env_api import MPEnvWrapper
class DMCCartpoleMPWrapper(MPEnvWrapper): class MPWrapper(MPEnvWrapper):
def __init__(self, env, n_poles: int = 1): def __init__(self, env, n_poles: int = 1):
self.n_poles = n_poles self.n_poles = n_poles
@ -39,13 +39,13 @@ class DMCCartpoleMPWrapper(MPEnvWrapper):
return self.env.dt return self.env.dt
class DMCCartpoleTwoPolesMPWrapper(DMCCartpoleMPWrapper): class TwoPolesMPWrapper(MPWrapper):
def __init__(self, env): def __init__(self, env):
super().__init__(env, n_poles=2) super().__init__(env, n_poles=2)
class DMCCartpoleThreePolesMPWrapper(DMCCartpoleMPWrapper): class ThreePolesMPWrapper(MPWrapper):
def __init__(self, env): def __init__(self, env):
super().__init__(env, n_poles=3) super().__init__(env, n_poles=3)

View File

@ -0,0 +1 @@
from .mp_wrapper import MPWrapper

View File

@ -2,10 +2,10 @@ from typing import Tuple, Union
import numpy as np import numpy as np
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper from mp_env_api import MPEnvWrapper
class DMCReacherMPWrapper(MPEnvWrapper): class MPWrapper(MPEnvWrapper):
@property @property
def active_obs(self): def active_obs(self):

View File

@ -1,5 +1,5 @@
from alr_envs.dmc.suite.ball_in_cup.ball_in_cup_mp_wrapper import DMCBallInCupMPWrapper import alr_envs
from alr_envs.utils.make_env_helpers import make_dmp_env, make_env from alr_envs.dmc.suite.ball_in_cup.mp_wrapper import MPWrapper
def example_dmc(env_id="fish-swim", seed=1, iterations=1000, render=True): def example_dmc(env_id="fish-swim", seed=1, iterations=1000, render=True):
@ -17,13 +17,12 @@ def example_dmc(env_id="fish-swim", seed=1, iterations=1000, render=True):
Returns: Returns:
""" """
env = make_env(env_id, seed) env = alr_envs.make_env(env_id, seed)
rewards = 0 rewards = 0
obs = env.reset() obs = env.reset()
print("observation shape:", env.observation_space.shape) print("observation shape:", env.observation_space.shape)
print("action shape:", env.action_space.shape) print("action shape:", env.action_space.shape)
# number of samples(multiple environment steps)
for i in range(iterations): for i in range(iterations):
ac = env.action_space.sample() ac = env.action_space.sample()
obs, reward, done, info = env.step(ac) obs, reward, done, info = env.step(ac)
@ -63,7 +62,7 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
# Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper. # Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper.
# You can also add other gym.Wrappers in case they are needed. # You can also add other gym.Wrappers in case they are needed.
wrappers = [DMCBallInCupMPWrapper] wrappers = [MPWrapper]
mp_kwargs = { mp_kwargs = {
"num_dof": 2, "num_dof": 2,
"num_basis": 5, "num_basis": 5,
@ -84,9 +83,9 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
"episode_length": 1000, "episode_length": 1000,
# "frame_skip": 1 # "frame_skip": 1
} }
env = make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs) env = alr_envs.make_dmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_kwargs, **kwargs)
# OR for a deterministic ProMP: # OR for a deterministic ProMP:
# env = make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args) # env = alr_envs.make_detpmp_env(base_env, wrappers=wrappers, seed=seed, mp_kwargs=mp_args)
# This renders the full MP trajectory # This renders the full MP trajectory
# It is only required to call render() once in the beginning, which renders every consecutive trajectory. # It is only required to call render() once in the beginning, which renders every consecutive trajectory.

View File

@ -1,11 +1,9 @@
import warnings
from collections import defaultdict from collections import defaultdict
import gym import gym
import numpy as np import numpy as np
from alr_envs.utils.make_env_helpers import make_env, make_env_rank import alr_envs
from alr_envs.utils.mp_env_async_sampler import AlrContextualMpEnvSampler, AlrMpEnvSampler, DummyDist
def example_general(env_id="Pendulum-v0", seed=1, iterations=1000, render=True): def example_general(env_id="Pendulum-v0", seed=1, iterations=1000, render=True):
@ -23,7 +21,7 @@ def example_general(env_id="Pendulum-v0", seed=1, iterations=1000, render=True):
""" """
env = make_env(env_id, seed) env = alr_envs.make_env(env_id, seed)
rewards = 0 rewards = 0
obs = env.reset() obs = env.reset()
print("Observation shape: ", env.observation_space.shape) print("Observation shape: ", env.observation_space.shape)
@ -58,7 +56,7 @@ def example_async(env_id="alr_envs:HoleReacher-v0", n_cpu=4, seed=int('533D', 16
Returns: Tuple of (obs, reward, done, info) with type np.ndarray Returns: Tuple of (obs, reward, done, info) with type np.ndarray
""" """
env = gym.vector.AsyncVectorEnv([make_env_rank(env_id, seed, i) for i in range(n_cpu)]) env = gym.vector.AsyncVectorEnv([alr_envs.make_env_rank(env_id, seed, i) for i in range(n_cpu)])
# OR # OR
# envs = gym.vector.AsyncVectorEnv([make_env(env_id, seed + i) for i in range(n_cpu)]) # envs = gym.vector.AsyncVectorEnv([make_env(env_id, seed + i) for i in range(n_cpu)])

View File

@ -1,4 +1,4 @@
from alr_envs import HoleReacherMPWrapper from alr_envs import MPWrapper
from alr_envs.utils.make_env_helpers import make_dmp_env, make_env from alr_envs.utils.make_env_helpers import make_dmp_env, make_env
@ -113,7 +113,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
# Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper. # Replace this wrapper with the custom wrapper for your environment by inheriting from the MPEnvWrapper.
# You can also add other gym.Wrappers in case they are needed. # You can also add other gym.Wrappers in case they are needed.
wrappers = [HoleReacherMPWrapper] wrappers = [MPWrapper]
mp_kwargs = { mp_kwargs = {
"num_dof": 5, "num_dof": 5,
"num_basis": 5, "num_basis": 5,

View File

@ -0,0 +1,74 @@
import numpy as np
from matplotlib import pyplot as plt
from alr_envs import dmc
from alr_envs.utils.make_env_helpers import make_detpmp_env
# This might work for some environments, however, please verify either way the correct trajectory information
# for your environment are extracted below
SEED = 10
env_id = "cartpole-swingup"
wrappers = [dmc.suite.cartpole.MPWrapper]
mp_kwargs = {
"num_dof": 1,
"num_basis": 5,
"duration": 2,
"width": 0.025,
"policy_type": "motor",
"weights_scale": 0.2,
"zero_start": True,
"policy_kwargs": {
"p_gains": 10,
"d_gains": 10 # a good starting point is the sqrt of p_gains
}
}
kwargs = dict(time_limit=2, episode_length=200)
env = make_detpmp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs,
**kwargs)
# Plot difference between real trajectory and target MP trajectory
env.reset()
pos, vel = env.mp_rollout(env.action_space.sample())
base_shape = env.full_action_space.shape
actual_pos = np.zeros((len(pos), *base_shape))
actual_pos_ball = np.zeros((len(pos), *base_shape))
actual_vel = np.zeros((len(pos), *base_shape))
act = np.zeros((len(pos), *base_shape))
for t, pos_vel in enumerate(zip(pos, vel)):
actions = env.policy.get_action(pos_vel[0], pos_vel[1])
actions = np.clip(actions, env.full_action_space.low, env.full_action_space.high)
_, _, _, _ = env.env.step(actions)
act[t, :] = actions
# TODO verify for your environment
actual_pos[t, :] = env.current_pos
# actual_pos_ball[t, :] = env.physics.data.qpos[2:]
actual_vel[t, :] = env.current_vel
plt.figure(figsize=(15, 5))
plt.subplot(131)
plt.title("Position")
plt.plot(actual_pos, c='C0', label=["true" if i == 0 else "" for i in range(np.prod(base_shape))])
# plt.plot(actual_pos_ball, label="true pos ball")
plt.plot(pos, c='C1', label=["MP" if i == 0 else "" for i in range(np.prod(base_shape))])
plt.xlabel("Episode steps")
plt.legend()
plt.subplot(132)
plt.title("Velocity")
plt.plot(actual_vel, c='C0', label=[f"true" if i == 0 else "" for i in range(np.prod(base_shape))])
plt.plot(vel, c='C1', label=[f"MP" if i == 0 else "" for i in range(np.prod(base_shape))])
plt.xlabel("Episode steps")
plt.legend()
plt.subplot(133)
plt.title("Actions")
plt.plot(act, c="C0"), # label=[f"actions" if i == 0 else "" for i in range(np.prod(base_action_shape))])
plt.xlabel("Episode steps")
# plt.legend()
plt.show()

View File

@ -2,7 +2,7 @@ from typing import Tuple, Union
import numpy as np import numpy as np
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper from mp_env_api import MPEnvWrapper
class BallInACupMPWrapper(MPEnvWrapper): class BallInACupMPWrapper(MPEnvWrapper):

View File

@ -0,0 +1,3 @@
from alr_envs.open_ai.mujoco import reacher_v2
from alr_envs.open_ai.robotics import fetch
from alr_envs.open_ai.classic_control import continuous_mountain_car

View File

@ -0,0 +1 @@
from .mp_wrapper import MPWrapper

View File

@ -1,7 +1,7 @@
from typing import Union from typing import Union
import numpy as np import numpy as np
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper from mp_env_api import MPEnvWrapper
class MPWrapper(MPEnvWrapper): class MPWrapper(MPEnvWrapper):

View File

@ -1 +0,0 @@
from alr_envs.open_ai.continuous_mountain_car.mp_wrapper import MPWrapper

View File

@ -1 +0,0 @@
from alr_envs.open_ai.fetch.mp_wrapper import MPWrapper

View File

@ -1,22 +0,0 @@
from typing import Union
import numpy as np
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper
class MPWrapper(MPEnvWrapper):
@property
def current_vel(self) -> Union[float, int, np.ndarray]:
return self.unwrapped._get_obs()["observation"][-5:-1]
@property
def current_pos(self) -> Union[float, int, np.ndarray]:
return self.unwrapped._get_obs()["observation"][:4]
@property
def goal_pos(self):
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
@property
def dt(self) -> Union[float, int]:
return self.env.dt

View File

View File

@ -0,0 +1 @@
from .mp_wrapper import MPWrapper

View File

@ -1,7 +1,7 @@
from typing import Union from typing import Union
import numpy as np import numpy as np
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper from mp_env_api import MPEnvWrapper
class MPWrapper(MPEnvWrapper): class MPWrapper(MPEnvWrapper):

View File

@ -1 +0,0 @@
from alr_envs.open_ai.reacher_v2.mp_wrapper import MPWrapper

View File

View File

@ -0,0 +1 @@
from .mp_wrapper import MPWrapper

View File

@ -0,0 +1,49 @@
from typing import Union
import numpy as np
from mp_env_api import MPEnvWrapper
class MPWrapper(MPEnvWrapper):
@property
def active_obs(self):
return np.hstack([
[False] * 3, # achieved goal
[True] * 3, # desired/true goal
[False] * 3, # grip pos
[True, True, False] * int(self.has_object), # object position
[True, True, False] * int(self.has_object), # object relative position
[False] * 2, # gripper state
[False] * 3 * int(self.has_object), # object rotation
[False] * 3 * int(self.has_object), # object velocity position
[False] * 3 * int(self.has_object), # object velocity rotation
[False] * 3, # grip velocity position
[False] * 2, # gripper velocity
]).astype(bool)
@property
def current_vel(self) -> Union[float, int, np.ndarray]:
dt = self.sim.nsubsteps * self.sim.model.opt.timestep
grip_velp = self.sim.data.get_site_xvelp("robot0:grip") * dt
# gripper state should be symmetric for left and right.
# They are controlled with only one action for both gripper joints
gripper_state = self.sim.data.get_joint_qvel('robot0:r_gripper_finger_joint') * dt
return np.hstack([grip_velp, gripper_state])
@property
def current_pos(self) -> Union[float, int, np.ndarray]:
grip_pos = self.sim.data.get_site_xpos("robot0:grip")
# gripper state should be symmetric for left and right.
# They are controlled with only one action for both gripper joints
gripper_state = self.sim.data.get_joint_qpos('robot0:r_gripper_finger_joint')
return np.hstack([grip_pos, gripper_state])
@property
def goal_pos(self):
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
@property
def dt(self) -> Union[float, int]:
return self.env.dt

View File

@ -37,7 +37,6 @@ def make(
episode_length = 250 if domain_name == "manipulation" else 1000 episode_length = 250 if domain_name == "manipulation" else 1000
max_episode_steps = (episode_length + frame_skip - 1) // frame_skip max_episode_steps = (episode_length + frame_skip - 1) // frame_skip
if env_id not in gym.envs.registry.env_specs: if env_id not in gym.envs.registry.env_specs:
task_kwargs = {'random': seed} task_kwargs = {'random': seed}
# if seed is not None: # if seed is not None:
@ -46,7 +45,7 @@ def make(
task_kwargs['time_limit'] = time_limit task_kwargs['time_limit'] = time_limit
register( register(
id=env_id, id=env_id,
entry_point='alr_envs.utils.dmc2gym_wrapper:DMCWrapper', entry_point='alr_envs.utils.dmc_wrapper:DMCWrapper',
kwargs=dict( kwargs=dict(
domain_name=domain_name, domain_name=domain_name,
task_name=task_name, task_name=task_name,

View File

@ -33,11 +33,14 @@ def _spec_to_box(spec):
def _flatten_obs(obs: collections.MutableMapping): def _flatten_obs(obs: collections.MutableMapping):
# obs_pieces = [] """
# for v in obs.values(): Flattens an observation of type MutableMapping, e.g. a dict to a 1D array.
# flat = np.array([v]) if np.isscalar(v) else v.ravel() Args:
# obs_pieces.append(flat) obs: observation to flatten
# return np.concatenate(obs_pieces, axis=0)
Returns: 1D array of observation
"""
if not isinstance(obs, collections.MutableMapping): if not isinstance(obs, collections.MutableMapping):
raise ValueError(f'Requires dict-like observations structure. {type(obs)} found.') raise ValueError(f'Requires dict-like observations structure. {type(obs)} found.')
@ -52,19 +55,19 @@ def _flatten_obs(obs: collections.MutableMapping):
class DMCWrapper(core.Env): class DMCWrapper(core.Env):
def __init__( def __init__(
self, self,
domain_name, domain_name: str,
task_name, task_name: str,
task_kwargs={}, task_kwargs: dict = {},
visualize_reward=True, visualize_reward: bool = True,
from_pixels=False, from_pixels: bool = False,
height=84, height: int = 84,
width=84, width: int = 84,
camera_id=0, camera_id: int = 0,
frame_skip=1, frame_skip: int = 1,
environment_kwargs=None, environment_kwargs: dict = None,
channels_first=True channels_first: bool = True
): ):
assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour' assert 'random' in task_kwargs, 'Please specify a seed for deterministic behavior.'
self._from_pixels = from_pixels self._from_pixels = from_pixels
self._height = height self._height = height
self._width = width self._width = width
@ -74,7 +77,7 @@ class DMCWrapper(core.Env):
# create task # create task
if domain_name == "manipulation": if domain_name == "manipulation":
assert not from_pixels, \ assert not from_pixels and not task_name.endswith("_vision"), \
"TODO: Vision interface for manipulation is different to suite and needs to be implemented" "TODO: Vision interface for manipulation is different to suite and needs to be implemented"
self._env = manipulation.load(environment_name=task_name, seed=task_kwargs['random']) self._env = manipulation.load(environment_name=task_name, seed=task_kwargs['random'])
else: else:
@ -169,11 +172,12 @@ class DMCWrapper(core.Env):
if self._last_state is None: if self._last_state is None:
raise ValueError('Environment not ready to render. Call reset() first.') raise ValueError('Environment not ready to render. Call reset() first.')
camera_id = camera_id or self._camera_id
# assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode # assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode
if mode == "rgb_array": if mode == "rgb_array":
height = height or self._height height = height or self._height
width = width or self._width width = width or self._width
camera_id = camera_id or self._camera_id
return self._env.physics.render(height=height, width=width, camera_id=camera_id) return self._env.physics.render(height=height, width=width, camera_id=camera_id)
elif mode == 'human': elif mode == 'human':
@ -184,7 +188,8 @@ class DMCWrapper(core.Env):
self.viewer = rendering.SimpleImageViewer() self.viewer = rendering.SimpleImageViewer()
# Render max available buffer size. Larger is only possible by altering the XML. # Render max available buffer size. Larger is only possible by altering the XML.
img = self._env.physics.render(height=self._env.physics.model.vis.global_.offheight, img = self._env.physics.render(height=self._env.physics.model.vis.global_.offheight,
width=self._env.physics.model.vis.global_.offwidth) width=self._env.physics.model.vis.global_.offwidth,
camera_id=camera_id)
self.viewer.imshow(img) self.viewer.imshow(img)
return self.viewer.isopen return self.viewer.isopen

View File

@ -2,13 +2,14 @@ import logging
from typing import Iterable, List, Type, Union from typing import Iterable, List, Type, Union
import gym import gym
import numpy as np
from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper from mp_env_api import MPEnvWrapper
from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper from mp_env_api.mp_wrappers.detpmp_wrapper import DetPMPWrapper
from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper from mp_env_api.mp_wrappers.dmp_wrapper import DmpWrapper
def make_env_rank(env_id: str, seed: int, rank: int = 0, **kwargs): def make_env_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwargs):
""" """
TODO: Do we need this? TODO: Do we need this?
Generate a callable to create a new gym environment with a given seed. Generate a callable to create a new gym environment with a given seed.
@ -22,11 +23,16 @@ def make_env_rank(env_id: str, seed: int, rank: int = 0, **kwargs):
env_id: name of the environment env_id: name of the environment
seed: seed for deterministic behaviour seed: seed for deterministic behaviour
rank: environment rank for deterministic over multiple seeds behaviour rank: environment rank for deterministic over multiple seeds behaviour
return_callable: If True returns a callable to create the environment instead of the environment itself.
Returns: Returns:
""" """
return lambda: make_env(env_id, seed + rank, **kwargs)
def f():
return make_env(env_id, seed + rank, **kwargs)
return f if return_callable else f()
def make_env(env_id: str, seed, **kwargs): def make_env(env_id: str, seed, **kwargs):
@ -103,6 +109,9 @@ def make_dmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwargs
verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None)) verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
verify_dof(_env, mp_kwargs.get("num_dof"))
return DmpWrapper(_env, **mp_kwargs) return DmpWrapper(_env, **mp_kwargs)
@ -120,6 +129,9 @@ def make_detpmp_env(env_id: str, wrappers: Iterable, seed=1, mp_kwargs={}, **kwa
verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None)) verify_time_limit(mp_kwargs.get("duration", None), kwargs.get("time_limit", None))
_env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) _env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
verify_dof(_env, mp_kwargs.get("num_dof"))
return DetPMPWrapper(_env, **mp_kwargs) return DetPMPWrapper(_env, **mp_kwargs)
@ -185,5 +197,12 @@ def verify_time_limit(mp_time_limit: Union[None, float], env_time_limit: Union[N
""" """
if mp_time_limit is not None and env_time_limit is not None: if mp_time_limit is not None and env_time_limit is not None:
assert mp_time_limit == env_time_limit, \ assert mp_time_limit == env_time_limit, \
f"The manually specified 'time_limit' of {env_time_limit}s does not match " \ f"The specified 'time_limit' of {env_time_limit}s does not match " \
f"the duration of {mp_time_limit}s for the MP." f"the duration of {mp_time_limit}s for the MP."
def verify_dof(base_env: gym.Env, dof: int):
action_shape = np.prod(base_env.action_space.shape)
assert dof == action_shape, \
f"The specified degrees of freedom ('num_dof') {dof} do not match " \
f"the action space of {action_shape} the base environments"

View File

@ -15,8 +15,7 @@ def angle_normalize(x, type="deg"):
if type not in ["deg", "rad"]: raise ValueError(f"Invalid type {type}. Choose one of 'deg' or 'rad'.") if type not in ["deg", "rad"]: raise ValueError(f"Invalid type {type}. Choose one of 'deg' or 'rad'.")
if type == "deg": if type == "deg":
x = np.deg2rad(x) # x * pi / 180 x = np.deg2rad(x) # x * pi / 180
two_pi = 2 * np.pi two_pi = 2 * np.pi
return x - two_pi * np.floor((x + np.pi) / two_pi) return x - two_pi * np.floor((x + np.pi) / two_pi)

View File

@ -3,7 +3,7 @@ from setuptools import setup
setup( setup(
name='alr_envs', name='alr_envs',
version='0.0.1', version='0.0.1',
packages=['alr_envs', 'alr_envs.classic_control', 'alr_envs.open_ai', 'alr_envs.mujoco', 'alr_envs.stochastic_search', packages=['alr_envs', 'alr_envs.classic_control', 'alr_envs.open_ai', 'alr_envs.mujoco', 'alr_envs.dmc',
'alr_envs.utils'], 'alr_envs.utils'],
install_requires=[ install_requires=[
'gym', 'gym',

View File

@ -88,11 +88,8 @@ class TestEnvironments(unittest.TestCase):
def test_environment_determinism(self): def test_environment_determinism(self):
"""Tests that identical seeds produce identical trajectories.""" """Tests that identical seeds produce identical trajectories."""
seed = 0 seed = 0
# Iterate over two trajectories generated using identical sequences of # Iterate over two trajectories, which should have the same state and action sequence
# random actions, and with identical task random states. Check that the
# observations, rewards, discounts and step types are identical.
for spec in ALL_SPECS: for spec in ALL_SPECS:
# try:
with self.subTest(msg=spec.id): with self.subTest(msg=spec.id):
self._run_env(spec.id) self._run_env(spec.id)
traj1 = self._run_env(spec.id, seed=seed) traj1 = self._run_env(spec.id, seed=seed)