Merge branch 'master' into tt_cluster_debug
# Conflicts: # fancy_gym/black_box/black_box_wrapper.py
This commit is contained in:
commit
3a502ce831
@ -22,7 +22,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
replanning_schedule: Optional[
|
replanning_schedule: Optional[
|
||||||
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
|
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
|
||||||
reward_aggregation: Callable[[np.ndarray], float] = np.sum,
|
reward_aggregation: Callable[[np.ndarray], float] = np.sum,
|
||||||
max_planning_times: int = None,
|
max_planning_times: int = np.inf,
|
||||||
condition_on_desired: bool = False
|
condition_on_desired: bool = False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -86,7 +86,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
|
|
||||||
def get_trajectory(self, action: np.ndarray) -> Tuple:
|
def get_trajectory(self, action: np.ndarray) -> Tuple:
|
||||||
duration = self.duration
|
duration = self.duration
|
||||||
# duration = self.duration - self.current_traj_steps * self.dt
|
|
||||||
if self.learn_sub_trajectories:
|
if self.learn_sub_trajectories:
|
||||||
duration = None
|
duration = None
|
||||||
# reset with every new call as we need to set all arguments, such as tau, delay, again.
|
# reset with every new call as we need to set all arguments, such as tau, delay, again.
|
||||||
@ -187,11 +186,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
if self.render_kwargs:
|
if self.render_kwargs:
|
||||||
self.env.render(**self.render_kwargs)
|
self.env.render(**self.render_kwargs)
|
||||||
|
|
||||||
if done or self.replanning_schedule(current_pos, current_vel, obs, c_action,
|
if done or (self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
|
||||||
t + 1 + self.current_traj_steps):
|
t + 1 + self.current_traj_steps)
|
||||||
|
and self.plan_steps < self.max_planning_times):
|
||||||
if not done and self.max_planning_times is not None and self.plan_steps >= self.max_planning_times:
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.condition_pos = pos if self.condition_on_desired else None
|
self.condition_pos = pos if self.condition_on_desired else None
|
||||||
self.condition_vel = vel if self.condition_on_desired else None
|
self.condition_vel = vel if self.condition_on_desired else None
|
||||||
@ -221,4 +218,6 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
self.current_traj_steps = 0
|
self.current_traj_steps = 0
|
||||||
self.plan_steps = 0
|
self.plan_steps = 0
|
||||||
self.traj_gen.reset()
|
self.traj_gen.reset()
|
||||||
|
self.condition_vel = None
|
||||||
|
self.condition_pos = None
|
||||||
return super(BlackBoxWrapper, self).reset()
|
return super(BlackBoxWrapper, self).reset()
|
||||||
|
Loading…
Reference in New Issue
Block a user