updates
This commit is contained in:
parent
a8ffa791b8
commit
be6137ec81
@ -22,6 +22,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
replanning_schedule: Optional[
|
replanning_schedule: Optional[
|
||||||
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
|
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
|
||||||
reward_aggregation: Callable[[np.ndarray], float] = np.sum,
|
reward_aggregation: Callable[[np.ndarray], float] = np.sum,
|
||||||
|
max_replan_times: int = 1,
|
||||||
desired_conditioning: bool = False
|
desired_conditioning: bool = False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -73,6 +74,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
self.condition_pos = None
|
self.condition_pos = None
|
||||||
self.condition_vel = None
|
self.condition_vel = None
|
||||||
|
|
||||||
|
self.max_replan_times = max_replan_times
|
||||||
|
self.replan_counts = 0
|
||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
# return context space if we are
|
# return context space if we are
|
||||||
if self.return_context_observation:
|
if self.return_context_observation:
|
||||||
@ -155,6 +159,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
infos = dict()
|
infos = dict()
|
||||||
done = False
|
done = False
|
||||||
|
|
||||||
|
self.replan_counts += 1
|
||||||
for t, (pos, vel) in enumerate(zip(position, velocity)):
|
for t, (pos, vel) in enumerate(zip(position, velocity)):
|
||||||
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
|
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
|
||||||
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
|
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
|
||||||
@ -173,14 +178,16 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
if self.render_kwargs:
|
if self.render_kwargs:
|
||||||
self.env.render(**self.render_kwargs)
|
self.env.render(**self.render_kwargs)
|
||||||
|
|
||||||
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
|
if done or (self.replan_counts < self.max_replan_times
|
||||||
t + 1 + self.current_traj_steps):
|
and self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
|
||||||
|
t + 1 + self.current_traj_steps)):
|
||||||
if self.desired_conditioning:
|
if self.desired_conditioning:
|
||||||
self.condition_pos = pos
|
self.condition_pos = pos
|
||||||
self.condition_vel = vel
|
self.condition_vel = vel
|
||||||
else:
|
else:
|
||||||
self.condition_pos = self.current_pos
|
self.condition_pos = self.current_pos
|
||||||
self.condition_vel = self.current_vel
|
self.condition_vel = self.current_vel
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
infos.update({k: v[:t+1] for k, v in infos.items()})
|
infos.update({k: v[:t+1] for k, v in infos.items()})
|
||||||
|
@ -498,10 +498,11 @@ for _v in _versions:
|
|||||||
kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
|
kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
|
||||||
# kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = np.array([3.4944e+01, 4.3734e+01, 9.6711e+01, 2.4429e+02, 5.8272e+02])
|
# kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = np.array([3.4944e+01, 4.3734e+01, 9.6711e+01, 2.4429e+02, 5.8272e+02])
|
||||||
# kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 3.1264e-01
|
# kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 3.1264e-01
|
||||||
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = np.array([100., 166., 500., 1000.])
|
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = 0.3 * np.array([100., 166., 500., 1000.])
|
||||||
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 1.
|
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 0.3 * 1.
|
||||||
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 4
|
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 4
|
||||||
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10.
|
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10.
|
||||||
|
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_replan_times'] = 4
|
||||||
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
|
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
|
@ -9,7 +9,15 @@ Fancy_ProDMP_IDS = fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP']
|
|||||||
|
|
||||||
All_ProDMP_IDS = fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP']
|
All_ProDMP_IDS = fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP']
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('env_id', Fancy_ProDMP_IDS)
|
@pytest.mark.parametrize('env_id', Fancy_ProDMP_IDS)
|
||||||
def test_prodmp_envs(env_id: str):
|
def test_replanning_envs(env_id: str):
|
||||||
"""Tests that ProDMP environments run without errors using random actions."""
|
"""Tests that ProDMP environments run without errors using random actions."""
|
||||||
run_env(env_id)
|
run_env(env_id, iterations=4)
|
||||||
|
|
||||||
|
# @pytest.mark.parametrize('env_id', All_ProDMP_IDS)
|
||||||
|
# def test_replanning_determinism(env_id: str):
|
||||||
|
# """Tests that ProDMP environments are deterministic."""
|
||||||
|
# run_env_determinism(env_id, 0)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user