todos
This commit is contained in:
parent
6e06e11cfa
commit
9b48fc9d48
@ -24,13 +24,17 @@ ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
|||||||
DEFAULT_MP_ENV_DICT = {
|
DEFAULT_MP_ENV_DICT = {
|
||||||
"name": 'EnvName',
|
"name": 'EnvName',
|
||||||
"wrappers": [],
|
"wrappers": [],
|
||||||
|
# TODO move scale to traj_gen
|
||||||
"ep_wrapper_kwargs": {
|
"ep_wrapper_kwargs": {
|
||||||
"weight_scale": 1
|
"weight_scale": 1
|
||||||
},
|
},
|
||||||
|
# TODO traj_gen_kwargs
|
||||||
|
# TODO remove action_dim
|
||||||
"movement_primitives_kwargs": {
|
"movement_primitives_kwargs": {
|
||||||
'movement_primitives_type': 'promp',
|
'movement_primitives_type': 'promp',
|
||||||
'action_dim': 7
|
'action_dim': 7
|
||||||
},
|
},
|
||||||
|
# TODO remove tau
|
||||||
"phase_generator_kwargs": {
|
"phase_generator_kwargs": {
|
||||||
'phase_generator_type': 'linear',
|
'phase_generator_type': 'linear',
|
||||||
'delay': 0,
|
'delay': 0,
|
||||||
@ -40,13 +44,13 @@ DEFAULT_MP_ENV_DICT = {
|
|||||||
},
|
},
|
||||||
"controller_kwargs": {
|
"controller_kwargs": {
|
||||||
'controller_type': 'motor',
|
'controller_type': 'motor',
|
||||||
"p_gains": np.ones(7),
|
"p_gains": 1.0,
|
||||||
"d_gains": np.ones(7) * 0.1,
|
"d_gains": 0.1,
|
||||||
},
|
},
|
||||||
"basis_generator_kwargs": {
|
"basis_generator_kwargs": {
|
||||||
'basis_generator_type': 'zero_rbf',
|
'basis_generator_type': 'zero_rbf',
|
||||||
'num_basis': 5,
|
'num_basis': 5,
|
||||||
'num_basis_zero_start': 2
|
'num_basis_zero_start': 2 # TODO: Change to 1
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,8 +2,6 @@ from typing import Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mp_env_api import MPEnvWrapper
|
|
||||||
|
|
||||||
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
|
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,18 +32,24 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC):
|
|||||||
|
|
||||||
self.env = env
|
self.env = env
|
||||||
self.duration = duration
|
self.duration = duration
|
||||||
self.traj_steps = int(duration / self.dt)
|
self.sequencing = sequencing
|
||||||
self.post_traj_steps = self.env.spec.max_episode_steps - self.traj_steps
|
# self.traj_steps = int(duration / self.dt)
|
||||||
|
# self.post_traj_steps = self.env.spec.max_episode_steps - self.traj_steps
|
||||||
# duration = self.env.max_episode_steps * self.dt
|
# duration = self.env.max_episode_steps * self.dt
|
||||||
|
|
||||||
# trajectory generation
|
# trajectory generation
|
||||||
self.trajectory_generator = trajectory_generator
|
self.trajectory_generator = trajectory_generator
|
||||||
self.tracking_controller = tracking_controller
|
self.tracking_controller = tracking_controller
|
||||||
# self.weight_scale = weight_scale
|
# self.weight_scale = weight_scale
|
||||||
self.time_steps = np.linspace(0, self.duration, self.traj_steps)
|
# self.time_steps = np.linspace(0, self.duration, self.traj_steps)
|
||||||
self.trajectory_generator.set_mp_times(self.time_steps)
|
# self.trajectory_generator.set_mp_times(self.time_steps)
|
||||||
# self.trajectory_generator.set_mp_duration(self.time_steps, dt)
|
if not sequencing:
|
||||||
# action_bounds = np.inf * np.ones((np.prod(self.trajectory_generator.num_params)))
|
self.trajectory_generator.set_mp_duration(np.array([self.duration]), np.array([self.dt]))
|
||||||
|
else:
|
||||||
|
# sequencing stuff
|
||||||
|
pass
|
||||||
|
|
||||||
|
# reward computation
|
||||||
self.reward_aggregation = reward_aggregation
|
self.reward_aggregation = reward_aggregation
|
||||||
|
|
||||||
# spaces
|
# spaces
|
||||||
@ -67,15 +73,12 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC):
|
|||||||
return observation[self.env.context_mask]
|
return observation[self.env.context_mask]
|
||||||
|
|
||||||
def get_trajectory(self, action: np.ndarray) -> Tuple:
|
def get_trajectory(self, action: np.ndarray) -> Tuple:
|
||||||
# TODO: this follows the implementation of the mp_pytorch library which includes the parameters tau and delay at
|
|
||||||
# the beginning of the array.
|
|
||||||
# ignore_indices = int(self.trajectory_generator.learn_tau) + int(self.trajectory_generator.learn_delay)
|
|
||||||
# scaled_mp_params = action.copy()
|
|
||||||
# scaled_mp_params[ignore_indices:] *= self.weight_scale
|
|
||||||
|
|
||||||
clipped_params = np.clip(action, self.mp_action_space.low, self.mp_action_space.high)
|
clipped_params = np.clip(action, self.mp_action_space.low, self.mp_action_space.high)
|
||||||
self.trajectory_generator.set_params(clipped_params)
|
self.trajectory_generator.set_params(clipped_params)
|
||||||
self.trajectory_generator.set_boundary_conditions(bc_time=self.time_steps[:1], bc_pos=self.current_pos,
|
# if self.trajectory_generator.learn_tau:
|
||||||
|
# self.trajectory_generator.set_mp_duration(self.trajectory_generator.tau, np.array([self.dt]))
|
||||||
|
self.trajectory_generator.set_mp_duration(None if self.sequencing else self.duration, np.array([self.dt]))
|
||||||
|
self.trajectory_generator.set_boundary_conditions(bc_time=, bc_pos=self.current_pos,
|
||||||
bc_vel=self.current_vel)
|
bc_vel=self.current_vel)
|
||||||
traj_dict = self.trajectory_generator.get_mp_trajs(get_pos=True, get_vel=True)
|
traj_dict = self.trajectory_generator.get_mp_trajs(get_pos=True, get_vel=True)
|
||||||
trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel']
|
trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel']
|
||||||
@ -152,7 +155,7 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC):
|
|||||||
if self.render_mode is not None:
|
if self.render_mode is not None:
|
||||||
self.render(mode=self.render_mode, **self.render_kwargs)
|
self.render(mode=self.render_mode, **self.render_kwargs)
|
||||||
|
|
||||||
if done or self.env.do_replanning(self.env.current_pos, self.env.current_vel, obs, c_action, t):
|
if done or self.env.do_replanning(self.current_pos, self.current_vel, obs, c_action, t + past_steps):
|
||||||
break
|
break
|
||||||
|
|
||||||
infos.update({k: v[:t + 1] for k, v in infos.items()})
|
infos.update({k: v[:t + 1] for k, v in infos.items()})
|
||||||
|
@ -43,13 +43,13 @@ class RawInterfaceWrapper(gym.Wrapper):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
|
||||||
def dt(self) -> float:
|
def dt(self) -> float:
|
||||||
"""
|
"""
|
||||||
Control frequency of the environment
|
Control frequency of the environment
|
||||||
Returns: float
|
Returns: float
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
return self.env.dt
|
||||||
|
|
||||||
def do_replanning(self, pos, vel, s, a, t):
|
def do_replanning(self, pos, vel, s, a, t):
|
||||||
# return t % 100 == 0
|
# return t % 100 == 0
|
||||||
|
Loading…
Reference in New Issue
Block a user