updated envs and registering

This commit is contained in:
ottofabian 2021-06-25 16:16:56 +02:00
parent 0a1e55d97b
commit c8742e2934
9 changed files with 248 additions and 179 deletions

View File

@ -1,12 +1,16 @@
import numpy as np
from gym.envs.registration import register
from alr_envs.classic_control.hole_reacher.hole_reacher_mp_wrapper import HoleReacherMPWrapper
from alr_envs.classic_control.simple_reacher.simple_reacher_mp_wrapper import SimpleReacherMPWrapper
from alr_envs.classic_control.viapoint_reacher.viapoint_reacher_mp_wrapper import ViaPointReacherMPWrapper
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_mp_wrapper import BallInACupMPWrapper
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_positional_wrapper import BallInACupPositionalWrapper
from alr_envs.stochastic_search.functions.f_rosenbrock import Rosenbrock
# from alr_envs.utils.mps.dmp_wrapper import DmpWrapper
# Mujoco
## Reacher
register(
id='ALRReacher-v0',
entry_point='alr_envs.mujoco:ALRReacherEnv',
@ -177,7 +181,7 @@ register(
register(
id='ViaPointReacher-v0',
entry_point='alr_envs.classic_control.viapoint_reacher:ViaPointReacher',
entry_point='alr_envs.classic_control:ViaPointReacher',
max_episode_steps=200,
kwargs={
"n_links": 5,
@ -189,7 +193,7 @@ register(
## Hole Reacher
register(
id='HoleReacher-v0',
entry_point='alr_envs.classic_control.hole_reacher:HoleReacherEnv',
entry_point='alr_envs.classic_control:HoleReacherEnv',
max_episode_steps=200,
kwargs={
"n_links": 5,
@ -205,7 +209,7 @@ register(
register(
id='HoleReacher-v1',
entry_point='alr_envs.classic_control.hole_reacher:HoleReacherEnv',
entry_point='alr_envs.classic_control:HoleReacherEnv',
max_episode_steps=200,
kwargs={
"n_links": 5,
@ -221,7 +225,7 @@ register(
register(
id='HoleReacher-v2',
entry_point='alr_envs.classic_control.hole_reacher:HoleReacherEnv',
entry_point='alr_envs.classic_control:HoleReacherEnv',
max_episode_steps=200,
kwargs={
"n_links": 5,
@ -236,39 +240,46 @@ register(
)
# MP environments
## Simple Reacher
versions = ["SimpleReacher-v0", "SimpleReacher-v1", "LongSimpleReacher-v0", "LongSimpleReacher-v1"]
for v in versions:
name = v.split("-")
register(
id=f'{name[0]}DMP-{name[1]}',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
# max_episode_steps=1,
kwargs={
"name": f"alr_envs:{v}",
"num_dof": 2 if "long" not in v.lower() else 5,
"num_basis": 5,
"duration": 2,
"alpha_phase": 2,
"learn_goal": True,
"policy_type": "velocity",
"weights_scale": 50,
"wrappers": [SimpleReacherMPWrapper],
"mp_kwargs": {
"num_dof": 2 if "long" not in v.lower() else 5,
"num_basis": 5,
"duration": 2,
"alpha_phase": 2,
"learn_goal": True,
"policy_type": "velocity",
"weights_scale": 50,
}
}
)
register(
id='ViaPointReacherDMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
# max_episode_steps=1,
kwargs={
"name": "alr_envs:ViaPointReacher-v0",
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"alpha_phase": 2,
"learn_goal": False,
"policy_type": "velocity",
"weights_scale": 50,
"wrappers": [ViaPointReacherMPWrapper],
"mp_kwargs": {
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"learn_goal": True,
"alpha_phase": 2,
"policy_type": "velocity",
"weights_scale": 50,
}
}
)
@ -277,52 +288,61 @@ versions = ["v0", "v1", "v2"]
for v in versions:
register(
id=f'HoleReacherDMP-{v}',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
# max_episode_steps=1,
kwargs={
"name": f"alr_envs:HoleReacher-{v}",
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"learn_goal": True,
"alpha_phase": 2,
"bandwidth_factor": 2,
"policy_type": "velocity",
"weights_scale": 50,
"goal_scale": 0.1
"wrappers": [HoleReacherMPWrapper],
"mp_kwargs": {
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"learn_goal": True,
"alpha_phase": 2,
"bandwidth_factor": 2,
"policy_type": "velocity",
"weights_scale": 50,
"goal_scale": 0.1
}
}
)
register(
id=f'HoleReacherDetPMP-{v}',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
kwargs={
"name": f"alr_envs:HoleReacher-{v}",
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"width": 0.025,
"policy_type": "velocity",
"weights_scale": 0.2,
"zero_start": True
"wrappers": [HoleReacherMPWrapper],
"mp_kwargs": {
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"width": 0.025,
"policy_type": "velocity",
"weights_scale": 0.2,
"zero_start": True
}
}
)
# TODO: properly add final_pos
register(
id='HoleReacherFixedGoalDMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
# max_episode_steps=1,
kwargs={
"name": "alr_envs:HoleReacher-v0",
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"learn_goal": False,
"alpha_phase": 2,
"policy_type": "velocity",
"weights_scale": 50,
"goal_scale": 0.1
"wrappers": [HoleReacherMPWrapper],
"mp_kwargs": {
"num_dof": 5,
"num_basis": 5,
"duration": 2,
"learn_goal": False,
"alpha_phase": 2,
"policy_type": "velocity",
"weights_scale": 50,
"goal_scale": 0.1
}
}
)
@ -330,81 +350,101 @@ register(
register(
id='ALRBallInACupSimpleDMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
kwargs={
"name": "alr_envs:ALRBallInACupSimple-v0",
"num_dof": 3,
"num_basis": 5,
"duration": 3.5,
"post_traj_time": 4.5,
"learn_goal": False,
"alpha_phase": 3,
"bandwidth_factor": 2.5,
"policy_type": "motor",
"weights_scale": 100,
"return_to_start": True,
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
"mp_kwargs": {
"num_dof": 3,
"num_basis": 5,
"duration": 3.5,
"post_traj_time": 4.5,
"learn_goal": False,
"alpha_phase": 3,
"bandwidth_factor": 2.5,
"policy_type": "motor",
"weights_scale": 100,
"return_to_start": True,
"policy_kwargs": {
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
}
}
}
)
register(
id='ALRBallInACupDMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
kwargs={
"name": "alr_envs:ALRBallInACup-v0",
"num_dof": 7,
"num_basis": 5,
"duration": 3.5,
"post_traj_time": 4.5,
"learn_goal": False,
"alpha_phase": 3,
"bandwidth_factor": 2.5,
"policy_type": "motor",
"weights_scale": 100,
"return_to_start": True,
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
"mp_kwargs": {
"num_dof": 7,
"num_basis": 5,
"duration": 3.5,
"post_traj_time": 4.5,
"learn_goal": False,
"alpha_phase": 3,
"bandwidth_factor": 2.5,
"policy_type": "motor",
"weights_scale": 100,
"return_to_start": True,
"policy_kwargs": {
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
}
}
}
)
register(
id='ALRBallInACupSimpleDetPMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
kwargs={
"name": "alr_envs:ALRBallInACupSimple-v0",
"num_dof": 3,
"num_basis": 5,
"duration": 3.5,
"post_traj_time": 4.5,
"width": 0.0035,
# "off": -0.05,
"policy_type": "motor",
"weights_scale": 0.2,
"zero_start": True,
"zero_goal": True,
"p_gains": np.array([4./3., 2.4, 2.5, 5./3., 2., 2., 1.25]),
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
"mp_kwargs": {
"num_dof": 3,
"num_basis": 5,
"duration": 3.5,
"post_traj_time": 4.5,
"width": 0.0035,
# "off": -0.05,
"policy_type": "motor",
"weights_scale": 0.2,
"zero_start": True,
"zero_goal": True,
"policy_kwargs": {
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
}
}
}
)
register(
id='ALRBallInACupPDSimpleDetPMP-v0',
entry_point='alr_envs.mujoco.ball_in_a_cup.biac_pd:make_detpmp_env',
entry_point='alr_envs.mujoco.ball_in_a_cup.biac_pd:make_detpmp_env_helper',
kwargs={
"name": "alr_envs:ALRBallInACupPDSimple-v0",
"num_dof": 3,
"num_basis": 5,
"duration": 3.5,
"post_traj_time": 4.5,
"width": 0.0035,
# "off": -0.05,
"policy_type": "motor",
"weights_scale": 0.2,
"zero_start": True,
"zero_goal": True,
"p_gains": np.array([4./3., 2.4, 2.5, 5./3., 2., 2., 1.25]),
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
"mp_kwargs": {
"num_dof": 3,
"num_basis": 5,
"duration": 3.5,
"post_traj_time": 4.5,
"width": 0.0035,
# "off": -0.05,
"policy_type": "motor",
"weights_scale": 0.2,
"zero_start": True,
"zero_goal": True,
"policy_kwargs": {
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
}
}
}
)
@ -430,20 +470,26 @@ register(
register(
id='ALRBallInACupDetPMP-v0',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env',
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
kwargs={
"name": "alr_envs:ALRBallInACupSimple-v0",
"num_dof": 7,
"num_basis": 5,
"duration": 3.5,
"post_traj_time": 4.5,
"width": 0.0035,
"policy_type": "motor",
"weights_scale": 0.2,
"zero_start": True,
"zero_goal": True,
"p_gains": np.array([4./3., 2.4, 2.5, 5./3., 2., 2., 1.25]),
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
"mp_kwargs": {
"num_dof": 7,
"num_basis": 5,
"duration": 3.5,
"post_traj_time": 4.5,
"width": 0.0035,
"policy_type": "motor",
"weights_scale": 0.2,
"zero_start": True,
"zero_goal": True,
"policy_kwargs": {
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
}
}
}
)
@ -452,18 +498,23 @@ register(
entry_point='alr_envs.utils.make_env_helpers:make_contextual_env',
kwargs={
"name": "alr_envs:ALRBallInACupGoal-v0",
"num_dof": 7,
"num_basis": 5,
"duration": 3.5,
"post_traj_time": 4.5,
"learn_goal": True,
"alpha_phase": 3,
"bandwidth_factor": 2.5,
"policy_type": "motor",
"weights_scale": 50,
"goal_scale": 0.1,
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
"mp_kwargs": {
"num_dof": 7,
"num_basis": 5,
"duration": 3.5,
"post_traj_time": 4.5,
"learn_goal": True,
"alpha_phase": 3,
"bandwidth_factor": 2.5,
"policy_type": "motor",
"weights_scale": 50,
"goal_scale": 0.1,
"policy_kwargs": {
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
}
}
}
)

View File

@ -1,3 +1,3 @@
from alr_envs.classic_control.hole_reacher.hole_reacher import HoleReacherEnv
from alr_envs.classic_control.viapoint_reacher.viapoint_reacher import ViaPointReacher
from alr_envs.classic_control.simple_reacher.simple_reacher import SimpleReacherEnv
from alr_envs.classic_control.hole_reacher.hole_reacher import HoleReacherEnv

View File

@ -21,14 +21,14 @@ class HoleReacherEnv(gym.Env):
self.random_start = random_start
# provided initial parameters
self.hole_x = hole_x # x-position of center of hole
self.hole_width = hole_width # width of hole
self.hole_depth = hole_depth # depth of hole
self.initial_x = hole_x # x-position of center of hole
self.initial_width = hole_width # width of hole
self.initial_depth = hole_depth # depth of hole
# temp container for current env state
self._tmp_hole_x = None
self._tmp_hole_width = None
self._tmp_hole_depth = None
self._tmp_x = None
self._tmp_width = None
self._tmp_depth = None
self._goal = None # x-y coordinates for reaching the center at the bottom of the hole
# collision
@ -69,6 +69,10 @@ class HoleReacherEnv(gym.Env):
def dt(self) -> Union[float, int]:
return self._dt
@property
def start_pos(self):
return self._start_pos
def step(self, action: np.ndarray):
"""
A single step with an action in joint velocity space
@ -110,13 +114,13 @@ class HoleReacherEnv(gym.Env):
return self._get_obs().copy()
def _generate_hole(self):
self._tmp_hole_x = self.np_random.uniform(1, 3.5, 1) if self.hole_x is None else np.copy(self.hole_x)
self._tmp_hole_width = self.np_random.uniform(0.15, 0.5, 1) if self.hole_width is None else np.copy(
self.hole_width)
self._tmp_x = self.np_random.uniform(1, 3.5, 1) if self.initial_x is None else np.copy(self.initial_x)
self._tmp_width = self.np_random.uniform(0.15, 0.5, 1) if self.initial_width is None else np.copy(
self.initial_width)
# TODO we do not want this right now.
self._tmp_hole_depth = self.np_random.uniform(1, 1, 1) if self.hole_depth is None else np.copy(
self.hole_depth)
self._goal = np.hstack([self._tmp_hole_x, -self._tmp_hole_depth])
self._tmp_depth = self.np_random.uniform(1, 1, 1) if self.initial_depth is None else np.copy(
self.initial_depth)
self._goal = np.hstack([self._tmp_x, -self._tmp_depth])
def _update_joints(self):
"""
@ -164,7 +168,7 @@ class HoleReacherEnv(gym.Env):
np.cos(theta),
np.sin(theta),
self._angle_velocity,
self._tmp_hole_width,
self._tmp_width,
# self._tmp_hole_depth,
self.end_effector - self._goal,
self._steps
@ -192,7 +196,7 @@ class HoleReacherEnv(gym.Env):
def _check_wall_collision(self, line_points):
# all points that are before the hole in x
r, c = np.where(line_points[:, :, 0] < (self._tmp_hole_x - self._tmp_hole_width / 2))
r, c = np.where(line_points[:, :, 0] < (self._tmp_x - self._tmp_width / 2))
# check if any of those points are below surface
nr_line_points_below_surface_before_hole = np.sum(line_points[r, c, 1] < 0)
@ -201,7 +205,7 @@ class HoleReacherEnv(gym.Env):
return True
# all points that are after the hole in x
r, c = np.where(line_points[:, :, 0] > (self._tmp_hole_x + self._tmp_hole_width / 2))
r, c = np.where(line_points[:, :, 0] > (self._tmp_x + self._tmp_width / 2))
# check if any of those points are below surface
nr_line_points_below_surface_after_hole = np.sum(line_points[r, c, 1] < 0)
@ -210,11 +214,11 @@ class HoleReacherEnv(gym.Env):
return True
# all points that are above the hole
r, c = np.where((line_points[:, :, 0] > (self._tmp_hole_x - self._tmp_hole_width / 2)) & (
line_points[:, :, 0] < (self._tmp_hole_x + self._tmp_hole_width / 2)))
r, c = np.where((line_points[:, :, 0] > (self._tmp_x - self._tmp_width / 2)) & (
line_points[:, :, 0] < (self._tmp_x + self._tmp_width / 2)))
# check if any of those points are below surface
nr_line_points_below_surface_in_hole = np.sum(line_points[r, c, 1] < -self._tmp_hole_depth)
nr_line_points_below_surface_in_hole = np.sum(line_points[r, c, 1] < -self._tmp_depth)
if nr_line_points_below_surface_in_hole > 0:
return True
@ -257,17 +261,17 @@ class HoleReacherEnv(gym.Env):
def _set_patches(self):
if self.fig is not None:
self.fig.gca().patches = []
left_block = patches.Rectangle((-self.n_links, -self._tmp_hole_depth),
self.n_links + self._tmp_hole_x - self._tmp_hole_width / 2,
self._tmp_hole_depth,
left_block = patches.Rectangle((-self.n_links, -self._tmp_depth),
self.n_links + self._tmp_x - self._tmp_width / 2,
self._tmp_depth,
fill=True, edgecolor='k', facecolor='k')
right_block = patches.Rectangle((self._tmp_hole_x + self._tmp_hole_width / 2, -self._tmp_hole_depth),
self.n_links - self._tmp_hole_x + self._tmp_hole_width / 2,
self._tmp_hole_depth,
right_block = patches.Rectangle((self._tmp_x + self._tmp_width / 2, -self._tmp_depth),
self.n_links - self._tmp_x + self._tmp_width / 2,
self._tmp_depth,
fill=True, edgecolor='k', facecolor='k')
hole_floor = patches.Rectangle((self._tmp_hole_x - self._tmp_hole_width / 2, -self._tmp_hole_depth),
self._tmp_hole_width,
1 - self._tmp_hole_depth,
hole_floor = patches.Rectangle((self._tmp_x - self._tmp_width / 2, -self._tmp_depth),
self._tmp_width,
1 - self._tmp_depth,
fill=True, edgecolor='k', facecolor='k')
# Add the patch to the Axes

View File

@ -12,7 +12,7 @@ class HoleReacherMPWrapper(MPEnvWrapper):
[self.env.random_start] * self.env.n_links, # cos
[self.env.random_start] * self.env.n_links, # sin
[self.env.random_start] * self.env.n_links, # velocity
[self.env.hole_width is None], # hole width
[self.env.initial_width is None], # hole width
# [self.env.hole_depth is None], # hole depth
[True] * 2, # x-y coordinates of target distance
[False] # env steps
@ -20,7 +20,7 @@ class HoleReacherMPWrapper(MPEnvWrapper):
@property
def start_pos(self) -> Union[float, int, np.ndarray]:
return self._start_pos
return self.env.start_pos
@property
def goal_pos(self) -> Union[float, int, np.ndarray]:
@ -28,4 +28,4 @@ class HoleReacherMPWrapper(MPEnvWrapper):
@property
def dt(self) -> Union[float, int]:
return self.env.dt
return self.env.dt

View File

@ -22,15 +22,18 @@ class SimpleReacherEnv(gym.Env):
self.random_start = random_start
# provided initial parameters
self.inital_target = target
# temp container for current env state
self._goal = None
self._joints = None
self._joint_angles = None
self._angle_velocity = None
self._start_pos = np.zeros(self.n_links)
self._start_vel = np.zeros(self.n_links)
self._target = target # provided target value
self._goal = None # updated goal value, does not change when target != None
self.max_torque = 1
self.steps_before_reward = 199
@ -56,6 +59,10 @@ class SimpleReacherEnv(gym.Env):
def dt(self) -> Union[float, int]:
return self._dt
@property
def start_pos(self):
return self._start_pos
def step(self, action: np.ndarray):
"""
A single step with action in torque space
@ -129,14 +136,14 @@ class SimpleReacherEnv(gym.Env):
def _generate_goal(self):
if self._target is None:
if self.inital_target is None:
total_length = np.sum(self.link_lengths)
goal = np.array([total_length, total_length])
while np.linalg.norm(goal) >= total_length:
goal = self.np_random.uniform(low=-total_length, high=total_length, size=2)
else:
goal = np.copy(self._target)
goal = np.copy(self.inital_target)
self._goal = goal

View File

@ -18,7 +18,7 @@ class SimpleReacherMPWrapper(MPEnvWrapper):
@property
def start_pos(self):
return self._start_pos
return self.env.start_pos
@property
def goal_pos(self):

View File

@ -10,7 +10,7 @@ from alr_envs.classic_control.utils import check_self_collision
class ViaPointReacher(gym.Env):
def __init__(self, n_links, random_start: bool = True, via_target: Union[None, Iterable] = None,
def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None,
target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000):
self.n_links = n_links
@ -19,8 +19,8 @@ class ViaPointReacher(gym.Env):
self.random_start = random_start
# provided initial parameters
self.target = target # provided target value
self.via_target = via_target # provided via point target value
self.intitial_target = target # provided target value
self.initial_via_target = via_target # provided via point target value
# temp container for current env state
self._via_point = np.ones(2)
@ -63,6 +63,10 @@ class ViaPointReacher(gym.Env):
def dt(self):
return self._dt
@property
def start_pos(self):
return self._start_pos
def step(self, action: np.ndarray):
"""
a single step with an action in joint velocity space
@ -107,22 +111,22 @@ class ViaPointReacher(gym.Env):
total_length = np.sum(self.link_lengths)
# rejection sampled point in inner circle with 0.5*Radius
if self.via_target is None:
if self.initial_via_target is None:
via_target = np.array([total_length, total_length])
while np.linalg.norm(via_target) >= 0.5 * total_length:
via_target = self.np_random.uniform(low=-0.5 * total_length, high=0.5 * total_length, size=2)
else:
via_target = np.copy(self.via_target)
via_target = np.copy(self.initial_via_target)
# rejection sampled point in outer circle
if self.target is None:
if self.intitial_target is None:
goal = np.array([total_length, total_length])
while np.linalg.norm(goal) >= total_length or np.linalg.norm(goal) <= 0.5 * total_length:
goal = self.np_random.uniform(low=-total_length, high=total_length, size=2)
else:
goal = np.copy(self.target)
goal = np.copy(self.intitial_target)
self.via_target = via_target
self._via_point = via_target
self._goal = goal
def _update_joints(self):

View File

@ -12,18 +12,19 @@ class ViaPointReacherMPWrapper(MPEnvWrapper):
[self.env.random_start] * self.env.n_links, # cos
[self.env.random_start] * self.env.n_links, # sin
[self.env.random_start] * self.env.n_links, # velocity
[self.env.via_target is None] * 2, # x-y coordinates of via point distance
[self.env.initial_via_target is None] * 2, # x-y coordinates of via point distance
[True] * 2, # x-y coordinates of target distance
[False] # env steps
])
@property
def start_pos(self) -> Union[float, int, np.ndarray]:
return self._start_pos
return self.env.start_pos
@property
def goal_pos(self) -> Union[float, int, np.ndarray]:
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
@property
def dt(self) -> Union[float, int]:
return self.env.dt

View File

@ -1,7 +1,9 @@
from collections import defaultdict
import gym
import numpy as np
from alr_envs.utils.mp_env_async_sampler import AlrMpEnvSampler, AlrContextualMpEnvSampler, DummyDist
from alr_envs.utils.mp_env_async_sampler import AlrContextualMpEnvSampler, AlrMpEnvSampler, DummyDist
def example_mujoco():
@ -14,8 +16,8 @@ def example_mujoco():
obs, reward, done, info = env.step(env.action_space.sample())
rewards += reward
if i % 1 == 0:
env.render()
# if i % 1 == 0:
# env.render()
if done:
print(rewards)
@ -23,8 +25,7 @@ def example_mujoco():
obs = env.reset()
def example_mp(env_name="alr_envs:HoleReacherDMP-v0"):
# env = gym.make("alr_envs:ViaPointReacherDMP-v0")
def example_mp(env_name="alr_envs:HoleReacherDMP-v1"):
env = gym.make(env_name)
rewards = 0
# env.render(mode=None)
@ -105,11 +106,12 @@ def example_async_contextual_sampler(env_name="alr_envs:SimpleReacherDMP-v1", n_
if __name__ == '__main__':
example_mp("alr_envs:HoleReacherDetPMP-v0")
# example_mujoco()
# example_mp("alr_envs:SimpleReacherDMP-v1")
# example_async("alr_envs:LongSimpleReacherDMP-v0", 4)
# example_async_contextual_sampler()
# env = gym.make("alr_envs:HoleReacherDetPMP-v1")
env_name = "alr_envs:ALRBallInACupPDSimpleDetPMP-v0"
example_async_sampler(env_name)
# env_name = "alr_envs:ALRBallInACupPDSimpleDetPMP-v0"
# example_async_sampler(env_name)
# example_mp(env_name)