updates on mp wrappers and some bugfixes
This commit is contained in:
parent
17c489d622
commit
f5f12c846f
@ -272,12 +272,20 @@ for v in versions:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# register(
|
register(
|
||||||
# id='HoleReacherDetPMP-v0',
|
id=f'HoleReacherDetPMP-{v}',
|
||||||
# entry_point='alr_envs.classic_control.hole_reacher:holereacher_detpmp',
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env',
|
||||||
# # max_episode_steps=1,
|
kwargs={
|
||||||
# # TODO: add mp kwargs
|
"name": f"alr_envs:HoleReacher-{v}",
|
||||||
# )
|
"num_dof": 5,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 2,
|
||||||
|
"width": 0.01,
|
||||||
|
"policy_type": "velocity",
|
||||||
|
"weights_scale": 0.2,
|
||||||
|
"zero_start": True
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: properly add final_pos
|
# TODO: properly add final_pos
|
||||||
register(
|
register(
|
||||||
@ -335,6 +343,40 @@ register(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
register(
|
||||||
|
id='ALRBallInACupSimpleDetPMP-v0',
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env',
|
||||||
|
kwargs={
|
||||||
|
"name": "alr_envs:ALRBallInACupSimple-v0",
|
||||||
|
"num_dof": 3,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 3.5,
|
||||||
|
"post_traj_time": 4.5,
|
||||||
|
"width": 0.005,
|
||||||
|
"policy_type": "motor",
|
||||||
|
"weights_scale": 0.2,
|
||||||
|
"zero_start": True,
|
||||||
|
"zero_goal": True
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
register(
|
||||||
|
id='ALRBallInACupDetPMP-v0',
|
||||||
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env',
|
||||||
|
kwargs={
|
||||||
|
"name": "alr_envs:ALRBallInACupSimple-v0",
|
||||||
|
"num_dof": 7,
|
||||||
|
"num_basis": 5,
|
||||||
|
"duration": 3.5,
|
||||||
|
"post_traj_time": 4.5,
|
||||||
|
"width": 0.0035,
|
||||||
|
"policy_type": "motor",
|
||||||
|
"weights_scale": 0.2,
|
||||||
|
"zero_start": True,
|
||||||
|
"zero_goal": True
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRBallInACupGoalDMP-v0',
|
id='ALRBallInACupGoalDMP-v0',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_contextual_env',
|
entry_point='alr_envs.utils.make_env_helpers:make_contextual_env',
|
||||||
|
@ -7,10 +7,10 @@ from gym.utils import seeding
|
|||||||
from matplotlib import patches
|
from matplotlib import patches
|
||||||
|
|
||||||
from alr_envs.classic_control.utils import check_self_collision
|
from alr_envs.classic_control.utils import check_self_collision
|
||||||
from alr_envs.utils.mps.mp_environments import MPEnv
|
from alr_envs.utils.mps.mp_environments import AlrEnv
|
||||||
|
|
||||||
|
|
||||||
class HoleReacherEnv(MPEnv):
|
class HoleReacherEnv(AlrEnv):
|
||||||
|
|
||||||
def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None,
|
def __init__(self, n_links: int, hole_x: Union[None, float] = None, hole_depth: Union[None, float] = None,
|
||||||
hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False,
|
hole_width: float = 1., random_start: bool = False, allow_self_collision: bool = False,
|
||||||
@ -71,11 +71,11 @@ class HoleReacherEnv(MPEnv):
|
|||||||
A single step with an action in joint velocity space
|
A single step with an action in joint velocity space
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
acc = (action - self._angle_velocity) / self.dt
|
||||||
self._angle_velocity = action
|
self._angle_velocity = action
|
||||||
self._joint_angles = self._joint_angles + self.dt * self._angle_velocity
|
self._joint_angles = self._joint_angles + self.dt * self._angle_velocity
|
||||||
self._update_joints()
|
self._update_joints()
|
||||||
|
|
||||||
acc = (action - self._angle_velocity) / self.dt
|
|
||||||
reward, info = self._get_reward(acc)
|
reward, info = self._get_reward(acc)
|
||||||
|
|
||||||
info.update({"is_collided": self._is_collided})
|
info.update({"is_collided": self._is_collided})
|
||||||
|
@ -5,10 +5,10 @@ import numpy as np
|
|||||||
from gym import spaces
|
from gym import spaces
|
||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
|
|
||||||
from alr_envs.utils.mps.mp_environments import MPEnv
|
from alr_envs.utils.mps.mp_environments import AlrEnv
|
||||||
|
|
||||||
|
|
||||||
class SimpleReacherEnv(MPEnv):
|
class SimpleReacherEnv(AlrEnv):
|
||||||
"""
|
"""
|
||||||
Simple Reaching Task without any physics simulation.
|
Simple Reaching Task without any physics simulation.
|
||||||
Returns no reward until 150 time steps. This allows the agent to explore the space, but requires precise actions
|
Returns no reward until 150 time steps. This allows the agent to explore the space, but requires precise actions
|
||||||
|
@ -6,10 +6,10 @@ import numpy as np
|
|||||||
from gym.utils import seeding
|
from gym.utils import seeding
|
||||||
|
|
||||||
from alr_envs.classic_control.utils import check_self_collision
|
from alr_envs.classic_control.utils import check_self_collision
|
||||||
from alr_envs.utils.mps.mp_environments import MPEnv
|
from alr_envs.utils.mps.mp_environments import AlrEnv
|
||||||
|
|
||||||
|
|
||||||
class ViaPointReacher(MPEnv):
|
class ViaPointReacher(AlrEnv):
|
||||||
|
|
||||||
def __init__(self, n_links, random_start: bool = True, via_target: Union[None, Iterable] = None,
|
def __init__(self, n_links, random_start: bool = True, via_target: Union[None, Iterable] = None,
|
||||||
target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000):
|
target: Union[None, Iterable] = None, allow_self_collision=False, collision_penalty=1000):
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import os
|
import os
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
|
|
||||||
from gym import error, spaces
|
from gym import error, spaces
|
||||||
@ -142,18 +143,20 @@ class AlrMujocoEnv(gym.Env):
|
|||||||
# methods to override:
|
# methods to override:
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def active_obs(self):
|
||||||
|
"""Returns boolean mask for each observation entry
|
||||||
|
whether the observation is returned for the contextual case or not.
|
||||||
|
This effectively allows to filter unwanted or unnecessary observations from the full step-based case.
|
||||||
|
"""
|
||||||
|
return np.ones(self.observation_space.shape, dtype=bool)
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
"""Returns the observation.
|
"""Returns the observation.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def configure(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Helper method to set certain environment properties such as contexts in contextual environments since reset()
|
|
||||||
doesn't take arguments. Should be called before reset().
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
"""
|
"""
|
||||||
Reset the robot degrees of freedom (qpos and qvel).
|
Reset the robot degrees of freedom (qpos and qvel).
|
||||||
|
@ -22,7 +22,7 @@ class ALRBallInACupEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle):
|
|||||||
self.j_min = np.array([-2.6, -1.985, -2.8, -0.9, -4.55, -1.5707, -2.7])
|
self.j_min = np.array([-2.6, -1.985, -2.8, -0.9, -4.55, -1.5707, -2.7])
|
||||||
self.j_max = np.array([2.6, 1.985, 2.8, 3.14159, 1.25, 1.5707, 2.7])
|
self.j_max = np.array([2.6, 1.985, 2.8, 3.14159, 1.25, 1.5707, 2.7])
|
||||||
|
|
||||||
self.context = None
|
self.context = context
|
||||||
|
|
||||||
utils.EzPickle.__init__(self)
|
utils.EzPickle.__init__(self)
|
||||||
alr_mujoco_env.AlrMujocoEnv.__init__(self,
|
alr_mujoco_env.AlrMujocoEnv.__init__(self,
|
||||||
@ -45,7 +45,6 @@ class ALRBallInACupEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Unknown reward type")
|
raise ValueError("Unknown reward type")
|
||||||
self.reward_function = reward_function(self.sim_steps)
|
self.reward_function = reward_function(self.sim_steps)
|
||||||
self.configure(context)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def start_pos(self):
|
def start_pos(self):
|
||||||
@ -69,10 +68,6 @@ class ALRBallInACupEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle):
|
|||||||
def current_vel(self):
|
def current_vel(self):
|
||||||
return self.sim.data.qvel[0:7].copy()
|
return self.sim.data.qvel[0:7].copy()
|
||||||
|
|
||||||
def configure(self, context):
|
|
||||||
self.context = context
|
|
||||||
self.reward_function.reset(context)
|
|
||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
init_pos_all = self.init_qpos.copy()
|
init_pos_all = self.init_qpos.copy()
|
||||||
init_pos_robot = self._start_pos
|
init_pos_robot = self._start_pos
|
||||||
@ -129,6 +124,16 @@ class ALRBallInACupEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle):
|
|||||||
[self._steps],
|
[self._steps],
|
||||||
])
|
])
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
@property
|
||||||
|
def active_obs(self):
|
||||||
|
return np.hstack([
|
||||||
|
[False] * 7, # cos
|
||||||
|
[False] * 7, # sin
|
||||||
|
# [True] * 2, # x-y coordinates of target distance
|
||||||
|
[False] # env steps
|
||||||
|
])
|
||||||
|
|
||||||
# These functions are for the task with 3 joint actuations
|
# These functions are for the task with 3 joint actuations
|
||||||
def extend_des_pos(self, des_pos):
|
def extend_des_pos(self, des_pos):
|
||||||
des_pos_full = self._start_pos.copy()
|
des_pos_full = self._start_pos.copy()
|
||||||
|
@ -105,6 +105,7 @@ class ALRBeerpongEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle):
|
|||||||
def check_traj_in_joint_limits(self):
|
def check_traj_in_joint_limits(self):
|
||||||
return any(self.current_pos > self.j_max) or any(self.current_pos < self.j_min)
|
return any(self.current_pos > self.j_max) or any(self.current_pos < self.j_min)
|
||||||
|
|
||||||
|
# TODO
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
theta = self.sim.data.qpos.flat[:7]
|
theta = self.sim.data.qpos.flat[:7]
|
||||||
return np.concatenate([
|
return np.concatenate([
|
||||||
@ -114,6 +115,10 @@ class ALRBeerpongEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle):
|
|||||||
[self._steps],
|
[self._steps],
|
||||||
])
|
])
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
def active_obs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -40,6 +40,16 @@ def _flatten_list(l):
|
|||||||
return [l__ for l_ in l for l__ in l_]
|
return [l__ for l_ in l for l__ in l_]
|
||||||
|
|
||||||
|
|
||||||
|
class DummyDist:
|
||||||
|
def __init__(self, dim):
|
||||||
|
self.dim = dim
|
||||||
|
|
||||||
|
def sample(self, contexts):
|
||||||
|
contexts = np.atleast_2d(contexts)
|
||||||
|
n_samples = contexts.shape[0]
|
||||||
|
return np.random.normal(size=(n_samples, self.dim)), contexts
|
||||||
|
|
||||||
|
|
||||||
class AlrMpEnvSampler:
|
class AlrMpEnvSampler:
|
||||||
"""
|
"""
|
||||||
An asynchronous sampler for non contextual MPWrapper environments. A sampler object can be called with a set of
|
An asynchronous sampler for non contextual MPWrapper environments. A sampler object can be called with a set of
|
||||||
@ -100,9 +110,9 @@ class AlrContextualMpEnvSampler:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
env_name = "alr_envs:ALRBallInACupSimpleDMP-v0"
|
env_name = "alr_envs:HoleReacherDetPMP-v1"
|
||||||
n_cpu = 8
|
n_cpu = 8
|
||||||
dim = 15
|
dim = 25
|
||||||
n_samples = 10
|
n_samples = 10
|
||||||
|
|
||||||
sampler = AlrMpEnvSampler(env_name, num_envs=n_cpu)
|
sampler = AlrMpEnvSampler(env_name, num_envs=n_cpu)
|
||||||
|
@ -2,12 +2,12 @@ import gym
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from mp_lib import det_promp
|
from mp_lib import det_promp
|
||||||
|
|
||||||
from alr_envs.utils.mps.mp_environments import MPEnv
|
from alr_envs.utils.mps.mp_environments import AlrEnv
|
||||||
from alr_envs.utils.mps.mp_wrapper import MPWrapper
|
from alr_envs.utils.mps.mp_wrapper import MPWrapper
|
||||||
|
|
||||||
|
|
||||||
class DetPMPWrapper(MPWrapper):
|
class DetPMPWrapper(MPWrapper):
|
||||||
def __init__(self, env: MPEnv, num_dof: int, num_basis: int, width: int, duration: int = 1, dt: float = 0.01,
|
def __init__(self, env: AlrEnv, num_dof: int, num_basis: int, width: float, duration: float = 1, dt: float = 0.01,
|
||||||
post_traj_time: float = 0., policy_type: str = None, weights_scale: float = 1.,
|
post_traj_time: float = 0., policy_type: str = None, weights_scale: float = 1.,
|
||||||
zero_start: bool = False, zero_goal: bool = False, **mp_kwargs):
|
zero_start: bool = False, zero_goal: bool = False, **mp_kwargs):
|
||||||
self.duration = duration # seconds
|
self.duration = duration # seconds
|
||||||
@ -15,15 +15,16 @@ class DetPMPWrapper(MPWrapper):
|
|||||||
super().__init__(env, num_dof, dt, duration, post_traj_time, policy_type, weights_scale, num_basis=num_basis,
|
super().__init__(env, num_dof, dt, duration, post_traj_time, policy_type, weights_scale, num_basis=num_basis,
|
||||||
width=width, zero_start=zero_start, zero_goal=zero_goal, **mp_kwargs)
|
width=width, zero_start=zero_start, zero_goal=zero_goal, **mp_kwargs)
|
||||||
|
|
||||||
self.dt = dt
|
self.dt = env.dt if hasattr(env, "dt") else dt
|
||||||
|
assert self.dt is not None
|
||||||
|
|
||||||
action_bounds = np.inf * np.ones((self.mp.n_basis * self.mp.n_dof))
|
action_bounds = np.inf * np.ones((self.mp.n_basis * self.mp.n_dof))
|
||||||
self.action_space = gym.spaces.Box(low=-action_bounds, high=action_bounds, dtype=np.float32)
|
self.action_space = gym.spaces.Box(low=-action_bounds, high=action_bounds, dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
def initialize_mp(self, num_dof: int, duration: int, dt: float, num_basis: int = 5, width: float = None,
|
def initialize_mp(self, num_dof: int, duration: int, dt: float, num_basis: int = 5, width: float = None,
|
||||||
zero_start: bool = False, zero_goal: bool = False):
|
off: float = 0.01, zero_start: bool = False, zero_goal: bool = False):
|
||||||
pmp = det_promp.DeterministicProMP(n_basis=num_basis, n_dof=num_dof, width=width, off=0.01,
|
pmp = det_promp.DeterministicProMP(n_basis=num_basis, n_dof=num_dof, width=width, off=off,
|
||||||
zero_start=zero_start, zero_goal=zero_goal)
|
zero_start=zero_start, zero_goal=zero_goal)
|
||||||
|
|
||||||
weights = np.zeros(shape=(num_basis, num_dof))
|
weights = np.zeros(shape=(num_basis, num_dof))
|
||||||
@ -32,10 +33,10 @@ class DetPMPWrapper(MPWrapper):
|
|||||||
return pmp
|
return pmp
|
||||||
|
|
||||||
def mp_rollout(self, action):
|
def mp_rollout(self, action):
|
||||||
params = np.reshape(action, (self.mp.n_basis, self.mp.n_dof)) * self.weights_scale
|
params = np.reshape(action, newshape=(self.mp.n_basis, self.mp.n_dof)) * self.weights_scale
|
||||||
self.mp.set_weights(self.duration, params)
|
self.mp.set_weights(self.duration, params)
|
||||||
_, des_pos, des_vel, _ = self.mp.compute_trajectory(1 / self.dt, 1.)
|
_, des_pos, des_vel, _ = self.mp.compute_trajectory(1 / self.dt, 1.)
|
||||||
if self.mp.zero_start:
|
if self.mp.zero_start:
|
||||||
des_pos += self.start_pos[None, :]
|
des_pos += self.env.start_pos[None, :]
|
||||||
|
|
||||||
return des_pos, des_vel
|
return des_pos, des_vel
|
||||||
|
@ -4,13 +4,13 @@ from mp_lib import dmps
|
|||||||
from mp_lib.basis import DMPBasisGenerator
|
from mp_lib.basis import DMPBasisGenerator
|
||||||
from mp_lib.phase import ExpDecayPhaseGenerator
|
from mp_lib.phase import ExpDecayPhaseGenerator
|
||||||
|
|
||||||
from alr_envs.utils.mps.mp_environments import MPEnv
|
from alr_envs.utils.mps.mp_environments import AlrEnv
|
||||||
from alr_envs.utils.mps.mp_wrapper import MPWrapper
|
from alr_envs.utils.mps.mp_wrapper import MPWrapper
|
||||||
|
|
||||||
|
|
||||||
class DmpWrapper(MPWrapper):
|
class DmpWrapper(MPWrapper):
|
||||||
|
|
||||||
def __init__(self, env: MPEnv, num_dof: int, num_basis: int,
|
def __init__(self, env: AlrEnv, num_dof: int, num_basis: int,
|
||||||
duration: int = 1, alpha_phase: float = 2., dt: float = None,
|
duration: int = 1, alpha_phase: float = 2., dt: float = None,
|
||||||
learn_goal: bool = False, post_traj_time: float = 0.,
|
learn_goal: bool = False, post_traj_time: float = 0.,
|
||||||
weights_scale: float = 1., goal_scale: float = 1., bandwidth_factor: float = 3.,
|
weights_scale: float = 1., goal_scale: float = 1., bandwidth_factor: float = 3.,
|
||||||
@ -40,7 +40,7 @@ class DmpWrapper(MPWrapper):
|
|||||||
super().__init__(env, num_dof, dt, duration, post_traj_time, policy_type, weights_scale, render_mode,
|
super().__init__(env, num_dof, dt, duration, post_traj_time, policy_type, weights_scale, render_mode,
|
||||||
num_basis=num_basis, alpha_phase=alpha_phase, bandwidth_factor=bandwidth_factor)
|
num_basis=num_basis, alpha_phase=alpha_phase, bandwidth_factor=bandwidth_factor)
|
||||||
|
|
||||||
action_bounds = np.inf * np.ones((np.prod(self.mp.dmp_weights.shape) + (num_dof if learn_goal else 0)))
|
action_bounds = np.inf * np.ones((np.prod(self.mp.weights.shape) + (num_dof if learn_goal else 0)))
|
||||||
self.action_space = gym.spaces.Box(low=-action_bounds, high=action_bounds, dtype=np.float32)
|
self.action_space = gym.spaces.Box(low=-action_bounds, high=action_bounds, dtype=np.float32)
|
||||||
|
|
||||||
def initialize_mp(self, num_dof: int, duration: int, dt: float, num_basis: int = 5, alpha_phase: float = 2.,
|
def initialize_mp(self, num_dof: int, duration: int, dt: float, num_basis: int = 5, alpha_phase: float = 2.,
|
||||||
@ -51,7 +51,7 @@ class DmpWrapper(MPWrapper):
|
|||||||
basis_bandwidth_factor=bandwidth_factor)
|
basis_bandwidth_factor=bandwidth_factor)
|
||||||
|
|
||||||
dmp = dmps.DMP(num_dof=num_dof, basis_generator=basis_generator, phase_generator=phase_generator,
|
dmp = dmps.DMP(num_dof=num_dof, basis_generator=basis_generator, phase_generator=phase_generator,
|
||||||
num_time_steps=int(duration / dt), dt=dt)
|
duration=duration, dt=dt)
|
||||||
|
|
||||||
return dmp
|
return dmp
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ class DmpWrapper(MPWrapper):
|
|||||||
goal_pos = self.env.goal_pos
|
goal_pos = self.env.goal_pos
|
||||||
assert goal_pos is not None
|
assert goal_pos is not None
|
||||||
|
|
||||||
weight_matrix = np.reshape(params, self.mp.dmp_weights.shape) # [num_basis, num_dof]
|
weight_matrix = np.reshape(params, self.mp.weights.shape) # [num_basis, num_dof]
|
||||||
return goal_pos * self.goal_scale, weight_matrix * self.weights_scale
|
return goal_pos * self.goal_scale, weight_matrix * self.weights_scale
|
||||||
|
|
||||||
def mp_rollout(self, action):
|
def mp_rollout(self, action):
|
||||||
|
@ -5,7 +5,7 @@ import gym
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class MPEnv(gym.Env):
|
class AlrEnv(gym.Env):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -3,13 +3,13 @@ from abc import ABC, abstractmethod
|
|||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from alr_envs.utils.mps.mp_environments import MPEnv
|
from alr_envs.utils.mps.mp_environments import AlrEnv
|
||||||
from alr_envs.utils.policies import get_policy_class
|
from alr_envs.utils.policies import get_policy_class
|
||||||
|
|
||||||
|
|
||||||
class MPWrapper(gym.Wrapper, ABC):
|
class MPWrapper(gym.Wrapper, ABC):
|
||||||
|
|
||||||
def __init__(self, env: MPEnv, num_dof: int, dt: float, duration: int = 1, post_traj_time: float = 0.,
|
def __init__(self, env: AlrEnv, num_dof: int, dt: float, duration: float = 1, post_traj_time: float = 0.,
|
||||||
policy_type: str = None, weights_scale: float = 1., render_mode: str = None, **mp_kwargs):
|
policy_type: str = None, weights_scale: float = 1., render_mode: str = None, **mp_kwargs):
|
||||||
super().__init__(env)
|
super().__init__(env)
|
||||||
|
|
||||||
@ -53,9 +53,6 @@ class MPWrapper(gym.Wrapper, ABC):
|
|||||||
|
|
||||||
return obs, np.array(rewards), dones, infos
|
return obs, np.array(rewards), dones, infos
|
||||||
|
|
||||||
def configure(self, context):
|
|
||||||
self.env.configure(context)
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
return self.env.reset()[self.env.active_obs]
|
return self.env.reset()[self.env.active_obs]
|
||||||
|
|
||||||
@ -65,7 +62,7 @@ class MPWrapper(gym.Wrapper, ABC):
|
|||||||
|
|
||||||
if self.post_traj_steps > 0:
|
if self.post_traj_steps > 0:
|
||||||
trajectory = np.vstack([trajectory, np.tile(trajectory[-1, :], [self.post_traj_steps, 1])])
|
trajectory = np.vstack([trajectory, np.tile(trajectory[-1, :], [self.post_traj_steps, 1])])
|
||||||
velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.mp.num_dimensions))])
|
velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.mp.n_dof))])
|
||||||
|
|
||||||
# self._trajectory = trajectory
|
# self._trajectory = trajectory
|
||||||
# self._velocity = velocity
|
# self._velocity = velocity
|
||||||
@ -105,7 +102,7 @@ class MPWrapper(gym.Wrapper, ABC):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def initialize_mp(self, num_dof: int, duration: int, dt: float, **kwargs):
|
def initialize_mp(self, num_dof: int, duration: float, dt: float, **kwargs):
|
||||||
"""
|
"""
|
||||||
Create respective instance of MP
|
Create respective instance of MP
|
||||||
Returns:
|
Returns:
|
||||||
|
40
example.py
40
example.py
@ -1,6 +1,7 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from alr_envs.utils.mp_env_async_sampler import AlrMpEnvSampler, AlrContextualMpEnvSampler, DummyDist
|
||||||
|
|
||||||
|
|
||||||
def example_mujoco():
|
def example_mujoco():
|
||||||
@ -22,9 +23,9 @@ def example_mujoco():
|
|||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
|
|
||||||
def example_dmp():
|
def example_mp(env_name="alr_envs:HoleReacherDMP-v0"):
|
||||||
# env = gym.make("alr_envs:ViaPointReacherDMP-v0")
|
# env = gym.make("alr_envs:ViaPointReacherDMP-v0")
|
||||||
env = gym.make("alr_envs:HoleReacherDMP-v0")
|
env = gym.make(env_name)
|
||||||
rewards = 0
|
rewards = 0
|
||||||
# env.render(mode=None)
|
# env.render(mode=None)
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
@ -79,9 +80,36 @@ def example_async(env_id="alr_envs:HoleReacherDMP-v0", n_cpu=4, seed=int('533D',
|
|||||||
print(sample(envs, 16))
|
print(sample(envs, 16))
|
||||||
|
|
||||||
|
|
||||||
|
def example_async_sampler(env_name="alr_envs:HoleReacherDetPMP-v1", n_cpu=4):
|
||||||
|
n_samples = 10
|
||||||
|
|
||||||
|
sampler = AlrMpEnvSampler(env_name, num_envs=n_cpu)
|
||||||
|
dim = sampler.env.action_space.spaces[0].shape[0]
|
||||||
|
|
||||||
|
thetas = np.random.randn(n_samples, dim) # usually form a search distribution
|
||||||
|
|
||||||
|
_, rewards, __, ___ = sampler(thetas)
|
||||||
|
|
||||||
|
print(rewards)
|
||||||
|
|
||||||
|
|
||||||
|
def example_async_contextual_sampler(env_name="alr_envs:SimpleReacherDMP-v1", n_cpu=4):
|
||||||
|
sampler = AlrContextualMpEnvSampler(env_name, num_envs=n_cpu)
|
||||||
|
dim = sampler.env.action_space.spaces[0].shape[0]
|
||||||
|
dist = DummyDist(dim) # needs a sample function
|
||||||
|
|
||||||
|
n_samples = 10
|
||||||
|
new_samples, new_contexts, obs, new_rewards, done, infos = sampler(dist, n_samples)
|
||||||
|
|
||||||
|
print(new_rewards)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# example_mujoco()
|
# example_mujoco()
|
||||||
# example_dmp()
|
# example_dmp("alr_envs:SimpleReacherDMP-v1")
|
||||||
example_async("alr_envs:LongSimpleReacherDMP-v0", 4)
|
# example_async("alr_envs:LongSimpleReacherDMP-v0", 4)
|
||||||
# env = gym.make("alr_envs:HoleReacherDMP-v0", context=0.1)
|
# example_async_contextual_sampler()
|
||||||
# env = gym.make("alr_envs:HoleReacherDMP-v1")
|
# env = gym.make("alr_envs:HoleReacherDetPMP-v1")
|
||||||
|
env_name = "alr_envs:ALRBallInACupSimpleDetPMP-v0"
|
||||||
|
# example_async_sampler(env_name)
|
||||||
|
example_mp(env_name)
|
||||||
|
Loading…
Reference in New Issue
Block a user