use auto scaling feature of MP_Pytorch

This commit is contained in:
Hongyi Zhou 2022-10-31 13:18:05 +01:00
parent 524bbf352e
commit 61c1b76e29
2 changed files with 8 additions and 3 deletions

View File

@ -2,6 +2,7 @@ from typing import Tuple, Optional, Callable
import gym
import numpy as np
import torch
from gym import spaces
from mp_pytorch.mp.mp_interfaces import MPInterface
@ -74,7 +75,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
self.verbose = verbose
# condition value
self.desired_conditioning = False
self.desired_conditioning = True
self.condition_pos = None
self.condition_vel = None
@ -105,6 +106,10 @@ class BlackBoxWrapper(gym.ObservationWrapper):
if self.current_traj_steps == 0:
self.condition_pos = self.current_pos
self.condition_vel = self.current_vel
bc_time = torch.as_tensor(bc_time, dtype=torch.float32)
self.condition_pos = torch.as_tensor(self.condition_pos, dtype=torch.float32)
self.condition_vel = torch.as_tensor(self.condition_vel, dtype=torch.float32)
self.traj_gen.set_boundary_conditions(bc_time, self.condition_pos, self.condition_vel)
self.traj_gen.set_duration(duration, self.dt)
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)

View File

@ -498,10 +498,10 @@ for _v in _versions:
kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = np.array([3.4944e+01, 4.3734e+01, 9.6711e+01, 2.4429e+02, 5.8272e+02])
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 3.1264e-01
# kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = 0.3 * np.array([100., 166., 500., 1000.])
# kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 0.3 * 1.
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 5
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10.
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try
kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_replan_times'] = 4
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
register(