diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py index 3b3ac55..3bbb543 100644 --- a/fancy_gym/black_box/black_box_wrapper.py +++ b/fancy_gym/black_box/black_box_wrapper.py @@ -72,8 +72,8 @@ class BlackBoxWrapper(gym.ObservationWrapper): # self.action_space = spaces.Box(np.array(tricky_action_lowerbound), np.array(tricky_action_upperbound), dtype=np.float32) self.action_space.low[0] = 0.5 self.action_space.high[0] = 1.5 - self.action_space.low[1] = 0.05 - self.action_space.high[1] = 0.2 + self.action_space.low[1] = 0.02 + self.action_space.high[1] = 0.15 self.observation_space = self._get_observation_space() # rendering diff --git a/fancy_gym/envs/__init__.py b/fancy_gym/envs/__init__.py index cb074e1..ffa478f 100644 --- a/fancy_gym/envs/__init__.py +++ b/fancy_gym/envs/__init__.py @@ -546,6 +546,7 @@ for _v in _versions: kwargs_dict_tt_promp['phase_generator_kwargs']['learn_delay'] = True kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis'] = 3 kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_start'] = 2 + kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis_zero_goal'] = 1 kwargs_dict_tt_promp['black_box_kwargs']['duration'] = 2. kwargs_dict_tt_promp['black_box_kwargs']['verbose'] = 2 register( diff --git a/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py b/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py index ac40f17..3c9da7f 100644 --- a/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py +++ b/fancy_gym/envs/mujoco/table_tennis/mp_wrapper.py @@ -37,8 +37,8 @@ class MPWrapper(RawInterfaceWrapper): def invalid_traj_callback(self, action, pos_traj: np.ndarray, vel_traj: np.ndarray) \ -> Tuple[np.ndarray, float, bool, dict]: - tau_invalid_penalty = 0.3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]])) - delay_invalid_penalty = 0.3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]])) + tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]])) + delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]])) violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0)) violate_low_bound_error = np.mean(np.maximum(jnt_pos_low - pos_traj, 0)) invalid_penalty = tau_invalid_penalty + delay_invalid_penalty + \ diff --git a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py index 93f1c29..eed0926 100644 --- a/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py +++ b/fancy_gym/envs/mujoco/table_tennis/table_tennis_env.py @@ -128,7 +128,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): def reset_model(self): self._steps = 0 self._init_ball_state = self._generate_valid_init_ball(random_pos=True, random_vel=False) - self._goal_pos = self.np_random.uniform(low=self.context_bounds[0][-2:], high=self.context_bounds[1][-2:]) + self._goal_pos = self._generate_goal_pos(random=True) self.data.joint("tar_x").qpos = self._init_ball_state[0] self.data.joint("tar_y").qpos = self._init_ball_state[1] self.data.joint("tar_z").qpos = self._init_ball_state[2] @@ -152,6 +152,11 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): self._racket_traj = [] return self._get_obs() + def _generate_goal_pos(self, random=True): + if random: + return self.np_random.uniform(low=self.context_bounds[0][-2:], high=self.context_bounds[1][-2:]) + else: + return np.array([-0.6, 0.4]) def _get_obs(self): obs = np.concatenate([ @@ -191,7 +196,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): x_pos = self.np_random.uniform(low=self.context_bounds[0][0], high=self.context_bounds[1][0]) y_pos = self.np_random.uniform(low=self.context_bounds[0][1], high=self.context_bounds[1][1]) if random_vel: - x_vel = self.np_random.uniform(low=2.0, high=3.0, size=1) + x_vel = self.np_random.uniform(low=2.0, high=3.0) init_ball_state = np.array([x_pos, y_pos, z_pos, x_vel, y_vel, z_vel]) return init_ball_state @@ -201,12 +206,6 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle): init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel) return init_ball_state - def check_traj_validity(self, traj): - raise NotImplementedError - - def get_invalid_steps(self, traj): - penalty = -100 - return self._get_obs(), penalty, True, {} if __name__ == "__main__": env = TableTennisEnv() diff --git a/fancy_gym/envs/mujoco/table_tennis/table_tennis_utils.py b/fancy_gym/envs/mujoco/table_tennis/table_tennis_utils.py index b14c160..66f68d2 100644 --- a/fancy_gym/envs/mujoco/table_tennis/table_tennis_utils.py +++ b/fancy_gym/envs/mujoco/table_tennis/table_tennis_utils.py @@ -2,7 +2,7 @@ import numpy as np jnt_pos_low = np.array([-2.6, -2.0, -2.8, -0.9, -4.8, -1.6, -2.2]) jnt_pos_high = np.array([2.6, 2.0, 2.8, 3.1, 1.3, 1.6, 2.2]) -delay_bound = [0.05, 0.2] +delay_bound = [0.05, 0.15] tau_bound = [0.5, 1.5] net_height = 0.1 diff --git a/fancy_gym/examples/plotting.py b/fancy_gym/examples/plotting.py new file mode 100644 index 0000000..cb5f866 --- /dev/null +++ b/fancy_gym/examples/plotting.py @@ -0,0 +1,31 @@ +import fancy_gym +import numpy as np +import matplotlib.pyplot as plt + +# This is the code that I am using to plot the data + + +def plot_trajs(desired_traj, actual_traj, dim): + fig, ax = plt.subplots() + ax.plot(desired_traj[:, dim], label='desired') + ax.plot(actual_traj[:, dim], label='actual') + ax.legend() + plt.show() + + +def compare_desired_and_actual(env_id: str = "TableTennis4DProMP-v0"): + env = fancy_gym.make(env_id, seed=0) + env.traj_gen.basis_gn.show_basis(plot=True) + env.reset() + for _ in range(1): + env.render(mode=None) + action = env.action_space.sample() + obs, reward, done, info = env.step(action) + for i in range(1): + plot_trajs(info['desired_pos_traj'], info['pos_traj'], i) + # plot_trajs(info['desired_vel_traj'], info['vel_traj'], i) + if done: + env.reset() + +if __name__ == "__main__": + compare_desired_and_actual(env_id='TableTennis4DProMP-v0') \ No newline at end of file