fix tau bound and init bound bug
This commit is contained in:
parent
053a17889f
commit
bd7e811a64
@ -55,6 +55,14 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
# self.traj_gen.set_mp_times(self.time_steps)
|
||||
self.traj_gen.set_duration(self.duration, self.dt)
|
||||
|
||||
# check
|
||||
self.tau_bound = [-np.inf, np.inf]
|
||||
self.delay_bound = [-np.inf, np.inf]
|
||||
if hasattr(self.traj_gen.phase_gn, "tau_bound"):
|
||||
self.tau_bound = self.traj_gen.phase_gn.tau_bound
|
||||
if hasattr(self.traj_gen.phase_gn, "delay_bound"):
|
||||
self.delay_bound = self.traj_gen.phase_gn.delay_bound
|
||||
|
||||
# reward computation
|
||||
self.reward_aggregation = reward_aggregation
|
||||
|
||||
@ -139,7 +147,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
|
||||
position, velocity = self.get_trajectory(action)
|
||||
position, velocity = self.env.set_episode_arguments(action, position, velocity)
|
||||
traj_is_valid, position, velocity = self.env.preprocessing_and_validity_callback(action, position, velocity)
|
||||
traj_is_valid, position, velocity = self.env.preprocessing_and_validity_callback(action, position, velocity,
|
||||
self.tau_bound, self.delay_bound)
|
||||
|
||||
trajectory_length = len(position)
|
||||
rewards = np.zeros(shape=(trajectory_length,))
|
||||
@ -153,7 +162,8 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
|
||||
if not traj_is_valid:
|
||||
obs, trajectory_return, done, infos = self.env.invalid_traj_callback(action, position, velocity,
|
||||
self.return_context_observation)
|
||||
self.return_context_observation,
|
||||
self.tau_bound, self.delay_bound)
|
||||
return self.observation(obs), trajectory_return, done, infos
|
||||
|
||||
self.plan_steps += 1
|
||||
|
@ -52,7 +52,8 @@ class RawInterfaceWrapper(gym.Wrapper):
|
||||
"""
|
||||
return self.env.dt
|
||||
|
||||
def preprocessing_and_validity_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) \
|
||||
def preprocessing_and_validity_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray,
|
||||
tau_bound: list = None, delay_bound: list = None ) \
|
||||
-> Tuple[bool, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Used to preprocess the action and check if the desired trajectory is valid.
|
||||
@ -61,6 +62,8 @@ class RawInterfaceWrapper(gym.Wrapper):
|
||||
specified, else only traj_gen parameters
|
||||
pos_traj: a vector instance of the raw position trajectory
|
||||
vel_traj: a vector instance of the raw velocity trajectory
|
||||
tau_bound: a list of two elements, the lower and upper bound of the trajectory length scaling factor
|
||||
delay_bound: a list of two elements, the lower and upper bound of the time to wait before execute
|
||||
Returns:
|
||||
validity flag: bool, True if the raw trajectory is valid, False if not
|
||||
pos_traj: a vector instance of the preprocessed position trajectory
|
||||
@ -97,7 +100,8 @@ class RawInterfaceWrapper(gym.Wrapper):
|
||||
"""
|
||||
return True
|
||||
|
||||
def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray) -> Tuple[np.ndarray, float, bool, dict]:
|
||||
def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray,
|
||||
tau_bound: list, delay_bound: list) -> Tuple[np.ndarray, float, bool, dict]:
|
||||
"""
|
||||
Used to return a artificial return from the env if the desired trajectory is invalid.
|
||||
Args:
|
||||
@ -105,6 +109,8 @@ class RawInterfaceWrapper(gym.Wrapper):
|
||||
specified, else only traj_gen parameters
|
||||
pos_traj: a vector instance of the raw position trajectory
|
||||
vel_traj: a vector instance of the raw velocity trajectory
|
||||
tau_bound: a list of two elements, the lower and upper bound of the trajectory length scaling factor
|
||||
delay_bound: a list of two elements, the lower and upper bound of the time to wait before execute
|
||||
Returns:
|
||||
obs: artificial observation if the trajectory is invalid, by default a zero vector
|
||||
reward: artificial reward if the trajectory is invalid, by default 0
|
||||
|
@ -557,8 +557,8 @@ for _v in _versions:
|
||||
kwargs_dict_tt_promp['name'] = _v
|
||||
kwargs_dict_tt_promp['controller_kwargs']['p_gains'] = 0.5 * np.array([1.0, 4.0, 2.0, 4.0, 1.0, 4.0, 1.0])
|
||||
kwargs_dict_tt_promp['controller_kwargs']['d_gains'] = 0.5 * np.array([0.1, 0.4, 0.2, 0.4, 0.1, 0.4, 0.1])
|
||||
kwargs_dict_tt_promp['phase_generator_kwargs']['learn_tau'] = False
|
||||
kwargs_dict_tt_promp['phase_generator_kwargs']['learn_delay'] = False
|
||||
kwargs_dict_tt_promp['phase_generator_kwargs']['learn_tau'] = True
|
||||
kwargs_dict_tt_promp['phase_generator_kwargs']['learn_delay'] = True
|
||||
kwargs_dict_tt_promp['phase_generator_kwargs']['tau_bound'] = [0.8, 1.5]
|
||||
kwargs_dict_tt_promp['phase_generator_kwargs']['delay_bound'] = [0.05, 0.15]
|
||||
kwargs_dict_tt_promp['basis_generator_kwargs']['num_basis'] = 3
|
||||
|
@ -29,15 +29,16 @@ class TT_MPWrapper(RawInterfaceWrapper):
|
||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
return self.data.qvel[:7].copy()
|
||||
|
||||
def preprocessing_and_validity_callback(self, action, pos_traj, vel_traj):
|
||||
return self.check_traj_validity(action, pos_traj, vel_traj)
|
||||
def preprocessing_and_validity_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray,
|
||||
tau_bound: list, delay_bound:list):
|
||||
return self.check_traj_validity(action, pos_traj, vel_traj, tau_bound, delay_bound)
|
||||
|
||||
def set_episode_arguments(self, action, pos_traj, vel_traj):
|
||||
return pos_traj, vel_traj
|
||||
|
||||
def invalid_traj_callback(self, action: np.ndarray, pos_traj: np.ndarray, vel_traj: np.ndarray,
|
||||
return_contextual_obs: bool) -> Tuple[np.ndarray, float, bool, dict]:
|
||||
return self.get_invalid_traj_step_return(action, pos_traj, return_contextual_obs)
|
||||
return_contextual_obs: bool, tau_bound:list, delay_bound:list) -> Tuple[np.ndarray, float, bool, dict]:
|
||||
return self.get_invalid_traj_step_return(action, pos_traj, return_contextual_obs, tau_bound, delay_bound)
|
||||
|
||||
class TTVelObs_MPWrapper(TT_MPWrapper):
|
||||
|
||||
|
@ -5,7 +5,7 @@ from gym import utils, spaces
|
||||
from gym.envs.mujoco import MujocoEnv
|
||||
|
||||
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import is_init_state_valid, magnus_force
|
||||
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high, delay_bound, tau_bound
|
||||
from fancy_gym.envs.mujoco.table_tennis.table_tennis_utils import jnt_pos_low, jnt_pos_high
|
||||
|
||||
import mujoco
|
||||
|
||||
@ -225,7 +225,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
init_ball_state = self._generate_random_ball(random_pos=random_pos, random_vel=random_vel)
|
||||
return init_ball_state
|
||||
|
||||
def _get_traj_invalid_penalty(self, action, pos_traj):
|
||||
def _get_traj_invalid_penalty(self, action, pos_traj, tau_bound, delay_bound):
|
||||
tau_invalid_penalty = 3 * (np.max([0, action[0] - tau_bound[1]]) + np.max([0, tau_bound[0] - action[0]]))
|
||||
delay_invalid_penalty = 3 * (np.max([0, action[1] - delay_bound[1]]) + np.max([0, delay_bound[0] - action[1]]))
|
||||
violate_high_bound_error = np.mean(np.maximum(pos_traj - jnt_pos_high, 0))
|
||||
@ -234,9 +234,9 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
violate_high_bound_error + violate_low_bound_error
|
||||
return -invalid_penalty
|
||||
|
||||
def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs):
|
||||
def get_invalid_traj_step_return(self, action, pos_traj, contextual_obs, tau_bound, delay_bound):
|
||||
obs = self._get_obs() if contextual_obs else np.concatenate([self._get_obs(), np.array([0])]) # 0 for invalid traj
|
||||
penalty = self._get_traj_invalid_penalty(action, pos_traj)
|
||||
penalty = self._get_traj_invalid_penalty(action, pos_traj, tau_bound, delay_bound)
|
||||
return obs, penalty, True, {
|
||||
"hit_ball": [False],
|
||||
"ball_returned_success": [False],
|
||||
@ -247,7 +247,7 @@ class TableTennisEnv(MujocoEnv, utils.EzPickle):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def check_traj_validity(action, pos_traj, vel_traj):
|
||||
def check_traj_validity(action, pos_traj, vel_traj, tau_bound, delay_bound):
|
||||
time_invalid = action[0] > tau_bound[1] or action[0] < tau_bound[0] \
|
||||
or action[1] > delay_bound[1] or action[1] < delay_bound[0]
|
||||
if time_invalid or np.any(pos_traj > jnt_pos_high) or np.any(pos_traj < jnt_pos_low):
|
||||
|
Loading…
Reference in New Issue
Block a user