bug fixes
This commit is contained in:
parent
ba0b612868
commit
c2db2f8064
@ -19,6 +19,7 @@ Currently we have the following environments:
|
||||
|`ALRLongReacher-v0`|Modified (7 links) Mujoco gym's `Reacher-v2` (2 links)| 200 | 7 | 27
|
||||
|`ALRLongReacherSparse-v0`|Same as `ALRLongReacher-v0`, but the distance penalty is only provided in the last time step.| 200 | 7 | 27
|
||||
|`ALRLongReacherSparseBalanced-v0`|Same as `ALRLongReacherSparse-v0`, but the end-effector has to remain upright.| 200 | 7 | 27
|
||||
|`ALRBallInACupSimple-v0`| Ball-in-a-cup task where a robot needs to catch a ball attached to a cup at its end-effector. | 4000 | 3 | wip
|
||||
|`ALRBallInACup-v0`| Ball-in-a-cup task where a robot needs to catch a ball attached to a cup at its end-effector | 4000 | 7 | wip
|
||||
|`ALRBallInACupGoal-v0`| Similiar to `ALRBallInACupSimple-v0` but the ball needs to be caught at a specified goal position | 4000 | 7 | wip
|
||||
|
||||
@ -41,7 +42,7 @@ All environments provide the full episode reward and additional information abou
|
||||
|`ViaPointReacherDMP-v0`| A DMP provides a trajectory for the `ViaPointReacher-v0` task. | 200 | 25
|
||||
|`HoleReacherFixedGoalDMP-v0`| A DMP provides a trajectory for the `HoleReacher-v0` task with a fixed goal attractor. | 200 | 25
|
||||
|`HoleReacherDMP-v0`| A DMP provides a trajectory for the `HoleReacher-v0` task. The goal attractor needs to be learned. | 200 | 30
|
||||
|`ALRBallInACupSimpleDMP-v0`| A DMP provides a trajectory for a simplified `ALRBallInACup-v0` task where only 3 joints are actuated. | 4000 | 15
|
||||
|`ALRBallInACupSimpleDMP-v0`| A DMP provides a trajectory for the `ALRBallInACupSimple-v0` task where only 3 joints are actuated. | 4000 | 15
|
||||
|`ALRBallInACupDMP-v0`| A DMP provides a trajectory for the `ALRBallInACup-v0` task. | 4000 | 35
|
||||
|`ALRBallInACupGoalDMP-v0`| A DMP provides a trajectory for the `ALRBallInACupGoal-v0` task. | 4000 | 35 | 3
|
||||
|
||||
|
@ -71,12 +71,22 @@ register(
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRBallInACupSimple-v0',
|
||||
entry_point='alr_envs.mujoco:ALRBallInACupEnv',
|
||||
max_episode_steps=4000,
|
||||
kwargs={
|
||||
"simplified": True,
|
||||
"reward_type": "no_context"
|
||||
}
|
||||
)
|
||||
|
||||
register(
|
||||
id='ALRBallInACup-v0',
|
||||
entry_point='alr_envs.mujoco:ALRBallInACupEnv',
|
||||
max_episode_steps=4000,
|
||||
kwargs={
|
||||
"reward_type": "simple"
|
||||
"reward_type": "no_context"
|
||||
}
|
||||
)
|
||||
|
||||
@ -209,7 +219,7 @@ register(
|
||||
id='ALRBallInACupSimpleDMP-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
||||
kwargs={
|
||||
"name": "alr_envs:ALRBallInACup-v0",
|
||||
"name": "alr_envs:ALRBallInACupSimple-v0",
|
||||
"num_dof": 3,
|
||||
"num_basis": 5,
|
||||
"duration": 3.5,
|
||||
@ -243,7 +253,7 @@ register(
|
||||
|
||||
register(
|
||||
id='ALRBallInACupGoalDMP-v0',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_dmp_env',
|
||||
entry_point='alr_envs.utils.make_env_helpers:make_contextual_env',
|
||||
kwargs={
|
||||
"name": "alr_envs:ALRBallInACupGoal-v0",
|
||||
"num_dof": 7,
|
||||
|
@ -5,7 +5,8 @@ from alr_envs.mujoco import alr_mujoco_env
|
||||
|
||||
|
||||
class ALRBallInACupEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle):
|
||||
def __init__(self, n_substeps=4, apply_gravity_comp=True, reward_type: str = None, context: np.ndarray = None):
|
||||
def __init__(self, n_substeps=4, apply_gravity_comp=True, simplified: bool = False,
|
||||
reward_type: str = None, context: np.ndarray = None):
|
||||
self._steps = 0
|
||||
|
||||
self.xml_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
|
||||
@ -31,9 +32,11 @@ class ALRBallInACupEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle):
|
||||
self._start_pos = np.array([0.0, 0.58760536, 0.0, 1.36004913, 0.0, -0.32072943, -1.57])
|
||||
self._start_vel = np.zeros(7)
|
||||
|
||||
self.simplified = simplified
|
||||
|
||||
self.sim_time = 8 # seconds
|
||||
self.sim_steps = int(self.sim_time / self.dt)
|
||||
if reward_type == "simple":
|
||||
if reward_type == "no_context":
|
||||
from alr_envs.mujoco.ball_in_a_cup.ball_in_a_cup_reward_simple import BallInACupReward
|
||||
reward_function = BallInACupReward
|
||||
elif reward_type == "contextual_goal":
|
||||
@ -44,6 +47,20 @@ class ALRBallInACupEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle):
|
||||
self.reward_function = reward_function(self.sim_steps)
|
||||
self.configure(context)
|
||||
|
||||
@property
|
||||
def start_pos(self):
|
||||
if self.simplified:
|
||||
return self._start_pos[1::2]
|
||||
else:
|
||||
return self._start_pos
|
||||
|
||||
@property
|
||||
def start_vel(self):
|
||||
if self.simplified:
|
||||
return self._start_vel[1::2]
|
||||
else:
|
||||
return self._start_vel
|
||||
|
||||
@property
|
||||
def current_pos(self):
|
||||
return self.sim.data.qpos[0:7].copy()
|
||||
@ -58,7 +75,7 @@ class ALRBallInACupEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle):
|
||||
|
||||
def reset_model(self):
|
||||
init_pos_all = self.init_qpos.copy()
|
||||
init_pos_robot = self.start_pos
|
||||
init_pos_robot = self._start_pos
|
||||
init_vel = np.zeros_like(init_pos_all)
|
||||
|
||||
self._steps = 0
|
||||
@ -114,14 +131,14 @@ class ALRBallInACupEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle):
|
||||
|
||||
# These functions are for the task with 3 joint actuations
|
||||
def extend_des_pos(self, des_pos):
|
||||
des_pos_full = self.start_pos.copy()
|
||||
des_pos_full = self._start_pos.copy()
|
||||
des_pos_full[1] = des_pos[0]
|
||||
des_pos_full[3] = des_pos[1]
|
||||
des_pos_full[5] = des_pos[2]
|
||||
return des_pos_full
|
||||
|
||||
def extend_des_vel(self, des_vel):
|
||||
des_vel_full = self.start_vel.copy()
|
||||
des_vel_full = self._start_vel.copy()
|
||||
des_vel_full[1] = des_vel[0]
|
||||
des_vel_full[3] = des_vel[1]
|
||||
des_vel_full[5] = des_vel[2]
|
||||
|
@ -27,111 +27,3 @@ def make_detpmp_env(**kwargs):
|
||||
name = kwargs.pop("name")
|
||||
_env = gym.make(name)
|
||||
return DetPMPWrapper(_env, **kwargs)
|
||||
|
||||
|
||||
# def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
||||
# assert shared_memory is None
|
||||
# env = env_fn()
|
||||
# parent_pipe.close()
|
||||
# try:
|
||||
# while True:
|
||||
# command, data = pipe.recv()
|
||||
# if command == 'reset':
|
||||
# observation = env.reset()
|
||||
# pipe.send((observation, True))
|
||||
# elif command == 'configure':
|
||||
# env.configure(data)
|
||||
# pipe.send((None, True))
|
||||
# elif command == 'step':
|
||||
# observation, reward, done, info = env.step(data)
|
||||
# if done:
|
||||
# observation = env.reset()
|
||||
# pipe.send(((observation, reward, done, info), True))
|
||||
# elif command == 'seed':
|
||||
# env.seed(data)
|
||||
# pipe.send((None, True))
|
||||
# elif command == 'close':
|
||||
# pipe.send((None, True))
|
||||
# break
|
||||
# elif command == '_check_observation_space':
|
||||
# pipe.send((data == env.observation_space, True))
|
||||
# else:
|
||||
# raise RuntimeError('Received unknown command `{0}`. Must '
|
||||
# 'be one of {`reset`, `step`, `seed`, `close`, '
|
||||
# '`_check_observation_space`}.'.format(command))
|
||||
# except (KeyboardInterrupt, Exception):
|
||||
# error_queue.put((index,) + sys.exc_info()[:2])
|
||||
# pipe.send((None, False))
|
||||
# finally:
|
||||
# env.close()
|
||||
#
|
||||
#
|
||||
# def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
|
||||
# assert shared_memory is not None
|
||||
# env = env_fn()
|
||||
# observation_space = env.observation_space
|
||||
# parent_pipe.close()
|
||||
# try:
|
||||
# while True:
|
||||
# command, data = pipe.recv()
|
||||
# if command == 'reset':
|
||||
# observation = env.reset()
|
||||
# write_to_shared_memory(index, observation, shared_memory,
|
||||
# observation_space)
|
||||
# pipe.send((None, True))
|
||||
# elif command == 'configure':
|
||||
# env.configure(data)
|
||||
# pipe.send((None, True))
|
||||
# elif command == 'step':
|
||||
# observation, reward, done, info = env.step(data)
|
||||
# if done:
|
||||
# observation = env.reset()
|
||||
# write_to_shared_memory(index, observation, shared_memory,
|
||||
# observation_space)
|
||||
# pipe.send(((None, reward, done, info), True))
|
||||
# elif command == 'seed':
|
||||
# env.seed(data)
|
||||
# pipe.send((None, True))
|
||||
# elif command == 'close':
|
||||
# pipe.send((None, True))
|
||||
# break
|
||||
# elif command == '_check_observation_space':
|
||||
# pipe.send((data == observation_space, True))
|
||||
# else:
|
||||
# raise RuntimeError('Received unknown command `{0}`. Must '
|
||||
# 'be one of {`reset`, `step`, `seed`, `close`, '
|
||||
# '`_check_observation_space`}.'.format(command))
|
||||
# except (KeyboardInterrupt, Exception):
|
||||
# error_queue.put((index,) + sys.exc_info()[:2])
|
||||
# pipe.send((None, False))
|
||||
# finally:
|
||||
# env.close()
|
||||
|
||||
|
||||
# def viapoint_dmp(**kwargs):
|
||||
# _env = gym.make("alr_envs:ViaPointReacher-v0")
|
||||
# # _env = ViaPointReacher(**kwargs)
|
||||
# return DmpWrapper(_env, num_dof=5, num_basis=5, duration=2, alpha_phase=2.5, dt=_env.dt,
|
||||
# start_pos=_env.start_pos, learn_goal=False, policy_type="velocity", weights_scale=50)
|
||||
#
|
||||
#
|
||||
# def holereacher_dmp(**kwargs):
|
||||
# _env = gym.make("alr_envs:HoleReacher-v0")
|
||||
# # _env = HoleReacher(**kwargs)
|
||||
# return DmpWrapper(_env, num_dof=5, num_basis=5, duration=2, dt=_env.dt, learn_goal=True, alpha_phase=2,
|
||||
# start_pos=_env.start_pos, policy_type="velocity", weights_scale=50, goal_scale=0.1)
|
||||
#
|
||||
#
|
||||
# def holereacher_fix_goal_dmp(**kwargs):
|
||||
# _env = gym.make("alr_envs:HoleReacher-v0")
|
||||
# # _env = HoleReacher(**kwargs)
|
||||
# return DmpWrapper(_env, num_dof=5, num_basis=5, duration=2, dt=_env.dt, learn_goal=False, alpha_phase=2,
|
||||
# start_pos=_env.start_pos, policy_type="velocity", weights_scale=50, goal_scale=1,
|
||||
# final_pos=np.array([2.02669572, -1.25966385, -1.51618198, -0.80946476, 0.02012344]))
|
||||
#
|
||||
#
|
||||
# def holereacher_detpmp(**kwargs):
|
||||
# _env = gym.make("alr_envs:HoleReacher-v0")
|
||||
# # _env = HoleReacher(**kwargs)
|
||||
# return DetPMPWrapper(_env, num_dof=5, num_basis=5, width=0.005, policy_type="velocity", start_pos=_env.start_pos,
|
||||
# duration=2, post_traj_time=0, dt=_env.dt, weights_scale=0.25, zero_start=True, zero_goal=False)
|
||||
|
@ -23,7 +23,7 @@ def split_array(ary, size):
|
||||
split = [k * size for k in range(1, repeat)]
|
||||
sub_arys = np.split(ary, split)
|
||||
|
||||
if n_samples % repeat != 0:
|
||||
if n_samples % size != 0:
|
||||
tmp = np.zeros_like(sub_arys[0])
|
||||
last = sub_arys[-1]
|
||||
tmp[0: len(last)] = last
|
||||
@ -42,8 +42,8 @@ def _flatten_list(l):
|
||||
|
||||
class AlrMpEnvSampler:
|
||||
"""
|
||||
An asynchronous sampler for MPWrapper environments. A sampler object can be called with a set of parameters and
|
||||
returns the corresponding final obs, rewards, dones and info dicts.
|
||||
An asynchronous sampler for non contextual MPWrapper environments. A sampler object can be called with a set of
|
||||
parameters and returns the corresponding final obs, rewards, dones and info dicts.
|
||||
"""
|
||||
def __init__(self, env_id, num_envs, seed=0):
|
||||
self.num_envs = num_envs
|
||||
@ -68,10 +68,10 @@ class AlrMpEnvSampler:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
env_name = "alr_envs:HoleReacherDMP-v0"
|
||||
env_name = "alr_envs:ALRBallInACupSimpleDMP-v0"
|
||||
n_cpu = 8
|
||||
dim = 30
|
||||
n_samples = 20
|
||||
dim = 15
|
||||
n_samples = 10
|
||||
|
||||
sampler = AlrMpEnvSampler(env_name, num_envs=n_cpu)
|
||||
|
||||
|
@ -66,7 +66,7 @@ class MPWrapper(gym.Wrapper, ABC):
|
||||
|
||||
if self.post_traj_steps > 0:
|
||||
trajectory = np.vstack([trajectory, np.tile(trajectory[-1, :], [self.post_traj_steps, 1])])
|
||||
velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.dmp.num_dimensions))])
|
||||
velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.mp.num_dimensions))])
|
||||
|
||||
# self._trajectory = trajectory
|
||||
# self._velocity = velocity
|
||||
@ -76,6 +76,7 @@ class MPWrapper(gym.Wrapper, ABC):
|
||||
|
||||
# TODO: @Max Why do we need this configure, states should be part of the model
|
||||
# TODO: Ask Onur if the context distribution needs to be outside the environment
|
||||
# TODO: For now create a new env with each context
|
||||
# self.env.configure(context)
|
||||
obs = self.env.reset()
|
||||
info = {}
|
||||
|
Loading…
Reference in New Issue
Block a user