use auto scaling feature of MP_Pytorch
This commit is contained in:
parent
524bbf352e
commit
61c1b76e29
@ -2,6 +2,7 @@ from typing import Tuple, Optional, Callable
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from gym import spaces
|
||||
from mp_pytorch.mp.mp_interfaces import MPInterface
|
||||
|
||||
@ -74,7 +75,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
self.verbose = verbose
|
||||
|
||||
# condition value
|
||||
self.desired_conditioning = False
|
||||
self.desired_conditioning = True
|
||||
self.condition_pos = None
|
||||
self.condition_vel = None
|
||||
|
||||
@ -105,6 +106,10 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
||||
if self.current_traj_steps == 0:
|
||||
self.condition_pos = self.current_pos
|
||||
self.condition_vel = self.current_vel
|
||||
|
||||
bc_time = torch.as_tensor(bc_time, dtype=torch.float32)
|
||||
self.condition_pos = torch.as_tensor(self.condition_pos, dtype=torch.float32)
|
||||
self.condition_vel = torch.as_tensor(self.condition_vel, dtype=torch.float32)
|
||||
self.traj_gen.set_boundary_conditions(bc_time, self.condition_pos, self.condition_vel)
|
||||
self.traj_gen.set_duration(duration, self.dt)
|
||||
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
|
||||
|
@ -498,10 +498,10 @@ for _v in _versions:
|
||||
kwargs_dict_box_pushing_prodmp['controller_kwargs']['d_gains'] = 0.01 * np.array([10., 10., 10., 10., 6., 5., 3.])
|
||||
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = np.array([3.4944e+01, 4.3734e+01, 9.6711e+01, 2.4429e+02, 5.8272e+02])
|
||||
kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 3.1264e-01
|
||||
# kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['weights_scale'] = 0.3 * np.array([100., 166., 500., 1000.])
|
||||
# kwargs_dict_box_pushing_prodmp['trajectory_generator_kwargs']['goal_scale'] = 0.3 * 1.
|
||||
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['num_basis'] = 5
|
||||
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['alpha'] = 10.
|
||||
kwargs_dict_box_pushing_prodmp['basis_generator_kwargs']['basis_bandwidth_factor'] = 3 # 3.5, 4 to try
|
||||
kwargs_dict_box_pushing_prodmp['phase_generator_kwargs']['alpha_phase'] = 3
|
||||
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['max_replan_times'] = 4
|
||||
kwargs_dict_box_pushing_prodmp['black_box_kwargs']['replanning_schedule'] = lambda pos, vel, obs, action, t : t % 25 == 0
|
||||
register(
|
||||
|
Loading…
Reference in New Issue
Block a user