updated envs and registering

This commit is contained in:
ottofabian 2021-06-25 16:16:56 +02:00
parent 0a1e55d97b
commit c8742e2934
9 changed files with 248 additions and 179 deletions

View File

@ -1,12 +1,16 @@
import numpy as np import numpy as np
from gym.envs.registration import register from gym.envs.registration import register
from alr_envs.classic_control.hole_reacher.hole_reacher_mp_wrapper import HoleReacherMPWrapper
from alr_envs.classic_control.simple_reacher.simple_reacher_mp_wrapper import SimpleReacherMPWrapper
from alr_envs.classic_control.viapoint_reacher.viapoint_reacher_mp_wrapper import ViaPointReacherMPWrapper
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_mp_wrapper import BallInACupMPWrapper
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_positional_wrapper import BallInACupPositionalWrapper
from alr_envs.stochastic_search.functions.f_rosenbrock import Rosenbrock from alr_envs.stochastic_search.functions.f_rosenbrock import Rosenbrock
# from alr_envs.utils.mps.dmp_wrapper import DmpWrapper
# Mujoco # Mujoco
## Reacher
register( register(
id='ALRReacher-v0', id='ALRReacher-v0',
entry_point='alr_envs.mujoco:ALRReacherEnv', entry_point='alr_envs.mujoco:ALRReacherEnv',
@ -177,7 +181,7 @@ register(
register( register(
id='ViaPointReacher-v0', id='ViaPointReacher-v0',
entry_point='alr_envs.classic_control.viapoint_reacher:ViaPointReacher', entry_point='alr_envs.classic_control:ViaPointReacher',
max_episode_steps=200, max_episode_steps=200,
kwargs={ kwargs={
"n_links": 5, "n_links": 5,
@ -189,7 +193,7 @@ register(
## Hole Reacher ## Hole Reacher
register( register(
id='HoleReacher-v0', id='HoleReacher-v0',
entry_point='alr_envs.classic_control.hole_reacher:HoleReacherEnv', entry_point='alr_envs.classic_control:HoleReacherEnv',
max_episode_steps=200, max_episode_steps=200,
kwargs={ kwargs={
"n_links": 5, "n_links": 5,
@ -205,7 +209,7 @@ register(
register( register(
id='HoleReacher-v1', id='HoleReacher-v1',
entry_point='alr_envs.classic_control.hole_reacher:HoleReacherEnv', entry_point='alr_envs.classic_control:HoleReacherEnv',
max_episode_steps=200, max_episode_steps=200,
kwargs={ kwargs={
"n_links": 5, "n_links": 5,
@ -221,7 +225,7 @@ register(
register( register(
id='HoleReacher-v2', id='HoleReacher-v2',
entry_point='alr_envs.classic_control.hole_reacher:HoleReacherEnv', entry_point='alr_envs.classic_control:HoleReacherEnv',
max_episode_steps=200, max_episode_steps=200,
kwargs={ kwargs={
"n_links": 5, "n_links": 5,
@ -236,16 +240,19 @@ register(
) )
# MP environments # MP environments
## Simple Reacher ## Simple Reacher
versions = ["SimpleReacher-v0", "SimpleReacher-v1", "LongSimpleReacher-v0", "LongSimpleReacher-v1"] versions = ["SimpleReacher-v0", "SimpleReacher-v1", "LongSimpleReacher-v0", "LongSimpleReacher-v1"]
for v in versions: for v in versions:
name = v.split("-") name = v.split("-")
register( register(
id=f'{name[0]}DMP-{name[1]}', id=f'{name[0]}DMP-{name[1]}',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env', entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
# max_episode_steps=1, # max_episode_steps=1,
kwargs={ kwargs={
"name": f"alr_envs:{v}", "name": f"alr_envs:{v}",
"wrappers": [SimpleReacherMPWrapper],
"mp_kwargs": {
"num_dof": 2 if "long" not in v.lower() else 5, "num_dof": 2 if "long" not in v.lower() else 5,
"num_basis": 5, "num_basis": 5,
"duration": 2, "duration": 2,
@ -254,22 +261,26 @@ for v in versions:
"policy_type": "velocity", "policy_type": "velocity",
"weights_scale": 50, "weights_scale": 50,
} }
}
) )
register( register(
id='ViaPointReacherDMP-v0', id='ViaPointReacherDMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env', entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
# max_episode_steps=1, # max_episode_steps=1,
kwargs={ kwargs={
"name": "alr_envs:ViaPointReacher-v0", "name": "alr_envs:ViaPointReacher-v0",
"wrappers": [ViaPointReacherMPWrapper],
"mp_kwargs": {
"num_dof": 5, "num_dof": 5,
"num_basis": 5, "num_basis": 5,
"duration": 2, "duration": 2,
"learn_goal": True,
"alpha_phase": 2, "alpha_phase": 2,
"learn_goal": False,
"policy_type": "velocity", "policy_type": "velocity",
"weights_scale": 50, "weights_scale": 50,
} }
}
) )
## Hole Reacher ## Hole Reacher
@ -277,10 +288,12 @@ versions = ["v0", "v1", "v2"]
for v in versions: for v in versions:
register( register(
id=f'HoleReacherDMP-{v}', id=f'HoleReacherDMP-{v}',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env', entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
# max_episode_steps=1, # max_episode_steps=1,
kwargs={ kwargs={
"name": f"alr_envs:HoleReacher-{v}", "name": f"alr_envs:HoleReacher-{v}",
"wrappers": [HoleReacherMPWrapper],
"mp_kwargs": {
"num_dof": 5, "num_dof": 5,
"num_basis": 5, "num_basis": 5,
"duration": 2, "duration": 2,
@ -291,13 +304,16 @@ for v in versions:
"weights_scale": 50, "weights_scale": 50,
"goal_scale": 0.1 "goal_scale": 0.1
} }
}
) )
register( register(
id=f'HoleReacherDetPMP-{v}', id=f'HoleReacherDetPMP-{v}',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env', entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
kwargs={ kwargs={
"name": f"alr_envs:HoleReacher-{v}", "name": f"alr_envs:HoleReacher-{v}",
"wrappers": [HoleReacherMPWrapper],
"mp_kwargs": {
"num_dof": 5, "num_dof": 5,
"num_basis": 5, "num_basis": 5,
"duration": 2, "duration": 2,
@ -306,15 +322,18 @@ for v in versions:
"weights_scale": 0.2, "weights_scale": 0.2,
"zero_start": True "zero_start": True
} }
}
) )
# TODO: properly add final_pos # TODO: properly add final_pos
register( register(
id='HoleReacherFixedGoalDMP-v0', id='HoleReacherFixedGoalDMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env', entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
# max_episode_steps=1, # max_episode_steps=1,
kwargs={ kwargs={
"name": "alr_envs:HoleReacher-v0", "name": "alr_envs:HoleReacher-v0",
"wrappers": [HoleReacherMPWrapper],
"mp_kwargs": {
"num_dof": 5, "num_dof": 5,
"num_basis": 5, "num_basis": 5,
"duration": 2, "duration": 2,
@ -324,15 +343,18 @@ register(
"weights_scale": 50, "weights_scale": 50,
"goal_scale": 0.1 "goal_scale": 0.1
} }
}
) )
## Ball in Cup ## Ball in Cup
register( register(
id='ALRBallInACupSimpleDMP-v0', id='ALRBallInACupSimpleDMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env', entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
kwargs={ kwargs={
"name": "alr_envs:ALRBallInACupSimple-v0", "name": "alr_envs:ALRBallInACupSimple-v0",
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
"mp_kwargs": {
"num_dof": 3, "num_dof": 3,
"num_basis": 5, "num_basis": 5,
"duration": 3.5, "duration": 3.5,
@ -343,16 +365,21 @@ register(
"policy_type": "motor", "policy_type": "motor",
"weights_scale": 100, "weights_scale": 100,
"return_to_start": True, "return_to_start": True,
"policy_kwargs": {
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]), "p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025]) "d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
} }
}
}
) )
register( register(
id='ALRBallInACupDMP-v0', id='ALRBallInACupDMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env', entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
kwargs={ kwargs={
"name": "alr_envs:ALRBallInACup-v0", "name": "alr_envs:ALRBallInACup-v0",
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
"mp_kwargs": {
"num_dof": 7, "num_dof": 7,
"num_basis": 5, "num_basis": 5,
"duration": 3.5, "duration": 3.5,
@ -363,16 +390,21 @@ register(
"policy_type": "motor", "policy_type": "motor",
"weights_scale": 100, "weights_scale": 100,
"return_to_start": True, "return_to_start": True,
"policy_kwargs": {
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]), "p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025]) "d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
} }
}
}
) )
register( register(
id='ALRBallInACupSimpleDetPMP-v0', id='ALRBallInACupSimpleDetPMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env', entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
kwargs={ kwargs={
"name": "alr_envs:ALRBallInACupSimple-v0", "name": "alr_envs:ALRBallInACupSimple-v0",
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
"mp_kwargs": {
"num_dof": 3, "num_dof": 3,
"num_basis": 5, "num_basis": 5,
"duration": 3.5, "duration": 3.5,
@ -383,16 +415,21 @@ register(
"weights_scale": 0.2, "weights_scale": 0.2,
"zero_start": True, "zero_start": True,
"zero_goal": True, "zero_goal": True,
"policy_kwargs": {
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]), "p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025]) "d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
} }
}
}
) )
register( register(
id='ALRBallInACupPDSimpleDetPMP-v0', id='ALRBallInACupPDSimpleDetPMP-v0',
entry_point='alr_envs.mujoco.ball_in_a_cup.biac_pd:make_detpmp_env', entry_point='alr_envs.mujoco.ball_in_a_cup.biac_pd:make_detpmp_env_helper',
kwargs={ kwargs={
"name": "alr_envs:ALRBallInACupPDSimple-v0", "name": "alr_envs:ALRBallInACupPDSimple-v0",
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
"mp_kwargs": {
"num_dof": 3, "num_dof": 3,
"num_basis": 5, "num_basis": 5,
"duration": 3.5, "duration": 3.5,
@ -403,9 +440,12 @@ register(
"weights_scale": 0.2, "weights_scale": 0.2,
"zero_start": True, "zero_start": True,
"zero_goal": True, "zero_goal": True,
"policy_kwargs": {
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]), "p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025]) "d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
} }
}
}
) )
register( register(
@ -430,9 +470,11 @@ register(
register( register(
id='ALRBallInACupDetPMP-v0', id='ALRBallInACupDetPMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env', entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
kwargs={ kwargs={
"name": "alr_envs:ALRBallInACupSimple-v0", "name": "alr_envs:ALRBallInACupSimple-v0",
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
"mp_kwargs": {
"num_dof": 7, "num_dof": 7,
"num_basis": 5, "num_basis": 5,
"duration": 3.5, "duration": 3.5,
@ -442,9 +484,13 @@ register(
"weights_scale": 0.2, "weights_scale": 0.2,
"zero_start": True, "zero_start": True,
"zero_goal": True, "zero_goal": True,
"policy_kwargs": {
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]), "p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025]) "d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
} }
}
}
) )
register( register(
@ -452,6 +498,8 @@ register(
entry_point='alr_envs.utils.make_env_helpers:make_contextual_env', entry_point='alr_envs.utils.make_env_helpers:make_contextual_env',
kwargs={ kwargs={
"name": "alr_envs:ALRBallInACupGoal-v0", "name": "alr_envs:ALRBallInACupGoal-v0",
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
"mp_kwargs": {
"num_dof": 7, "num_dof": 7,
"num_basis": 5, "num_basis": 5,
"duration": 3.5, "duration": 3.5,
@ -462,9 +510,12 @@ register(
"policy_type": "motor", "policy_type": "motor",
"weights_scale": 50, "weights_scale": 50,
"goal_scale": 0.1, "goal_scale": 0.1,
"policy_kwargs": {
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]), "p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025]) "d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
} }
}
}
) )
# BBO functions # BBO functions

View File

@ -1,3 +1,3 @@
from alr_envs.classic_control.hole_reacher.hole_reacher import HoleReacherEnv
from alr_envs.classic_control.viapoint_reacher.viapoint_reacher import ViaPointReacher from alr_envs.classic_control.viapoint_reacher.viapoint_reacher import ViaPointReacher
from alr_envs.classic_control.simple_reacher.simple_reacher import SimpleReacherEnv from alr_envs.classic_control.simple_reacher.simple_reacher import SimpleReacherEnv
from alr_envs.classic_control.hole_reacher.hole_reacher import HoleReacherEnv

View File

@ -21,14 +21,14 @@ class HoleReacherEnv(gym.Env):
self.random_start = random_start self.random_start = random_start
# provided initial parameters # provided initial parameters
self.hole_x = hole_x # x-position of center of hole self.initial_x = hole_x # x-position of center of hole
self.hole_width = hole_width # width of hole self.initial_width = hole_width # width of hole
self.hole_depth = hole_depth # depth of hole self.initial_depth = hole_depth # depth of hole
# temp container for current env state # temp container for current env state
self._tmp_hole_x = None self._tmp_x = None
self._tmp_hole_width = None self._tmp_width = None
self._tmp_hole_depth = None self._tmp_depth = None
self._goal = None # x-y coordinates for reaching the center at the bottom of the hole self._goal = None # x-y coordinates for reaching the center at the bottom of the hole
# collision # collision
@ -69,6 +69,10 @@ class HoleReacherEnv(gym.Env):
def dt(self) -> Union[float, int]: def dt(self) -> Union[float, int]:
return self._dt return self._dt
@property
def start_pos(self):
return self._start_pos
def step(self, action: np.ndarray): def step(self, action: np.ndarray):
""" """
A single step with an action in joint velocity space A single step with an action in joint velocity space
@ -110,13 +114,13 @@ class HoleReacherEnv(gym.Env):
return self._get_obs().copy() return self._get_obs().copy()
def _generate_hole(self): def _generate_hole(self):
self._tmp_hole_x = self.np_random.uniform(1, 3.5, 1) if self.hole_x is None else np.copy(self.hole_x) self._tmp_x = self.np_random.uniform(1, 3.5, 1) if self.initial_x is None else np.copy(self.initial_x)
self._tmp_hole_width = self.np_random.uniform(0.15, 0.5, 1) if self.hole_width is None else np.copy( self._tmp_width = self.np_random.uniform(0.15, 0.5, 1) if self.initial_width is None else np.copy(
self.hole_width) self.initial_width)
# TODO we do not want this right now. # TODO we do not want this right now.
self._tmp_hole_depth = self.np_random.uniform(1, 1, 1) if self.hole_depth is None else np.copy( self._tmp_depth = self.np_random.uniform(1, 1, 1) if self.initial_depth is None else np.copy(
self.hole_depth) self.initial_depth)
self._goal = np.hstack([self._tmp_hole_x, -self._tmp_hole_depth]) self._goal = np.hstack([self._tmp_x, -self._tmp_depth])
def _update_joints(self): def _update_joints(self):
""" """
@ -164,7 +168,7 @@ class HoleReacherEnv(gym.Env):
np.cos(theta), np.cos(theta),
np.sin(theta), np.sin(theta),
self._angle_velocity, self._angle_velocity,
self._tmp_hole_width, self._tmp_width,
# self._tmp_hole_depth, # self._tmp_hole_depth,
self.end_effector - self._goal, self.end_effector - self._goal,
self._steps self._steps
@ -192,7 +196,7 @@ class HoleReacherEnv(gym.Env):
def _check_wall_collision(self, line_points): def _check_wall_collision(self, line_points):
# all points that are before the hole in x # all points that are before the hole in x
r, c = np.where(line_points[:, :, 0] < (self._tmp_hole_x - self._tmp_hole_width / 2)) r, c = np.where(line_points[:, :, 0] < (self._tmp_x - self._tmp_width / 2))
# check if any of those points are below surface # check if any of those points are below surface
nr_line_points_below_surface_before_hole = np.sum(line_points[r, c, 1] < 0) nr_line_points_below_surface_before_hole = np.sum(line_points[r, c, 1] < 0)
@ -201,7 +205,7 @@ class HoleReacherEnv(gym.Env):
return True return True
# all points that are after the hole in x # all points that are after the hole in x
r, c = np.where(line_points[:, :, 0] > (self._tmp_hole_x + self._tmp_hole_width / 2)) r, c = np.where(line_points[:, :, 0] > (self._tmp_x + self._tmp_width / 2))
# check if any of those points are below surface # check if any of those points are below surface
nr_line_points_below_surface_after_hole = np.sum(line_points[r, c, 1] < 0) nr_line_points_below_surface_after_hole = np.sum(line_points[r, c, 1] < 0)
@ -210,11 +214,11 @@ class HoleReacherEnv(gym.Env):
return True return True
# all points that are above the hole # all points that are above the hole
r, c = np.where((line_points[:, :, 0] > (self._tmp_hole_x - self._tmp_hole_width / 2)) & ( r, c = np.where((line_points[:, :, 0] > (self._tmp_x - self._tmp_width / 2)) & (
line_points[:, :, 0] < (self._tmp_hole_x + self._tmp_hole_width / 2))) line_points[:, :, 0] < (self._tmp_x + self._tmp_width / 2)))
# check if any of those points are below surface # check if any of those points are below surface
nr_line_points_below_surface_in_hole = np.sum(line_points[r, c, 1] < -self._tmp_hole_depth) nr_line_points_below_surface_in_hole = np.sum(line_points[r, c, 1] < -self._tmp_depth)
if nr_line_points_below_surface_in_hole > 0: if nr_line_points_below_surface_in_hole > 0:
return True return True
@ -257,17 +261,17 @@ class HoleReacherEnv(gym.Env):
def _set_patches(self): def _set_patches(self):
if self.fig is not None: if self.fig is not None:
self.fig.gca().patches = [] self.fig.gca().patches = []
left_block = patches.Rectangle((-self.n_links, -self._tmp_hole_depth), left_block = patches.Rectangle((-self.n_links, -self._tmp_depth),
self.n_links + self._tmp_hole_x - self._tmp_hole_width / 2, self.n_links + self._tmp_x - self._tmp_width / 2,
self._tmp_hole_depth, self._tmp_depth,
fill=True, edgecolor='k', facecolor='k') fill=True, edgecolor='k', facecolor='k')
right_block = patches.Rectangle((self._tmp_hole_x + self._tmp_hole_width / 2, -self._tmp_hole_depth), right_block = patches.Rectangle((self._tmp_x + self._tmp_width / 2, -self._tmp_depth),
self.n_links - self._tmp_hole_x + self._tmp_hole_width / 2, self.n_links - self._tmp_x + self._tmp_width / 2,
self._tmp_hole_depth, self._tmp_depth,
fill=True, edgecolor='k', facecolor='k') fill=True, edgecolor='k', facecolor='k')
hole_floor = patches.Rectangle((self._tmp_hole_x - self._tmp_hole_width / 2, -self._tmp_hole_depth), hole_floor = patches.Rectangle((self._tmp_x - self._tmp_width / 2, -self._tmp_depth),
self._tmp_hole_width, self._tmp_width,
1 - self._tmp_hole_depth, 1 - self._tmp_depth,
fill=True, edgecolor='k', facecolor='k') fill=True, edgecolor='k', facecolor='k')
# Add the patch to the Axes # Add the patch to the Axes

View File

@ -12,7 +12,7 @@ class HoleReacherMPWrapper(MPEnvWrapper):
[self.env.random_start] * self.env.n_links, # cos [self.env.random_start] * self.env.n_links, # cos
[self.env.random_start] * self.env.n_links, # sin [self.env.random_start] * self.env.n_links, # sin
[self.env.random_start] * self.env.n_links, # velocity [self.env.random_start] * self.env.n_links, # velocity
[self.env.hole_width is None], # hole width [self.env.initial_width is None], # hole width
# [self.env.hole_depth is None], # hole depth # [self.env.hole_depth is None], # hole depth
[True] * 2, # x-y coordinates of target distance [True] * 2, # x-y coordinates of target distance
[False] # env steps [False] # env steps
@ -20,7 +20,7 @@ class HoleReacherMPWrapper(MPEnvWrapper):
@property @property
def start_pos(self) -> Union[float, int, np.ndarray]: def start_pos(self) -> Union[float, int, np.ndarray]:
return self._start_pos return self.env.start_pos
@property @property
def goal_pos(self) -> Union[float, int, np.ndarray]: def goal_pos(self) -> Union[float, int, np.ndarray]:

View File

@ -22,15 +22,18 @@ class SimpleReacherEnv(gym.Env):
self.random_start = random_start self.random_start = random_start
# provided initial parameters
self.inital_target = target
# temp container for current env state
self._goal = None
self._joints = None self._joints = None
self._joint_angles = None self._joint_angles = None
self._angle_velocity = None self._angle_velocity = None
self._start_pos = np.zeros(self.n_links) self._start_pos = np.zeros(self.n_links)
self._start_vel = np.zeros(self.n_links) self._start_vel = np.zeros(self.n_links)
self._target = target # provided target value
self._goal = None # updated goal value, does not change when target != None
self.max_torque = 1 self.max_torque = 1
self.steps_before_reward = 199 self.steps_before_reward = 199
@ -56,6 +59,10 @@ class SimpleReacherEnv(gym.Env):
def dt(self) -> Union[float, int]: def dt(self) -> Union[float, int]:
return self._dt return self._dt
@property
def start_pos(self):
return self._start_pos
def step(self, action: np.ndarray): def step(self, action: np.ndarray):
""" """
A single step with action in torque space A single step with action in torque space
@ -129,14 +136,14 @@ class SimpleReacherEnv(gym.Env):
def _generate_goal(self): def _generate_goal(self):
if self._target is None: if self.inital_target is None:
total_length = np.sum(self.link_lengths) total_length = np.sum(self.link_lengths)
goal = np.array([total_length, total_length]) goal = np.array([total_length, total_length])
while np.linalg.norm(goal) >= total_length: while np.linalg.norm(goal) >= total_length:
goal = self.np_random.uniform(low=-total_length, high=total_length, size=2) goal = self.np_random.uniform(low=-total_length, high=total_length, size=2)
else: else:
goal = np.copy(self._target) goal = np.copy(self.inital_target)
self._goal = goal self._goal = goal

View File

@ -18,7 +18,7 @@ class SimpleReacherMPWrapper(MPEnvWrapper):
@property @property
def start_pos(self): def start_pos(self):
return self._start_pos return self.env.start_pos
@property @property
def goal_pos(self): def goal_pos(self):

View File

@ -10,7 +10,7 @@ from alr_envs.classic_control.utils import check_self_collision
class ViaPointReacher(gym.Env): class ViaPointReacher(gym.Env):
def __init__(self, n_links, random_start: bool = True, via_target: Union[None, Iterable] = None, def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None,
target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000): target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000):
self.n_links = n_links self.n_links = n_links
@ -19,8 +19,8 @@ class ViaPointReacher(gym.Env):
self.random_start = random_start self.random_start = random_start
# provided initial parameters # provided initial parameters
self.target = target # provided target value self.intitial_target = target # provided target value
self.via_target = via_target # provided via point target value self.initial_via_target = via_target # provided via point target value
# temp container for current env state # temp container for current env state
self._via_point = np.ones(2) self._via_point = np.ones(2)
@ -63,6 +63,10 @@ class ViaPointReacher(gym.Env):
def dt(self): def dt(self):
return self._dt return self._dt
@property
def start_pos(self):
return self._start_pos
def step(self, action: np.ndarray): def step(self, action: np.ndarray):
""" """
a single step with an action in joint velocity space a single step with an action in joint velocity space
@ -107,22 +111,22 @@ class ViaPointReacher(gym.Env):
total_length = np.sum(self.link_lengths) total_length = np.sum(self.link_lengths)
# rejection sampled point in inner circle with 0.5*Radius # rejection sampled point in inner circle with 0.5*Radius
if self.via_target is None: if self.initial_via_target is None:
via_target = np.array([total_length, total_length]) via_target = np.array([total_length, total_length])
while np.linalg.norm(via_target) >= 0.5 * total_length: while np.linalg.norm(via_target) >= 0.5 * total_length:
via_target = self.np_random.uniform(low=-0.5 * total_length, high=0.5 * total_length, size=2) via_target = self.np_random.uniform(low=-0.5 * total_length, high=0.5 * total_length, size=2)
else: else:
via_target = np.copy(self.via_target) via_target = np.copy(self.initial_via_target)
# rejection sampled point in outer circle # rejection sampled point in outer circle
if self.target is None: if self.intitial_target is None:
goal = np.array([total_length, total_length]) goal = np.array([total_length, total_length])
while np.linalg.norm(goal) >= total_length or np.linalg.norm(goal) <= 0.5 * total_length: while np.linalg.norm(goal) >= total_length or np.linalg.norm(goal) <= 0.5 * total_length:
goal = self.np_random.uniform(low=-total_length, high=total_length, size=2) goal = self.np_random.uniform(low=-total_length, high=total_length, size=2)
else: else:
goal = np.copy(self.target) goal = np.copy(self.intitial_target)
self.via_target = via_target self._via_point = via_target
self._goal = goal self._goal = goal
def _update_joints(self): def _update_joints(self):

View File

@ -12,18 +12,19 @@ class ViaPointReacherMPWrapper(MPEnvWrapper):
[self.env.random_start] * self.env.n_links, # cos [self.env.random_start] * self.env.n_links, # cos
[self.env.random_start] * self.env.n_links, # sin [self.env.random_start] * self.env.n_links, # sin
[self.env.random_start] * self.env.n_links, # velocity [self.env.random_start] * self.env.n_links, # velocity
[self.env.via_target is None] * 2, # x-y coordinates of via point distance [self.env.initial_via_target is None] * 2, # x-y coordinates of via point distance
[True] * 2, # x-y coordinates of target distance [True] * 2, # x-y coordinates of target distance
[False] # env steps [False] # env steps
]) ])
@property @property
def start_pos(self) -> Union[float, int, np.ndarray]: def start_pos(self) -> Union[float, int, np.ndarray]:
return self._start_pos return self.env.start_pos
@property @property
def goal_pos(self) -> Union[float, int, np.ndarray]: def goal_pos(self) -> Union[float, int, np.ndarray]:
raise ValueError("Goal position is not available and has to be learnt based on the environment.") raise ValueError("Goal position is not available and has to be learnt based on the environment.")
@property
def dt(self) -> Union[float, int]: def dt(self) -> Union[float, int]:
return self.env.dt return self.env.dt

View File

@ -1,7 +1,9 @@
from collections import defaultdict from collections import defaultdict
import gym import gym
import numpy as np import numpy as np
from alr_envs.utils.mp_env_async_sampler import AlrMpEnvSampler, AlrContextualMpEnvSampler, DummyDist
from alr_envs.utils.mp_env_async_sampler import AlrContextualMpEnvSampler, AlrMpEnvSampler, DummyDist
def example_mujoco(): def example_mujoco():
@ -14,8 +16,8 @@ def example_mujoco():
obs, reward, done, info = env.step(env.action_space.sample()) obs, reward, done, info = env.step(env.action_space.sample())
rewards += reward rewards += reward
if i % 1 == 0: # if i % 1 == 0:
env.render() # env.render()
if done: if done:
print(rewards) print(rewards)
@ -23,8 +25,7 @@ def example_mujoco():
obs = env.reset() obs = env.reset()
def example_mp(env_name="alr_envs:HoleReacherDMP-v0"): def example_mp(env_name="alr_envs:HoleReacherDMP-v1"):
# env = gym.make("alr_envs:ViaPointReacherDMP-v0")
env = gym.make(env_name) env = gym.make(env_name)
rewards = 0 rewards = 0
# env.render(mode=None) # env.render(mode=None)
@ -105,11 +106,12 @@ def example_async_contextual_sampler(env_name="alr_envs:SimpleReacherDMP-v1", n_
if __name__ == '__main__': if __name__ == '__main__':
example_mp("alr_envs:HoleReacherDetPMP-v0")
# example_mujoco() # example_mujoco()
# example_mp("alr_envs:SimpleReacherDMP-v1") # example_mp("alr_envs:SimpleReacherDMP-v1")
# example_async("alr_envs:LongSimpleReacherDMP-v0", 4) # example_async("alr_envs:LongSimpleReacherDMP-v0", 4)
# example_async_contextual_sampler() # example_async_contextual_sampler()
# env = gym.make("alr_envs:HoleReacherDetPMP-v1") # env = gym.make("alr_envs:HoleReacherDetPMP-v1")
env_name = "alr_envs:ALRBallInACupPDSimpleDetPMP-v0" # env_name = "alr_envs:ALRBallInACupPDSimpleDetPMP-v0"
example_async_sampler(env_name) # example_async_sampler(env_name)
# example_mp(env_name) # example_mp(env_name)