diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py index 9619954..fd5032b 100644 --- a/fancy_gym/black_box/black_box_wrapper.py +++ b/fancy_gym/black_box/black_box_wrapper.py @@ -100,15 +100,10 @@ class BlackBoxWrapper(gym.ObservationWrapper): self.traj_gen.set_initial_conditions(init_time, condition_pos, condition_vel) self.traj_gen.set_duration(duration, self.dt) - # traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True) + position = get_numpy(self.traj_gen.get_traj_pos()) velocity = get_numpy(self.traj_gen.get_traj_vel()) - # if self.do_replanning: - # # Remove first part of trajectory as this is already over - # position = position[self.current_traj_steps:] - # velocity = velocity[self.current_traj_steps:] - return position, velocity def _get_traj_gen_action_space(self): @@ -179,8 +174,9 @@ class BlackBoxWrapper(gym.ObservationWrapper): t + 1 + self.current_traj_steps) and self.plan_steps < self.max_planning_times): - self.condition_pos = pos if self.condition_on_desired else None - self.condition_vel = vel if self.condition_on_desired else None + if self.condition_on_desired: + self.condition_pos = pos + self.condition_vel = vel break @@ -207,6 +203,6 @@ class BlackBoxWrapper(gym.ObservationWrapper): self.current_traj_steps = 0 self.plan_steps = 0 self.traj_gen.reset() - self.condition_vel = None self.condition_pos = None + self.condition_vel = None return super(BlackBoxWrapper, self).reset() diff --git a/fancy_gym/meta/__init__.py b/fancy_gym/meta/__init__.py index 9304c72..63b15c2 100644 --- a/fancy_gym/meta/__init__.py +++ b/fancy_gym/meta/__init__.py @@ -26,6 +26,9 @@ DEFAULT_BB_DICT_ProMP = { 'basis_generator_type': 'zero_rbf', 'num_basis': 5, 'num_basis_zero_start': 1 + }, + 'black_box_kwargs': { + 'condition_on_desired': False, } } @@ -50,7 +53,11 @@ DEFAULT_BB_DICT_ProDMP = { 'basis_generator_type': 'prodmp', 'num_basis': 3, 'alpha': 10 + }, + 'black_box_kwargs': { + 'condition_on_desired': False, } + } _goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2",