updated envs and registering
This commit is contained in:
parent
0a1e55d97b
commit
c8742e2934
@ -1,12 +1,16 @@
|
||||
import numpy as np
|
||||
from gym.envs.registration import register
|
||||
|
||||
from alr_envs.classic_control.hole_reacher.hole_reacher_mp_wrapper import HoleReacherMPWrapper
|
||||
from alr_envs.classic_control.simple_reacher.simple_reacher_mp_wrapper import SimpleReacherMPWrapper
|
||||
from alr_envs.classic_control.viapoint_reacher.viapoint_reacher_mp_wrapper import ViaPointReacherMPWrapper
|
||||
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_mp_wrapper import BallInACupMPWrapper
|
||||
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_positional_wrapper import BallInACupPositionalWrapper
|
||||
from alr_envs.stochastic_search.functions.f_rosenbrock import Rosenbrock
|
||||
|
||||
# from alr_envs.utils.mps.dmp_wrapper import DmpWrapper
|
||||
|
||||
# Mujoco
|
||||
|
||||
## Reacher
|
||||
register(
|
||||
id='ALRReacher-v0',
|
||||
entry_point='alr_envs.mujoco:ALRReacherEnv',
|
||||
@ -177,7 +181,7 @@ register(
|
||||
|
||||
register(
|
||||
id='ViaPointReacher-v0',
|
||||
entry_point='alr_envs.classic_control.viapoint_reacher:ViaPointReacher',
|
||||
entry_point='alr_envs.classic_control:ViaPointReacher',
|
||||
max_episode_steps=200,
|
||||
kwargs={
|
||||
"n_links": 5,
|
||||
@ -189,7 +193,7 @@ register(
|
||||
## Hole Reacher
|
||||
register(
|
||||
id='HoleReacher-v0',
|
||||
entry_point='alr_envs.classic_control.hole_reacher:HoleReacherEnv',
|
||||
entry_point='alr_envs.classic_control:HoleReacherEnv',
|
||||
max_episode_steps=200,
|
||||
kwargs={
|
||||
"n_links": 5,
|
||||
@ -205,7 +209,7 @@ register(
|
||||
|
||||
register(
|
||||
id='HoleReacher-v1',
|
||||
entry_point='alr_envs.classic_control.hole_reacher:HoleReacherEnv',
|
||||
entry_point='alr_envs.classic_control:HoleReacherEnv',
|
||||
max_episode_steps=200,
|
||||
kwargs={
|
||||
"n_links": 5,
|
||||
@ -221,7 +225,7 @@ register(
|
||||
|
||||
register(
|
||||
id='HoleReacher-v2',
|
||||
entry_point='alr_envs.classic_control.hole_reacher:HoleReacherEnv',
|
||||
entry_point='alr_envs.classic_control:HoleReacherEnv',
|
||||
max_episode_steps=200,
|
||||
kwargs={
|
||||
"n_links": 5,
|
||||
@ -236,16 +240,19 @@ register(
|
||||
)
|
||||
|
||||
# MP environments
|
||||
|
||||
## Simple Reacher
|
||||
versions = ["SimpleReacher-v0", "SimpleReacher-v1", "LongSimpleReacher-v0", "LongSimpleReacher-v1"]
|
||||
for v in versions:
|
||||
name = v.split("-")
|
||||
register(
|
||||
id=f'{name[0]}DMP-{name[1]}',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": f"alr_envs:{v}",
|
||||
"wrappers": [SimpleReacherMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 2 if "long" not in v.lower() else 5,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
@ -254,22 +261,26 @@ for v in versions:
|
||||
"policy_type": "velocity",
|
||||
"weights_scale": 50,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='ViaPointReacherDMP-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": "alr_envs:ViaPointReacher-v0",
|
||||
"wrappers": [ViaPointReacherMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 5,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
"learn_goal": True,
|
||||
"alpha_phase": 2,
|
||||
"learn_goal": False,
|
||||
"policy_type": "velocity",
|
||||
"weights_scale": 50,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
## Hole Reacher
|
||||
@ -277,10 +288,12 @@ versions = ["v0", "v1", "v2"]
|
||||
for v in versions:
|
||||
register(
|
||||
id=f'HoleReacherDMP-{v}',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": f"alr_envs:HoleReacher-{v}",
|
||||
"wrappers": [HoleReacherMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 5,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
@ -291,13 +304,16 @@ for v in versions:
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id=f'HoleReacherDetPMP-{v}',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": f"alr_envs:HoleReacher-{v}",
|
||||
"wrappers": [HoleReacherMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 5,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
@ -306,15 +322,18 @@ for v in versions:
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# TODO: properly add final_pos
|
||||
register(
|
||||
id='HoleReacherFixedGoalDMP-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
# max_episode_steps=1,
|
||||
kwargs={
|
||||
"name": "alr_envs:HoleReacher-v0",
|
||||
"wrappers": [HoleReacherMPWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 5,
|
||||
"num_basis": 5,
|
||||
"duration": 2,
|
||||
@ -324,15 +343,18 @@ register(
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
## Ball in Cup
|
||||
|
||||
register(
|
||||
id='ALRBallInACupSimpleDMP-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
kwargs={
|
||||
"name": "alr_envs:ALRBallInACupSimple-v0",
|
||||
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 3,
|
||||
"num_basis": 5,
|
||||
"duration": 3.5,
|
||||
@ -343,16 +365,21 @@ register(
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 100,
|
||||
"return_to_start": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRBallInACupDMP-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper',
|
||||
kwargs={
|
||||
"name": "alr_envs:ALRBallInACup-v0",
|
||||
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 7,
|
||||
"num_basis": 5,
|
||||
"duration": 3.5,
|
||||
@ -363,16 +390,21 @@ register(
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 100,
|
||||
"return_to_start": True,
|
||||
"policy_kwargs": {
|
||||
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRBallInACupSimpleDetPMP-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": "alr_envs:ALRBallInACupSimple-v0",
|
||||
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 3,
|
||||
"num_basis": 5,
|
||||
"duration": 3.5,
|
||||
@ -383,16 +415,21 @@ register(
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"zero_goal": True,
|
||||
"p_gains": np.array([4./3., 2.4, 2.5, 5./3., 2., 2., 1.25]),
|
||||
"policy_kwargs": {
|
||||
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRBallInACupPDSimpleDetPMP-v0',
|
||||
entry_point='alr_envs.mujoco.ball_in_a_cup.biac_pd:make_detpmp_env',
|
||||
entry_point='alr_envs.mujoco.ball_in_a_cup.biac_pd:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": "alr_envs:ALRBallInACupPDSimple-v0",
|
||||
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 3,
|
||||
"num_basis": 5,
|
||||
"duration": 3.5,
|
||||
@ -403,9 +440,12 @@ register(
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"zero_goal": True,
|
||||
"p_gains": np.array([4./3., 2.4, 2.5, 5./3., 2., 2., 1.25]),
|
||||
"policy_kwargs": {
|
||||
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
@ -430,9 +470,11 @@ register(
|
||||
|
||||
register(
|
||||
id='ALRBallInACupDetPMP-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||
kwargs={
|
||||
"name": "alr_envs:ALRBallInACupSimple-v0",
|
||||
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 7,
|
||||
"num_basis": 5,
|
||||
"duration": 3.5,
|
||||
@ -442,9 +484,13 @@ register(
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"zero_goal": True,
|
||||
"p_gains": np.array([4./3., 2.4, 2.5, 5./3., 2., 2., 1.25]),
|
||||
"policy_kwargs": {
|
||||
|
||||
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
@ -452,6 +498,8 @@ register(
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_contextual_env',
|
||||
kwargs={
|
||||
"name": "alr_envs:ALRBallInACupGoal-v0",
|
||||
"wrappers": [BallInACupMPWrapper, BallInACupPositionalWrapper],
|
||||
"mp_kwargs": {
|
||||
"num_dof": 7,
|
||||
"num_basis": 5,
|
||||
"duration": 3.5,
|
||||
@ -462,9 +510,12 @@ register(
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1,
|
||||
"policy_kwargs": {
|
||||
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# BBO functions
|
||||
|
@ -1,3 +1,3 @@
|
||||
from alr_envs.classic_control.hole_reacher.hole_reacher import HoleReacherEnv
|
||||
from alr_envs.classic_control.viapoint_reacher.viapoint_reacher import ViaPointReacher
|
||||
from alr_envs.classic_control.simple_reacher.simple_reacher import SimpleReacherEnv
|
||||
from alr_envs.classic_control.hole_reacher.hole_reacher import HoleReacherEnv
|
||||
|
@ -21,14 +21,14 @@ class HoleReacherEnv(gym.Env):
|
||||
self.random_start = random_start
|
||||
|
||||
# provided initial parameters
|
||||
self.hole_x = hole_x # x-position of center of hole
|
||||
self.hole_width = hole_width # width of hole
|
||||
self.hole_depth = hole_depth # depth of hole
|
||||
self.initial_x = hole_x # x-position of center of hole
|
||||
self.initial_width = hole_width # width of hole
|
||||
self.initial_depth = hole_depth # depth of hole
|
||||
|
||||
# temp container for current env state
|
||||
self._tmp_hole_x = None
|
||||
self._tmp_hole_width = None
|
||||
self._tmp_hole_depth = None
|
||||
self._tmp_x = None
|
||||
self._tmp_width = None
|
||||
self._tmp_depth = None
|
||||
self._goal = None # x-y coordinates for reaching the center at the bottom of the hole
|
||||
|
||||
# collision
|
||||
@ -69,6 +69,10 @@ class HoleReacherEnv(gym.Env):
|
||||
def dt(self) -> Union[float, int]:
|
||||
return self._dt
|
||||
|
||||
@property
|
||||
def start_pos(self):
|
||||
return self._start_pos
|
||||
|
||||
def step(self, action: np.ndarray):
|
||||
"""
|
||||
A single step with an action in joint velocity space
|
||||
@ -110,13 +114,13 @@ class HoleReacherEnv(gym.Env):
|
||||
return self._get_obs().copy()
|
||||
|
||||
def _generate_hole(self):
|
||||
self._tmp_hole_x = self.np_random.uniform(1, 3.5, 1) if self.hole_x is None else np.copy(self.hole_x)
|
||||
self._tmp_hole_width = self.np_random.uniform(0.15, 0.5, 1) if self.hole_width is None else np.copy(
|
||||
self.hole_width)
|
||||
self._tmp_x = self.np_random.uniform(1, 3.5, 1) if self.initial_x is None else np.copy(self.initial_x)
|
||||
self._tmp_width = self.np_random.uniform(0.15, 0.5, 1) if self.initial_width is None else np.copy(
|
||||
self.initial_width)
|
||||
# TODO we do not want this right now.
|
||||
self._tmp_hole_depth = self.np_random.uniform(1, 1, 1) if self.hole_depth is None else np.copy(
|
||||
self.hole_depth)
|
||||
self._goal = np.hstack([self._tmp_hole_x, -self._tmp_hole_depth])
|
||||
self._tmp_depth = self.np_random.uniform(1, 1, 1) if self.initial_depth is None else np.copy(
|
||||
self.initial_depth)
|
||||
self._goal = np.hstack([self._tmp_x, -self._tmp_depth])
|
||||
|
||||
def _update_joints(self):
|
||||
"""
|
||||
@ -164,7 +168,7 @@ class HoleReacherEnv(gym.Env):
|
||||
np.cos(theta),
|
||||
np.sin(theta),
|
||||
self._angle_velocity,
|
||||
self._tmp_hole_width,
|
||||
self._tmp_width,
|
||||
# self._tmp_hole_depth,
|
||||
self.end_effector - self._goal,
|
||||
self._steps
|
||||
@ -192,7 +196,7 @@ class HoleReacherEnv(gym.Env):
|
||||
def _check_wall_collision(self, line_points):
|
||||
|
||||
# all points that are before the hole in x
|
||||
r, c = np.where(line_points[:, :, 0] < (self._tmp_hole_x - self._tmp_hole_width / 2))
|
||||
r, c = np.where(line_points[:, :, 0] < (self._tmp_x - self._tmp_width / 2))
|
||||
|
||||
# check if any of those points are below surface
|
||||
nr_line_points_below_surface_before_hole = np.sum(line_points[r, c, 1] < 0)
|
||||
@ -201,7 +205,7 @@ class HoleReacherEnv(gym.Env):
|
||||
return True
|
||||
|
||||
# all points that are after the hole in x
|
||||
r, c = np.where(line_points[:, :, 0] > (self._tmp_hole_x + self._tmp_hole_width / 2))
|
||||
r, c = np.where(line_points[:, :, 0] > (self._tmp_x + self._tmp_width / 2))
|
||||
|
||||
# check if any of those points are below surface
|
||||
nr_line_points_below_surface_after_hole = np.sum(line_points[r, c, 1] < 0)
|
||||
@ -210,11 +214,11 @@ class HoleReacherEnv(gym.Env):
|
||||
return True
|
||||
|
||||
# all points that are above the hole
|
||||
r, c = np.where((line_points[:, :, 0] > (self._tmp_hole_x - self._tmp_hole_width / 2)) & (
|
||||
line_points[:, :, 0] < (self._tmp_hole_x + self._tmp_hole_width / 2)))
|
||||
r, c = np.where((line_points[:, :, 0] > (self._tmp_x - self._tmp_width / 2)) & (
|
||||
line_points[:, :, 0] < (self._tmp_x + self._tmp_width / 2)))
|
||||
|
||||
# check if any of those points are below surface
|
||||
nr_line_points_below_surface_in_hole = np.sum(line_points[r, c, 1] < -self._tmp_hole_depth)
|
||||
nr_line_points_below_surface_in_hole = np.sum(line_points[r, c, 1] < -self._tmp_depth)
|
||||
|
||||
if nr_line_points_below_surface_in_hole > 0:
|
||||
return True
|
||||
@ -257,17 +261,17 @@ class HoleReacherEnv(gym.Env):
|
||||
def _set_patches(self):
|
||||
if self.fig is not None:
|
||||
self.fig.gca().patches = []
|
||||
left_block = patches.Rectangle((-self.n_links, -self._tmp_hole_depth),
|
||||
self.n_links + self._tmp_hole_x - self._tmp_hole_width / 2,
|
||||
self._tmp_hole_depth,
|
||||
left_block = patches.Rectangle((-self.n_links, -self._tmp_depth),
|
||||
self.n_links + self._tmp_x - self._tmp_width / 2,
|
||||
self._tmp_depth,
|
||||
fill=True, edgecolor='k', facecolor='k')
|
||||
right_block = patches.Rectangle((self._tmp_hole_x + self._tmp_hole_width / 2, -self._tmp_hole_depth),
|
||||
self.n_links - self._tmp_hole_x + self._tmp_hole_width / 2,
|
||||
self._tmp_hole_depth,
|
||||
right_block = patches.Rectangle((self._tmp_x + self._tmp_width / 2, -self._tmp_depth),
|
||||
self.n_links - self._tmp_x + self._tmp_width / 2,
|
||||
self._tmp_depth,
|
||||
fill=True, edgecolor='k', facecolor='k')
|
||||
hole_floor = patches.Rectangle((self._tmp_hole_x - self._tmp_hole_width / 2, -self._tmp_hole_depth),
|
||||
self._tmp_hole_width,
|
||||
1 - self._tmp_hole_depth,
|
||||
hole_floor = patches.Rectangle((self._tmp_x - self._tmp_width / 2, -self._tmp_depth),
|
||||
self._tmp_width,
|
||||
1 - self._tmp_depth,
|
||||
fill=True, edgecolor='k', facecolor='k')
|
||||
|
||||
# Add the patch to the Axes
|
||||
|
@ -12,7 +12,7 @@ class HoleReacherMPWrapper(MPEnvWrapper):
|
||||
[self.env.random_start] * self.env.n_links, # cos
|
||||
[self.env.random_start] * self.env.n_links, # sin
|
||||
[self.env.random_start] * self.env.n_links, # velocity
|
||||
[self.env.hole_width is None], # hole width
|
||||
[self.env.initial_width is None], # hole width
|
||||
# [self.env.hole_depth is None], # hole depth
|
||||
[True] * 2, # x-y coordinates of target distance
|
||||
[False] # env steps
|
||||
@ -20,7 +20,7 @@ class HoleReacherMPWrapper(MPEnvWrapper):
|
||||
|
||||
@property
|
||||
def start_pos(self) -> Union[float, int, np.ndarray]:
|
||||
return self._start_pos
|
||||
return self.env.start_pos
|
||||
|
||||
@property
|
||||
def goal_pos(self) -> Union[float, int, np.ndarray]:
|
||||
|
@ -22,15 +22,18 @@ class SimpleReacherEnv(gym.Env):
|
||||
|
||||
self.random_start = random_start
|
||||
|
||||
# provided initial parameters
|
||||
self.inital_target = target
|
||||
|
||||
# temp container for current env state
|
||||
self._goal = None
|
||||
|
||||
self._joints = None
|
||||
self._joint_angles = None
|
||||
self._angle_velocity = None
|
||||
self._start_pos = np.zeros(self.n_links)
|
||||
self._start_vel = np.zeros(self.n_links)
|
||||
|
||||
self._target = target # provided target value
|
||||
self._goal = None # updated goal value, does not change when target != None
|
||||
|
||||
self.max_torque = 1
|
||||
self.steps_before_reward = 199
|
||||
|
||||
@ -56,6 +59,10 @@ class SimpleReacherEnv(gym.Env):
|
||||
def dt(self) -> Union[float, int]:
|
||||
return self._dt
|
||||
|
||||
@property
|
||||
def start_pos(self):
|
||||
return self._start_pos
|
||||
|
||||
def step(self, action: np.ndarray):
|
||||
"""
|
||||
A single step with action in torque space
|
||||
@ -129,14 +136,14 @@ class SimpleReacherEnv(gym.Env):
|
||||
|
||||
def _generate_goal(self):
|
||||
|
||||
if self._target is None:
|
||||
if self.inital_target is None:
|
||||
|
||||
total_length = np.sum(self.link_lengths)
|
||||
goal = np.array([total_length, total_length])
|
||||
while np.linalg.norm(goal) >= total_length:
|
||||
goal = self.np_random.uniform(low=-total_length, high=total_length, size=2)
|
||||
else:
|
||||
goal = np.copy(self._target)
|
||||
goal = np.copy(self.inital_target)
|
||||
|
||||
self._goal = goal
|
||||
|
||||
|
@ -18,7 +18,7 @@ class SimpleReacherMPWrapper(MPEnvWrapper):
|
||||
|
||||
@property
|
||||
def start_pos(self):
|
||||
return self._start_pos
|
||||
return self.env.start_pos
|
||||
|
||||
@property
|
||||
def goal_pos(self):
|
||||
|
@ -10,7 +10,7 @@ from alr_envs.classic_control.utils import check_self_collision
|
||||
|
||||
class ViaPointReacher(gym.Env):
|
||||
|
||||
def __init__(self, n_links, random_start: bool = True, via_target: Union[None, Iterable] = None,
|
||||
def __init__(self, n_links, random_start: bool = False, via_target: Union[None, Iterable] = None,
|
||||
target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000):
|
||||
|
||||
self.n_links = n_links
|
||||
@ -19,8 +19,8 @@ class ViaPointReacher(gym.Env):
|
||||
self.random_start = random_start
|
||||
|
||||
# provided initial parameters
|
||||
self.target = target # provided target value
|
||||
self.via_target = via_target # provided via point target value
|
||||
self.intitial_target = target # provided target value
|
||||
self.initial_via_target = via_target # provided via point target value
|
||||
|
||||
# temp container for current env state
|
||||
self._via_point = np.ones(2)
|
||||
@ -63,6 +63,10 @@ class ViaPointReacher(gym.Env):
|
||||
def dt(self):
|
||||
return self._dt
|
||||
|
||||
@property
|
||||
def start_pos(self):
|
||||
return self._start_pos
|
||||
|
||||
def step(self, action: np.ndarray):
|
||||
"""
|
||||
a single step with an action in joint velocity space
|
||||
@ -107,22 +111,22 @@ class ViaPointReacher(gym.Env):
|
||||
total_length = np.sum(self.link_lengths)
|
||||
|
||||
# rejection sampled point in inner circle with 0.5*Radius
|
||||
if self.via_target is None:
|
||||
if self.initial_via_target is None:
|
||||
via_target = np.array([total_length, total_length])
|
||||
while np.linalg.norm(via_target) >= 0.5 * total_length:
|
||||
via_target = self.np_random.uniform(low=-0.5 * total_length, high=0.5 * total_length, size=2)
|
||||
else:
|
||||
via_target = np.copy(self.via_target)
|
||||
via_target = np.copy(self.initial_via_target)
|
||||
|
||||
# rejection sampled point in outer circle
|
||||
if self.target is None:
|
||||
if self.intitial_target is None:
|
||||
goal = np.array([total_length, total_length])
|
||||
while np.linalg.norm(goal) >= total_length or np.linalg.norm(goal) <= 0.5 * total_length:
|
||||
goal = self.np_random.uniform(low=-total_length, high=total_length, size=2)
|
||||
else:
|
||||
goal = np.copy(self.target)
|
||||
goal = np.copy(self.intitial_target)
|
||||
|
||||
self.via_target = via_target
|
||||
self._via_point = via_target
|
||||
self._goal = goal
|
||||
|
||||
def _update_joints(self):
|
||||
|
@ -12,18 +12,19 @@ class ViaPointReacherMPWrapper(MPEnvWrapper):
|
||||
[self.env.random_start] * self.env.n_links, # cos
|
||||
[self.env.random_start] * self.env.n_links, # sin
|
||||
[self.env.random_start] * self.env.n_links, # velocity
|
||||
[self.env.via_target is None] * 2, # x-y coordinates of via point distance
|
||||
[self.env.initial_via_target is None] * 2, # x-y coordinates of via point distance
|
||||
[True] * 2, # x-y coordinates of target distance
|
||||
[False] # env steps
|
||||
])
|
||||
|
||||
@property
|
||||
def start_pos(self) -> Union[float, int, np.ndarray]:
|
||||
return self._start_pos
|
||||
return self.env.start_pos
|
||||
|
||||
@property
|
||||
def goal_pos(self) -> Union[float, int, np.ndarray]:
|
||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
||||
|
||||
@property
|
||||
def dt(self) -> Union[float, int]:
|
||||
return self.env.dt
|
||||
|
16
example.py
16
example.py
@ -1,7 +1,9 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from alr_envs.utils.mp_env_async_sampler import AlrMpEnvSampler, AlrContextualMpEnvSampler, DummyDist
|
||||
|
||||
from alr_envs.utils.mp_env_async_sampler import AlrContextualMpEnvSampler, AlrMpEnvSampler, DummyDist
|
||||
|
||||
|
||||
def example_mujoco():
|
||||
@ -14,8 +16,8 @@ def example_mujoco():
|
||||
obs, reward, done, info = env.step(env.action_space.sample())
|
||||
rewards += reward
|
||||
|
||||
if i % 1 == 0:
|
||||
env.render()
|
||||
# if i % 1 == 0:
|
||||
# env.render()
|
||||
|
||||
if done:
|
||||
print(rewards)
|
||||
@ -23,8 +25,7 @@ def example_mujoco():
|
||||
obs = env.reset()
|
||||
|
||||
|
||||
def example_mp(env_name="alr_envs:HoleReacherDMP-v0"):
|
||||
# env = gym.make("alr_envs:ViaPointReacherDMP-v0")
|
||||
def example_mp(env_name="alr_envs:HoleReacherDMP-v1"):
|
||||
env = gym.make(env_name)
|
||||
rewards = 0
|
||||
# env.render(mode=None)
|
||||
@ -105,11 +106,12 @@ def example_async_contextual_sampler(env_name="alr_envs:SimpleReacherDMP-v1", n_
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
example_mp("alr_envs:HoleReacherDetPMP-v0")
|
||||
# example_mujoco()
|
||||
# example_mp("alr_envs:SimpleReacherDMP-v1")
|
||||
# example_async("alr_envs:LongSimpleReacherDMP-v0", 4)
|
||||
# example_async_contextual_sampler()
|
||||
# env = gym.make("alr_envs:HoleReacherDetPMP-v1")
|
||||
env_name = "alr_envs:ALRBallInACupPDSimpleDetPMP-v0"
|
||||
example_async_sampler(env_name)
|
||||
# env_name = "alr_envs:ALRBallInACupPDSimpleDetPMP-v0"
|
||||
# example_async_sampler(env_name)
|
||||
# example_mp(env_name)
|
||||
|
Loading…
Reference in New Issue
Block a user