update according to reviews opinion & fix bugs in box pushing IK
This commit is contained in:
parent
fc3051bf57
commit
2674bf80fe
@ -24,7 +24,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
|
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
|
||||||
reward_aggregation: Callable[[np.ndarray], float] = np.sum,
|
reward_aggregation: Callable[[np.ndarray], float] = np.sum,
|
||||||
max_planning_times: int = None,
|
max_planning_times: int = None,
|
||||||
desired_traj_bc: bool = False
|
condition_on_desired: bool = False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
gym.Wrapper for leveraging a black box approach with a trajectory generator.
|
gym.Wrapper for leveraging a black box approach with a trajectory generator.
|
||||||
@ -71,12 +71,12 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
# condition value
|
# condition value
|
||||||
self.desired_traj_bc = desired_traj_bc
|
self.condition_on_desired = condition_on_desired
|
||||||
self.condition_pos = None
|
self.condition_pos = None
|
||||||
self.condition_vel = None
|
self.condition_vel = None
|
||||||
|
|
||||||
self.max_planning_times = max_planning_times
|
self.max_planning_times = max_planning_times
|
||||||
self.plan_counts = 0
|
self.plan_steps = 0
|
||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
# return context space if we are
|
# return context space if we are
|
||||||
@ -98,15 +98,11 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
|
bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
|
||||||
# TODO we could think about initializing with the previous desired value in order to have a smooth transition
|
# TODO we could think about initializing with the previous desired value in order to have a smooth transition
|
||||||
# at least from the planning point of view.
|
# at least from the planning point of view.
|
||||||
# self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel)
|
|
||||||
if self.current_traj_steps == 0:
|
|
||||||
self.condition_pos = self.current_pos
|
|
||||||
self.condition_vel = self.current_vel
|
|
||||||
|
|
||||||
bc_time = torch.as_tensor(bc_time, dtype=torch.float32)
|
condition_pos = self.condition_pos if self.condition_pos is not None else self.current_pos
|
||||||
self.condition_pos = torch.as_tensor(self.condition_pos, dtype=torch.float32)
|
condition_vel = self.condition_vel if self.condition_vel is not None else self.current_vel
|
||||||
self.condition_vel = torch.as_tensor(self.condition_vel, dtype=torch.float32)
|
|
||||||
self.traj_gen.set_boundary_conditions(bc_time, self.condition_pos, self.condition_vel)
|
self.traj_gen.set_boundary_conditions(bc_time, condition_pos, condition_vel)
|
||||||
self.traj_gen.set_duration(duration, self.dt)
|
self.traj_gen.set_duration(duration, self.dt)
|
||||||
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
||||||
position = get_numpy(self.traj_gen.get_traj_pos())
|
position = get_numpy(self.traj_gen.get_traj_pos())
|
||||||
@ -164,7 +160,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
infos = dict()
|
infos = dict()
|
||||||
done = False
|
done = False
|
||||||
|
|
||||||
self.plan_counts += 1
|
self.plan_steps += 1
|
||||||
for t, (pos, vel) in enumerate(zip(position, velocity)):
|
for t, (pos, vel) in enumerate(zip(position, velocity)):
|
||||||
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
|
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
|
||||||
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
|
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
|
||||||
@ -186,11 +182,11 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
|
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
|
||||||
t + 1 + self.current_traj_steps):
|
t + 1 + self.current_traj_steps):
|
||||||
|
|
||||||
if self.max_planning_times is not None and self.plan_counts >= self.max_planning_times:
|
if self.max_planning_times is not None and self.plan_steps >= self.max_planning_times:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.condition_pos = pos if self.desired_traj_bc else self.current_pos
|
self.condition_pos = pos if self.condition_on_desired else None
|
||||||
self.condition_vel = vel if self.desired_traj_bc else self.current_vel
|
self.condition_vel = vel if self.condition_on_desired else None
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -215,6 +211,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None):
|
def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None):
|
||||||
self.current_traj_steps = 0
|
self.current_traj_steps = 0
|
||||||
self.plan_counts = 0
|
self.plan_steps = 0
|
||||||
self.traj_gen.reset()
|
self.traj_gen.reset()
|
||||||
return super(BlackBoxWrapper, self).reset()
|
return super(BlackBoxWrapper, self).reset()
|
||||||
|
@ -68,12 +68,9 @@ DEFAULT_BB_DICT_ProDMP = {
|
|||||||
"wrappers": [],
|
"wrappers": [],
|
||||||
"trajectory_generator_kwargs": {
|
"trajectory_generator_kwargs": {
|
||||||
'trajectory_generator_type': 'prodmp',
|
'trajectory_generator_type': 'prodmp',
|
||||||
'weights_scale': 1.0,
|
|
||||||
},
|
},
|
||||||
"phase_generator_kwargs": {
|
"phase_generator_kwargs": {
|
||||||
'phase_generator_type': 'exp',
|
'phase_generator_type': 'exp',
|
||||||
'learn_delay': False,
|
|
||||||
'learn_tau': False,
|
|
||||||
},
|
},
|
||||||
"controller_kwargs": {
|
"controller_kwargs": {
|
||||||
'controller_type': 'motor',
|
'controller_type': 'motor',
|
||||||
@ -86,10 +83,6 @@ DEFAULT_BB_DICT_ProDMP = {
|
|||||||
'num_basis': 5,
|
'num_basis': 5,
|
||||||
},
|
},
|
||||||
"black_box_kwargs": {
|
"black_box_kwargs": {
|
||||||
'replanning_schedule': None,
|
|
||||||
'max_planning_times': None,
|
|
||||||
'desired_traj_bc': False,
|
|
||||||
'verbose': 2
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -492,7 +485,7 @@ for _v in _versions:
|
|||||||
|
|
||||||
for _v in _versions:
|
for _v in _versions:
|
||||||
_name = _v.split("-")
|
_name = _v.split("-")
|
||||||
_env_id = f'{_name[0]}ProDMP-{_name[1]}'
|
_env_id = f'{_name[0]}ReplanProDMP-{_name[1]}'
|
||||||
kwargs_dict_box_pushing_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
kwargs_dict_box_pushing_prodmp = deepcopy(DEFAULT_BB_DICT_ProDMP)
|
||||||
kwargs_dict_box_pushing_prodmp['wrappers'].append(mujoco.box_pushing.MPWrapper)
|
kwargs_dict_box_pushing_prodmp['wrappers'].append(mujoco.box_pushing.MPWrapper)
|
||||||
kwargs_dict_box_pushing_prodmp['name'] = _v
|
kwargs_dict_box_pushing_prodmp['name'] = _v
|
||||||
@ -502,13 +495,12 @@ for _v in _versions:
|
|||||||
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 0.3
|
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 0.3
|
||||||
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['auto_scale_basis'] = True
|
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['auto_scale_basis'] = True
|
||||||
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_offset'] = 1.0
|
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_offset'] = 1.0
|
||||||
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 0
|
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 4
|
||||||
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10.
|
|
||||||
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3
|
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3
|
||||||
kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
|
kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
|
||||||
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 2
|
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_planning_times'] = 2
|
||||||
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
|
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
|
||||||
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['desired_traj_bc'] = True
|
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['condition_on_desried'] = True
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
entry_point='fancy_gym.utils.make_env_helpers:make_bb_env_helper',
|
||||||
|
@ -219,6 +219,8 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
|||||||
q_old = q
|
q_old = q
|
||||||
q = q + dt * qd_d
|
q = q + dt * qd_d
|
||||||
q = np.clip(q, q_min, q_max)
|
q = np.clip(q, q_min, q_max)
|
||||||
|
self.data.qpos[:7] = q
|
||||||
|
mujoco.mj_forward(self.model, self.data)
|
||||||
current_cart_pos = self.data.body("tcp").xpos.copy()
|
current_cart_pos = self.data.body("tcp").xpos.copy()
|
||||||
current_cart_quat = self.data.body("tcp").xquat.copy()
|
current_cart_quat = self.data.body("tcp").xquat.copy()
|
||||||
|
|
||||||
@ -247,8 +249,10 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
|||||||
### get Jacobian by mujoco
|
### get Jacobian by mujoco
|
||||||
self.data.qpos[:7] = q
|
self.data.qpos[:7] = q
|
||||||
mujoco.mj_forward(self.model, self.data)
|
mujoco.mj_forward(self.model, self.data)
|
||||||
|
|
||||||
jacp = self.get_body_jacp("tcp")[:, :7].copy()
|
jacp = self.get_body_jacp("tcp")[:, :7].copy()
|
||||||
jacr = self.get_body_jacr("tcp")[:, :7].copy()
|
jacr = self.get_body_jacr("tcp")[:, :7].copy()
|
||||||
|
|
||||||
J = np.concatenate((jacp, jacr), axis=0)
|
J = np.concatenate((jacp, jacr), axis=0)
|
||||||
|
|
||||||
Jw = J.dot(w)
|
Jw = J.dot(w)
|
||||||
@ -356,14 +360,3 @@ class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase):
|
|||||||
reward += box_goal_pos_dist_reward + box_goal_rot_dist_reward
|
reward += box_goal_pos_dist_reward + box_goal_rot_dist_reward
|
||||||
|
|
||||||
return reward
|
return reward
|
||||||
|
|
||||||
if __name__=="__main__":
|
|
||||||
env = BoxPushingTemporalSpatialSparse(frame_skip=10)
|
|
||||||
env.reset()
|
|
||||||
for i in range(10):
|
|
||||||
env.reset()
|
|
||||||
for _ in range(100):
|
|
||||||
env.render("human")
|
|
||||||
action = env.action_space.sample()
|
|
||||||
obs, reward, done, info = env.step(action)
|
|
||||||
print("info: {}".format(info))
|
|
||||||
|
@ -1,38 +1,62 @@
|
|||||||
import fancy_gym
|
import fancy_gym
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
def plot_trajectory(traj):
|
def example_run_replanning_env(env_name="BoxPushingDenseReplanProDMP-v0", seed=1, iterations=1, render=False):
|
||||||
plt.figure()
|
|
||||||
plt.plot(traj[:, 3])
|
|
||||||
plt.legend()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
def run_replanning_envs(env_name="BoxPushingProDMP-v0", seed=1, iterations=1, render=True):
|
|
||||||
env = fancy_gym.make(env_name, seed=seed)
|
env = fancy_gym.make(env_name, seed=seed)
|
||||||
env.reset()
|
env.reset()
|
||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
done = False
|
done = False
|
||||||
desired_pos_traj = np.zeros((100, 7))
|
|
||||||
desired_vel_traj = np.zeros((100, 7))
|
|
||||||
real_pos_traj = np.zeros((100, 7))
|
|
||||||
real_vel_traj = np.zeros((100, 7))
|
|
||||||
t = 0
|
|
||||||
while done is False:
|
while done is False:
|
||||||
ac = env.action_space.sample()
|
ac = env.action_space.sample()
|
||||||
obs, reward, done, info = env.step(ac)
|
obs, reward, done, info = env.step(ac)
|
||||||
desired_pos_traj[t: t + 25, :] = info['desired_pos']
|
|
||||||
desired_vel_traj[t: t + 25, :] = info['desired_vel']
|
|
||||||
# real_pos_traj.append(info['current_pos'])
|
|
||||||
# real_vel_traj.append(info['current_vel'])
|
|
||||||
t += 25
|
|
||||||
if render:
|
if render:
|
||||||
env.render(mode="human")
|
env.render(mode="human")
|
||||||
if done:
|
if done:
|
||||||
env.reset()
|
env.reset()
|
||||||
plot_trajectory(desired_pos_traj)
|
|
||||||
env.close()
|
env.close()
|
||||||
del env
|
del env
|
||||||
|
|
||||||
|
def example_custom_replanning_envs(seed=0, iteration=100, render=True):
|
||||||
|
# id for a step-based environment
|
||||||
|
base_env_id = "BoxPushingDense-v0"
|
||||||
|
|
||||||
|
wrappers = [fancy_gym.envs.mujoco.box_pushing.mp_wrapper.MPWrapper]
|
||||||
|
|
||||||
|
trajectory_generator_kwargs = {'trajectory_generator_type': 'prodmp',
|
||||||
|
'weight_scale': 1}
|
||||||
|
phase_generator_kwargs = {'phase_generator_type': 'exp'}
|
||||||
|
controller_kwargs = {'controller_type': 'velocity'}
|
||||||
|
basis_generator_kwargs = {'basis_generator_type': 'prodmp',
|
||||||
|
'num_basis': 5}
|
||||||
|
|
||||||
|
# max_planning_times: the maximum number of plans can be generated
|
||||||
|
# replanning_schedule: the trigger for replanning
|
||||||
|
# condition_on_desired: use desired state as the boundary condition for the next plan
|
||||||
|
black_box_kwargs = {'max_planning_times': 4,
|
||||||
|
'replanning_schedule': lambda pos, vel, obs, action, t: t % 25 == 0,
|
||||||
|
'desired_traj_bc': True}
|
||||||
|
|
||||||
|
env = fancy_gym.make_bb(env_id=base_env_id, wrappers=wrappers, black_box_kwargs=black_box_kwargs,
|
||||||
|
traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs,
|
||||||
|
phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs,
|
||||||
|
seed=seed)
|
||||||
|
if render:
|
||||||
|
env.render(mode="human")
|
||||||
|
|
||||||
|
obs = env.reset()
|
||||||
|
|
||||||
|
for i in range(iteration):
|
||||||
|
ac = env.action_space.sample()
|
||||||
|
obs, reward, done, info = env.step(ac)
|
||||||
|
if done:
|
||||||
|
env.reset()
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
del env
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_replanning_envs(env_name="BoxPushingDenseProDMP-v0", seed=1, iterations=1, render=False)
|
# run a registered replanning environment
|
||||||
|
example_run_replanning_env(env_name="BoxPushingDenseReplanProDMP-v0", seed=1, iterations=1, render=False)
|
||||||
|
|
||||||
|
# run a custom replanning environment
|
||||||
|
example_custom_replanning_envs(seed=0, iteration=100, render=True)
|
@ -1,9 +0,0 @@
|
|||||||
import gym_blockpush
|
|
||||||
import gym
|
|
||||||
|
|
||||||
env = gym.make("blockpush-v0")
|
|
||||||
env.start()
|
|
||||||
env.scene.reset()
|
|
||||||
for i in range(100):
|
|
||||||
env.step(env.action_space.sample())
|
|
||||||
env.render()
|
|
@ -164,7 +164,7 @@ if __name__ == '__main__':
|
|||||||
example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
|
example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
# ProDMP
|
# ProDMP
|
||||||
example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=4, render=render)
|
example_mp("BoxPushingDenseReplanProDMP-v0", seed=10, iterations=4, render=render)
|
||||||
|
|
||||||
# Altered basis functions
|
# Altered basis functions
|
||||||
obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render)
|
||||||
|
@ -175,9 +175,6 @@ def make_bb(
|
|||||||
if phase_kwargs.get('learn_delay'):
|
if phase_kwargs.get('learn_delay'):
|
||||||
phase_kwargs["delay_bound"] = [0, black_box_kwargs['duration'] - env.dt * 2]
|
phase_kwargs["delay_bound"] = [0, black_box_kwargs['duration'] - env.dt * 2]
|
||||||
|
|
||||||
if traj_gen_kwargs['trajectory_generator_type'] == 'prodmp':
|
|
||||||
assert basis_kwargs['basis_generator_type'] == 'prodmp', 'prodmp trajectory generator requires prodmp basis generator'
|
|
||||||
|
|
||||||
phase_gen = get_phase_generator(**phase_kwargs)
|
phase_gen = get_phase_generator(**phase_kwargs)
|
||||||
basis_gen = get_basis_generator(phase_generator=phase_gen, **basis_kwargs)
|
basis_gen = get_basis_generator(phase_generator=phase_gen, **basis_kwargs)
|
||||||
controller = get_controller(**controller_kwargs)
|
controller = get_controller(**controller_kwargs)
|
||||||
|
@ -158,7 +158,7 @@ def test_context_space(mp_type: str, env_wrap: Tuple[str, Type[RawInterfaceWrapp
|
|||||||
|
|
||||||
@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])
|
@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])
|
||||||
@pytest.mark.parametrize('num_dof', [0, 1, 2, 5])
|
@pytest.mark.parametrize('num_dof', [0, 1, 2, 5])
|
||||||
@pytest.mark.parametrize('num_basis', [0, 2, 5]) # should add 1 back after the bug is fixed
|
@pytest.mark.parametrize('num_basis', [0, 1, 2, 5])
|
||||||
@pytest.mark.parametrize('learn_tau', [True, False])
|
@pytest.mark.parametrize('learn_tau', [True, False])
|
||||||
@pytest.mark.parametrize('learn_delay', [True, False])
|
@pytest.mark.parametrize('learn_delay', [True, False])
|
||||||
def test_action_space(mp_type: str, num_dof: int, num_basis: int, learn_tau: bool, learn_delay: bool):
|
def test_action_space(mp_type: str, num_dof: int, num_basis: int, learn_tau: bool, learn_delay: bool):
|
||||||
@ -344,31 +344,3 @@ def test_learn_tau_and_delay(mp_type: str, tau: float, delay: float):
|
|||||||
active_vel = vel[delay_time_steps: joint_time_steps - 2]
|
active_vel = vel[delay_time_steps: joint_time_steps - 2]
|
||||||
assert np.all(active_pos != pos[-1]) and np.all(active_pos != pos[0])
|
assert np.all(active_pos != pos[-1]) and np.all(active_pos != pos[0])
|
||||||
assert np.all(active_vel != vel[-1]) and np.all(active_vel != vel[0])
|
assert np.all(active_vel != vel[-1]) and np.all(active_vel != vel[0])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('mp_type', ['promp', 'prodmp'])
|
|
||||||
@pytest.mark.parametrize('max_planning_times', [1, 2, 3, 4])
|
|
||||||
@pytest.mark.parametrize('sub_segment_steps', [5, 10])
|
|
||||||
def test_replanning_schedule(mp_type: str, max_planning_times: int, sub_segment_steps: int):
|
|
||||||
basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
|
|
||||||
phase_generator_type = 'exp' if mp_type == 'prodmp' else 'linear'
|
|
||||||
env = fancy_gym.make_bb('toy-v0', [ToyWrapper],
|
|
||||||
{'max_planning_times': max_planning_times,
|
|
||||||
'replanning_schedule': lambda pos, vel, obs, action, t: t % sub_segment_steps == 0,
|
|
||||||
'verbose': 2},
|
|
||||||
{'trajectory_generator_type': mp_type,
|
|
||||||
},
|
|
||||||
{'controller_type': 'motor'},
|
|
||||||
{'phase_generator_type': phase_generator_type,
|
|
||||||
'learn_tau': False,
|
|
||||||
'learn_delay': False
|
|
||||||
},
|
|
||||||
{'basis_generator_type': basis_generator_type,
|
|
||||||
},
|
|
||||||
seed=SEED)
|
|
||||||
_ = env.reset()
|
|
||||||
d = False
|
|
||||||
for i in range(max_planning_times):
|
|
||||||
_, _, d, _ = env.step(env.action_space.sample())
|
|
||||||
assert d
|
|
||||||
|
|
||||||
|
@ -1,82 +0,0 @@
|
|||||||
from itertools import chain
|
|
||||||
from typing import Tuple, Type, Union, Optional, Callable
|
|
||||||
|
|
||||||
import gym
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
from gym import register
|
|
||||||
from gym.core import ActType, ObsType
|
|
||||||
|
|
||||||
import fancy_gym
|
|
||||||
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
|
||||||
|
|
||||||
import fancy_gym
|
|
||||||
from test.utils import run_env, run_env_determinism
|
|
||||||
|
|
||||||
Fancy_ProDMP_IDS = fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP']
|
|
||||||
|
|
||||||
All_ProDMP_IDS = fancy_gym.ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS['ProDMP']
|
|
||||||
|
|
||||||
|
|
||||||
class Object(object):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ToyEnv(gym.Env):
|
|
||||||
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64)
|
|
||||||
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float64)
|
|
||||||
dt = 0.02
|
|
||||||
|
|
||||||
def __init__(self, a: int = 0, b: float = 0.0, c: list = [], d: dict = {}, e: Object = Object()):
|
|
||||||
self.a, self.b, self.c, self.d, self.e = a, b, c, d, e
|
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, return_info: bool = False,
|
|
||||||
options: Optional[dict] = None) -> Union[ObsType, Tuple[ObsType, dict]]:
|
|
||||||
return np.array([-1])
|
|
||||||
|
|
||||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
|
|
||||||
return np.array([-1]), 1, False, {}
|
|
||||||
|
|
||||||
def render(self, mode="human"):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ToyWrapper(RawInterfaceWrapper):
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
||||||
return np.ones(self.action_space.shape)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
||||||
return np.zeros(self.action_space.shape)
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
|
||||||
def setup():
|
|
||||||
register(
|
|
||||||
id=f'toy-v0',
|
|
||||||
entry_point='test.test_black_box:ToyEnv',
|
|
||||||
max_episode_steps=50,
|
|
||||||
)
|
|
||||||
# @pytest.mark.parametrize('env_id', All_ProDMP_IDS)
|
|
||||||
# def test_replanning_envs(env_id: str):
|
|
||||||
# """Tests that ProDMP environments run without errors using random actions."""
|
|
||||||
# run_env(env_id)
|
|
||||||
#
|
|
||||||
# @pytest.mark.parametrize('env_id', All_ProDMP_IDS)
|
|
||||||
# def test_replanning_determinism(env_id: str):
|
|
||||||
# """Tests that ProDMP environments are deterministic."""
|
|
||||||
# run_env_determinism(env_id, 0)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('mp_type', ['promp', 'dmp', 'prodmp'])
|
|
||||||
def test_missing_local_state(mp_type: str):
|
|
||||||
basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
|
|
||||||
|
|
||||||
env = fancy_gym.make_bb('toy-v0', [RawInterfaceWrapper], {},
|
|
||||||
{'trajectory_generator_type': mp_type},
|
|
||||||
{'controller_type': 'motor'},
|
|
||||||
{'phase_generator_type': 'exp'},
|
|
||||||
{'basis_generator_type': basis_generator_type})
|
|
||||||
env.reset()
|
|
||||||
with pytest.raises(NotImplementedError):
|
|
||||||
env.step(env.action_space.sample())
|
|
@ -305,3 +305,29 @@ def test_replanning_with_learn_delay_and_tau(mp_type: str, max_planning_times: i
|
|||||||
planning_times += 1
|
planning_times += 1
|
||||||
|
|
||||||
assert planning_times == max_planning_times
|
assert planning_times == max_planning_times
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('mp_type', ['promp', 'prodmp'])
|
||||||
|
@pytest.mark.parametrize('max_planning_times', [1, 2, 3, 4])
|
||||||
|
@pytest.mark.parametrize('sub_segment_steps', [5, 10])
|
||||||
|
def test_replanning_schedule(mp_type: str, max_planning_times: int, sub_segment_steps: int):
|
||||||
|
basis_generator_type = 'prodmp' if mp_type == 'prodmp' else 'rbf'
|
||||||
|
phase_generator_type = 'exp' if mp_type == 'prodmp' else 'linear'
|
||||||
|
env = fancy_gym.make_bb('toy-v0', [ToyWrapper],
|
||||||
|
{'max_planning_times': max_planning_times,
|
||||||
|
'replanning_schedule': lambda pos, vel, obs, action, t: t % sub_segment_steps == 0,
|
||||||
|
'verbose': 2},
|
||||||
|
{'trajectory_generator_type': mp_type,
|
||||||
|
},
|
||||||
|
{'controller_type': 'motor'},
|
||||||
|
{'phase_generator_type': phase_generator_type,
|
||||||
|
'learn_tau': False,
|
||||||
|
'learn_delay': False
|
||||||
|
},
|
||||||
|
{'basis_generator_type': basis_generator_type,
|
||||||
|
},
|
||||||
|
seed=SEED)
|
||||||
|
_ = env.reset()
|
||||||
|
d = False
|
||||||
|
for i in range(max_planning_times):
|
||||||
|
_, _, d, _ = env.step(env.action_space.sample())
|
||||||
|
assert d
|
||||||
|
Loading…
Reference in New Issue
Block a user