This commit is contained in:
Fabian 2022-06-29 12:25:40 +02:00
parent 6e06e11cfa
commit 9b48fc9d48
4 changed files with 25 additions and 20 deletions

View File

@ -24,13 +24,17 @@ ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
DEFAULT_MP_ENV_DICT = {
"name": 'EnvName',
"wrappers": [],
# TODO move scale to traj_gen
"ep_wrapper_kwargs": {
"weight_scale": 1
},
# TODO traj_gen_kwargs
# TODO remove action_dim
"movement_primitives_kwargs": {
'movement_primitives_type': 'promp',
'action_dim': 7
},
# TODO remove tau
"phase_generator_kwargs": {
'phase_generator_type': 'linear',
'delay': 0,
@ -40,13 +44,13 @@ DEFAULT_MP_ENV_DICT = {
},
"controller_kwargs": {
'controller_type': 'motor',
"p_gains": np.ones(7),
"d_gains": np.ones(7) * 0.1,
"p_gains": 1.0,
"d_gains": 0.1,
},
"basis_generator_kwargs": {
'basis_generator_type': 'zero_rbf',
'num_basis': 5,
'num_basis_zero_start': 2
'num_basis_zero_start': 2 # TODO: Change to 1
}
}

View File

@ -2,8 +2,6 @@ from typing import Tuple, Union
import numpy as np
from mp_env_api import MPEnvWrapper
from alr_envs.mp.raw_interface_wrapper import RawInterfaceWrapper

View File

@ -32,18 +32,24 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC):
self.env = env
self.duration = duration
self.traj_steps = int(duration / self.dt)
self.post_traj_steps = self.env.spec.max_episode_steps - self.traj_steps
self.sequencing = sequencing
# self.traj_steps = int(duration / self.dt)
# self.post_traj_steps = self.env.spec.max_episode_steps - self.traj_steps
# duration = self.env.max_episode_steps * self.dt
# trajectory generation
self.trajectory_generator = trajectory_generator
self.tracking_controller = tracking_controller
# self.weight_scale = weight_scale
self.time_steps = np.linspace(0, self.duration, self.traj_steps)
self.trajectory_generator.set_mp_times(self.time_steps)
# self.trajectory_generator.set_mp_duration(self.time_steps, dt)
# action_bounds = np.inf * np.ones((np.prod(self.trajectory_generator.num_params)))
# self.time_steps = np.linspace(0, self.duration, self.traj_steps)
# self.trajectory_generator.set_mp_times(self.time_steps)
if not sequencing:
self.trajectory_generator.set_mp_duration(np.array([self.duration]), np.array([self.dt]))
else:
# sequencing stuff
pass
# reward computation
self.reward_aggregation = reward_aggregation
# spaces
@ -67,15 +73,12 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC):
return observation[self.env.context_mask]
def get_trajectory(self, action: np.ndarray) -> Tuple:
# TODO: this follows the implementation of the mp_pytorch library which includes the parameters tau and delay at
# the beginning of the array.
# ignore_indices = int(self.trajectory_generator.learn_tau) + int(self.trajectory_generator.learn_delay)
# scaled_mp_params = action.copy()
# scaled_mp_params[ignore_indices:] *= self.weight_scale
clipped_params = np.clip(action, self.mp_action_space.low, self.mp_action_space.high)
self.trajectory_generator.set_params(clipped_params)
self.trajectory_generator.set_boundary_conditions(bc_time=self.time_steps[:1], bc_pos=self.current_pos,
# if self.trajectory_generator.learn_tau:
# self.trajectory_generator.set_mp_duration(self.trajectory_generator.tau, np.array([self.dt]))
self.trajectory_generator.set_mp_duration(None if self.sequencing else self.duration, np.array([self.dt]))
self.trajectory_generator.set_boundary_conditions(bc_time=, bc_pos=self.current_pos,
bc_vel=self.current_vel)
traj_dict = self.trajectory_generator.get_mp_trajs(get_pos=True, get_vel=True)
trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel']
@ -152,7 +155,7 @@ class BlackBoxWrapper(gym.ObservationWrapper, ABC):
if self.render_mode is not None:
self.render(mode=self.render_mode, **self.render_kwargs)
if done or self.env.do_replanning(self.env.current_pos, self.env.current_vel, obs, c_action, t):
if done or self.env.do_replanning(self.current_pos, self.current_vel, obs, c_action, t + past_steps):
break
infos.update({k: v[:t + 1] for k, v in infos.items()})

View File

@ -43,13 +43,13 @@ class RawInterfaceWrapper(gym.Wrapper):
raise NotImplementedError()
@property
@abstractmethod
def dt(self) -> float:
"""
Control frequency of the environment
Returns: float
"""
return self.env.dt
def do_replanning(self, pos, vel, s, a, t):
# return t % 100 == 0