diff --git a/fancy_gym/black_box/black_box_wrapper.py b/fancy_gym/black_box/black_box_wrapper.py index 0cd4ae6..45a92f5 100644 --- a/fancy_gym/black_box/black_box_wrapper.py +++ b/fancy_gym/black_box/black_box_wrapper.py @@ -2,6 +2,7 @@ from typing import Tuple, Optional, Callable import gym import numpy as np +import torch from gym import spaces from mp_pytorch.mp.mp_interfaces import MPInterface @@ -74,7 +75,7 @@ class BlackBoxWrapper(gym.ObservationWrapper): self.verbose = verbose # condition value - self.desired_conditioning = False + self.desired_conditioning = True self.condition_pos = None self.condition_vel = None @@ -105,6 +106,10 @@ class BlackBoxWrapper(gym.ObservationWrapper): if self.current_traj_steps == 0: self.condition_pos = self.current_pos self.condition_vel = self.current_vel + + bc_time = torch.as_tensor(bc_time, dtype=torch.float32) + self.condition_pos = torch.as_tensor(self.condition_pos, dtype=torch.float32) + self.condition_vel = torch.as_tensor(self.condition_vel, dtype=torch.float32) self.traj_gen.set_boundary_conditions(bc_time, self.condition_pos, self.condition_vel) self.traj_gen.set_duration(duration, self.dt) # traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True) diff --git a/fancy_gym/envs/__init__.py b/fancy_gym/envs/__init__.py index 1ad261f..f8c4597 100644 --- a/fancy_gym/envs/__init__.py +++ b/fancy_gym/envs/__init__.py @@ -498,10 +498,10 @@ for _v in _versions: kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.]) kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = np.array([3.4944e+01, 4.3734e+01, 9.6711e+01, 2.4429e+02, 5.8272e+02]) kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 3.1264e-01 - # kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = 0.3 * np.array([100., 166., 500., 1000.]) - # kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 0.3 * 1. kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 5 kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10. + kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try + kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3 kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_replan_times'] = 4 kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0 register(