Merge branch 'master' into tt_cluster_debug

# Conflicts:
#	fancy_gym/black_box/black_box_wrapper.py
This commit is contained in:
Hongyi Zhou 2022-12-01 11:35:38 +01:00
commit 3a502ce831

View File

@ -22,7 +22,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
replanning_schedule: Optional[
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
reward_aggregation: Callable[[np.ndarray], float] = np.sum,
max_planning_times: int = None,
max_planning_times: int = np.inf,
condition_on_desired: bool = False
):
"""
@ -86,7 +86,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
def get_trajectory(self, action: np.ndarray) -> Tuple:
duration = self.duration
# duration = self.duration - self.current_traj_steps * self.dt
if self.learn_sub_trajectories:
duration = None
# reset with every new call as we need to set all arguments, such as tau, delay, again.
@ -187,11 +186,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
if self.render_kwargs:
self.env.render(**self.render_kwargs)
if done or self.replanning_schedule(current_pos, current_vel, obs, c_action,
t + 1 + self.current_traj_steps):
if not done and self.max_planning_times is not None and self.plan_steps >= self.max_planning_times:
continue
if done or (self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
t + 1 + self.current_traj_steps)
and self.plan_steps < self.max_planning_times):
self.condition_pos = pos if self.condition_on_desired else None
self.condition_vel = vel if self.condition_on_desired else None
@ -221,4 +218,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.current_traj_steps = 0
self.plan_steps = 0
self.traj_gen.reset()
self.condition_vel = None
self.condition_pos = None
return super(BlackBoxWrapper, self).reset()