todos
This commit is contained in:
parent
6e06e11cfa
commit
9b48fc9d48
@ -24,13 +24,17 @@ ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||
DEFAULT_MP_ENV_DICT = {
|
||||
"name": 'EnvName',
|
||||
"wrappers": [],
|
||||
# TODO move scale to traj_gen
|
||||
"ep_wrapper_kwargs": {
|
||||
"weight_scale": 1
|
||||
},
|
||||
# TODO traj_gen_kwargs
|
||||
# TODO remove action_dim
|
||||
"movement_primitives_kwargs": {
|
||||
'movement_primitives_type': 'promp',
|
||||
'action_dim': 7
|
||||
},
|
||||
# TODO remove tau
|
||||
"phase_generator_kwargs": {
|
||||
'phase_generator_type': 'linear',
|
||||
'delay': 0,
|
||||
@ -40,13 +44,13 @@ DEFAULT_MP_ENV_DICT = {
|
||||
},
|
||||
"controller_kwargs": {
|
||||
'controller_type': 'motor',
|
||||
"p_gains": np.ones(7),
|
||||
"d_gains": np.ones(7) * 0.1,
|
||||
"p_gains": 1.0,
|
||||
"d_gains": 0.1,
|
||||
},
|
||||
"basis_generator_kwargs": {
|
||||
'basis_generator_type': 'zero_rbf',
|
||||
'num_basis': 5,
|
||||
'num_basis_zero_start': 2
|
||||
'num_basis_zero_start': 2 # TODO: Change to 1
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2,8 +2,6 @@ from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mp_env_api import MPEnvWrapper
|
||||
|
||||
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
|
||||
|
||||
|
||||
|
@ -32,18 +32,24 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC):
|
||||
|
||||
self.env = env
|
||||
self.duration = duration
|
||||
self.traj_steps = int(duration / self.dt)
|
||||
self.post_traj_steps = self.env.spec.max_episode_steps - self.traj_steps
|
||||
self.sequencing = sequencing
|
||||
# self.traj_steps = int(duration / self.dt)
|
||||
# self.post_traj_steps = self.env.spec.max_episode_steps - self.traj_steps
|
||||
# duration = self.env.max_episode_steps * self.dt
|
||||
|
||||
# trajectory generation
|
||||
self.trajectory_generator = trajectory_generator
|
||||
self.tracking_controller = tracking_controller
|
||||
# self.weight_scale = weight_scale
|
||||
self.time_steps = np.linspace(0, self.duration, self.traj_steps)
|
||||
self.trajectory_generator.set_mp_times(self.time_steps)
|
||||
# self.trajectory_generator.set_mp_duration(self.time_steps, dt)
|
||||
# action_bounds = np.inf * np.ones((np.prod(self.trajectory_generator.num_params)))
|
||||
# self.time_steps = np.linspace(0, self.duration, self.traj_steps)
|
||||
# self.trajectory_generator.set_mp_times(self.time_steps)
|
||||
if not sequencing:
|
||||
self.trajectory_generator.set_mp_duration(np.array([self.duration]), np.array([self.dt]))
|
||||
else:
|
||||
# sequencing stuff
|
||||
pass
|
||||
|
||||
# reward computation
|
||||
self.reward_aggregation = reward_aggregation
|
||||
|
||||
# spaces
|
||||
@ -67,15 +73,12 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC):
|
||||
return observation[self.env.context_mask]
|
||||
|
||||
def get_trajectory(self, action: np.ndarray) -> Tuple:
|
||||
# TODO: this follows the implementation of the mp_pytorch library which includes the parameters tau and delay at
|
||||
# the beginning of the array.
|
||||
# ignore_indices = int(self.trajectory_generator.learn_tau) + int(self.trajectory_generator.learn_delay)
|
||||
# scaled_mp_params = action.copy()
|
||||
# scaled_mp_params[ignore_indices:] *= self.weight_scale
|
||||
|
||||
clipped_params = np.clip(action, self.mp_action_space.low, self.mp_action_space.high)
|
||||
self.trajectory_generator.set_params(clipped_params)
|
||||
self.trajectory_generator.set_boundary_conditions(bc_time=self.time_steps[:1], bc_pos=self.current_pos,
|
||||
# if self.trajectory_generator.learn_tau:
|
||||
# self.trajectory_generator.set_mp_duration(self.trajectory_generator.tau, np.array([self.dt]))
|
||||
self.trajectory_generator.set_mp_duration(None if self.sequencing else self.duration, np.array([self.dt]))
|
||||
self.trajectory_generator.set_boundary_conditions(bc_time=, bc_pos=self.current_pos,
|
||||
bc_vel=self.current_vel)
|
||||
traj_dict = self.trajectory_generator.get_mp_trajs(get_pos=True, get_vel=True)
|
||||
trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel']
|
||||
@ -152,7 +155,7 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC):
|
||||
if self.render_mode is not None:
|
||||
self.render(mode=self.render_mode, **self.render_kwargs)
|
||||
|
||||
if done or self.env.do_replanning(self.env.current_pos, self.env.current_vel, obs, c_action, t):
|
||||
if done or self.env.do_replanning(self.current_pos, self.current_vel, obs, c_action, t + past_steps):
|
||||
break
|
||||
|
||||
infos.update({k: v[:t + 1] for k, v in infos.items()})
|
||||
|
@ -43,13 +43,13 @@ class RawInterfaceWrapper(gym.Wrapper):
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dt(self) -> float:
|
||||
"""
|
||||
Control frequency of the environment
|
||||
Returns: float
|
||||
|
||||
"""
|
||||
return self.env.dt
|
||||
|
||||
def do_replanning(self, pos, vel, s, a, t):
|
||||
# return t % 100 == 0
|
||||
|
Loading…
Reference in New Issue
Block a user