update promp's config
This commit is contained in:
parent
7d16b420c1
commit
2a39a72af0
@ -72,8 +72,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
# self.action_space = spaces.Box(np.array(tricky_action_lowerbound), np.array(tricky_action_upperbound), dtype=np.float32)
|
||||
self.action_space.low[0] = 0.5
|
||||
self.action_space.high[0] = 1.5
|
||||
self.action_space.low[1] = 0.05
|
||||
self.action_space.high[1] = 0.2
|
||||
self.action_space.low[1] = 0.02
|
||||
self.action_space.high[1] = 0.15
|
||||
self.observation_space = self._get_observation_space()
|
||||
|
||||
# rendering
|
||||
|
@ -546,6 +546,7 @@ for _v in _versions:
|
||||
kwargs_dict_tt_promp['phase_generator_kwargs']['learn_delay'] = True
|
||||
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis'] = 3
|
||||
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2
|
||||
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_goal'] = 1
|
||||
kwargs_dict_tt_promp['black_box_kwargs']['duration'] = 2.
|
||||
kwargs_dict_tt_promp['black_box_kwargs']['verbose'] = 2
|
||||
register(
|
||||
|
@ -37,8 +37,8 @@ class MPWrapper(RawInterfaceWrapper):
|
||||
|
||||
def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \
|
||||
-> Tuple[np.ndarray, float, bool, dict]:
|
||||
tau_invalid_penalty = 0.3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
|
||||
delay_invalid_penalty = 0.3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
|
||||
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
|
||||
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
|
||||
violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
|
||||
violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0))
|
||||
invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \
|
||||
|
@ -128,7 +128,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
def reset_model(self):
|
||||
self._steps = 0
|
||||
self._init_ball_state = self._generate_valid_init_ball(random_pos=True, random_vel=False)
|
||||
self._goal_pos = self.np_random.uniform(low=self.context_bounds[0][-2:], high=self.context_bounds[1][-2:])
|
||||
self._goal_pos = self._generate_goal_pos(random=True)
|
||||
self.data.joint("tar_x").qpos = self._init_ball_state[0]
|
||||
self.data.joint("tar_y").qpos = self._init_ball_state[1]
|
||||
self.data.joint("tar_z").qpos = self._init_ball_state[2]
|
||||
@ -152,6 +152,11 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
self._racket_traj = []
|
||||
return self._get_obs()
|
||||
|
||||
def _generate_goal_pos(self, random=True):
|
||||
if random:
|
||||
return self.np_random.uniform(low=self.context_bounds[0][-2:], high=self.context_bounds[1][-2:])
|
||||
else:
|
||||
return np.array([-0.6, 0.4])
|
||||
|
||||
def _get_obs(self):
|
||||
obs = np.concatenate([
|
||||
@ -191,7 +196,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
x_pos = self.np_random.uniform(low=self.context_bounds[0][0], high=self.context_bounds[1][0])
|
||||
y_pos = self.np_random.uniform(low=self.context_bounds[0][1], high=self.context_bounds[1][1])
|
||||
if random_vel:
|
||||
x_vel = self.np_random.uniform(low=2.0, high=3.0, size=1)
|
||||
x_vel = self.np_random.uniform(low=2.0, high=3.0)
|
||||
init_ball_state = np.array([x_pos, y_pos, z_pos, x_vel, y_vel, z_vel])
|
||||
return init_ball_state
|
||||
|
||||
@ -201,12 +206,6 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
|
||||
return init_ball_state
|
||||
|
||||
def check_traj_validity(self, traj):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_invalid_steps(self, traj):
|
||||
penalty = -100
|
||||
return self._get_obs(), penalty, True, {}
|
||||
|
||||
if __name__ == "__main__":
|
||||
env = TableTennisEnv()
|
||||
|
@ -2,7 +2,7 @@ import numpy as np
|
||||
|
||||
jnt_pos_low = np.array([-2.6, -2.0, -2.8, -0.9, -4.8, -1.6, -2.2])
|
||||
jnt_pos_high = np.array([2.6, 2.0, 2.8, 3.1, 1.3, 1.6, 2.2])
|
||||
delay_bound = [0.05, 0.2]
|
||||
delay_bound = [0.05, 0.15]
|
||||
tau_bound = [0.5, 1.5]
|
||||
|
||||
net_height = 0.1
|
||||
|
31
fancy_gym/examples/plotting.py
Normal file
31
fancy_gym/examples/plotting.py
Normal file
@ -0,0 +1,31 @@
|
||||
import fancy_gym
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# This is the code that I am using to plot the data
|
||||
|
||||
|
||||
def plot_trajs(desired_traj, actual_traj, dim):
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(desired_traj[:, dim], label='desired')
|
||||
ax.plot(actual_traj[:, dim], label='actual')
|
||||
ax.legend()
|
||||
plt.show()
|
||||
|
||||
|
||||
def compare_desired_and_actual(env_id: str = "TableTennis4DProMP-v0"):
|
||||
env = fancy_gym.make(env_id, seed=0)
|
||||
env.traj_gen.basis_gn.show_basis(plot=True)
|
||||
env.reset()
|
||||
for _ in range(1):
|
||||
env.render(mode=None)
|
||||
action = env.action_space.sample()
|
||||
obs, reward, done, info = env.step(action)
|
||||
for i in range(1):
|
||||
plot_trajs(info['desired_pos_traj'], info['pos_traj'], i)
|
||||
# plot_trajs(info['desired_vel_traj'], info['vel_traj'], i)
|
||||
if done:
|
||||
env.reset()
|
||||
|
||||
if __name__ == "__main__":
|
||||
compare_desired_and_actual(env_id='TableTennis4DProMP-v0')
|
Loading…
Reference in New Issue
Block a user