220 lines
9.4 KiB
Python
220 lines
9.4 KiB
Python
from typing import Tuple, Optional, Callable, Dict, Any
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
from gymnasium import spaces
|
|
from gymnasium.core import ObsType
|
|
from mp_pytorch.mp.mp_interfaces import MPInterface
|
|
|
|
from fancy_gym.black_box.controller.base_controller import BaseController
|
|
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
|
from fancy_gym.utils.utils import get_numpy
|
|
|
|
|
|
class BlackBoxWrapper(gym.ObservationWrapper):
|
|
|
|
def __init__(self,
|
|
env: RawInterfaceWrapper,
|
|
trajectory_generator: MPInterface,
|
|
tracking_controller: BaseController,
|
|
duration: float,
|
|
verbose: int = 1,
|
|
learn_sub_trajectories: bool = False,
|
|
replanning_schedule: Optional[
|
|
Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int], bool]] = None,
|
|
reward_aggregation: Callable[[np.ndarray], float] = np.sum,
|
|
max_planning_times: int = np.inf,
|
|
condition_on_desired: bool = False
|
|
):
|
|
"""
|
|
gym.Wrapper for leveraging a black box approach with a trajectory generator.
|
|
|
|
Args:
|
|
env: The (wrapped) environment this wrapper is applied on
|
|
trajectory_generator: Generates the full or partial trajectory
|
|
tracking_controller: Translates the desired trajectory to raw action sequences
|
|
duration: Length of the trajectory of the movement primitive in seconds
|
|
verbose: level of detail for returned values in info dict.
|
|
learn_sub_trajectories: Transforms full episode learning into learning sub-trajectories, similar to
|
|
step-based learning
|
|
replanning_schedule: callable that receives
|
|
reward_aggregation: function that takes the np.ndarray of step rewards as input and returns the trajectory
|
|
reward, default summation over all values.
|
|
"""
|
|
super().__init__(env)
|
|
|
|
self.duration = duration
|
|
self.learn_sub_trajectories = learn_sub_trajectories
|
|
self.do_replanning = replanning_schedule is not None
|
|
self.replanning_schedule = replanning_schedule or (lambda *x: False)
|
|
self.current_traj_steps = 0
|
|
|
|
# trajectory generation
|
|
self.traj_gen = trajectory_generator
|
|
self.tracking_controller = tracking_controller
|
|
# self.time_steps = np.linspace(0, self.duration, self.traj_steps)
|
|
# self.traj_gen.set_mp_times(self.time_steps)
|
|
self.traj_gen.set_duration(self.duration, self.dt)
|
|
|
|
# reward computation
|
|
self.reward_aggregation = reward_aggregation
|
|
|
|
# spaces
|
|
self.return_context_observation = not (
|
|
learn_sub_trajectories or self.do_replanning)
|
|
self.traj_gen_action_space = self._get_traj_gen_action_space()
|
|
self.action_space = self._get_action_space()
|
|
self.observation_space = self._get_observation_space()
|
|
|
|
# rendering
|
|
self.render_kwargs = {}
|
|
self.verbose = verbose
|
|
|
|
# condition value
|
|
self.condition_on_desired = condition_on_desired
|
|
self.condition_pos = None
|
|
self.condition_vel = None
|
|
|
|
self.max_planning_times = max_planning_times
|
|
self.plan_steps = 0
|
|
|
|
def observation(self, observation):
|
|
# return context space if we are
|
|
if self.return_context_observation:
|
|
observation = observation[self.env.context_mask]
|
|
# cast dtype because metaworld returns incorrect that throws gym error
|
|
return observation.astype(self.observation_space.dtype)
|
|
|
|
def get_trajectory(self, action: np.ndarray) -> Tuple:
|
|
duration = self.duration
|
|
if self.learn_sub_trajectories:
|
|
duration = None
|
|
# reset with every new call as we need to set all arguments, such as tau, delay, again.
|
|
# If we do not do this, the traj_gen assumes we are continuing the trajectory.
|
|
self.traj_gen.reset()
|
|
|
|
clipped_params = np.clip(
|
|
action, self.traj_gen_action_space.low, self.traj_gen_action_space.high)
|
|
self.traj_gen.set_params(clipped_params)
|
|
init_time = np.array(
|
|
0 if not self.do_replanning else self.current_traj_steps * self.dt)
|
|
|
|
condition_pos = self.condition_pos if self.condition_pos is not None else self.current_pos
|
|
condition_vel = self.condition_vel if self.condition_vel is not None else self.current_vel
|
|
|
|
self.traj_gen.set_initial_conditions(
|
|
init_time, condition_pos, condition_vel)
|
|
self.traj_gen.set_duration(duration, self.dt)
|
|
|
|
position = get_numpy(self.traj_gen.get_traj_pos())
|
|
velocity = get_numpy(self.traj_gen.get_traj_vel())
|
|
|
|
return position, velocity
|
|
|
|
def _get_traj_gen_action_space(self):
|
|
"""This function can be used to set up an individual space for the parameters of the traj_gen."""
|
|
min_action_bounds, max_action_bounds = self.traj_gen.get_params_bounds()
|
|
action_space = gym.spaces.Box(low=min_action_bounds.numpy(), high=max_action_bounds.numpy(),
|
|
dtype=self.env.action_space.dtype)
|
|
return action_space
|
|
|
|
def _get_action_space(self):
|
|
"""
|
|
This function can be used to modify the action space for considering actions which are not learned via movement
|
|
primitives. E.g. ball releasing time for the beer pong task. By default, it is the parameter space of the
|
|
movement primitive.
|
|
Only needs to be overwritten if the action space needs to be modified.
|
|
"""
|
|
try:
|
|
return self.traj_gen_action_space
|
|
except AttributeError:
|
|
return self._get_traj_gen_action_space()
|
|
|
|
def _get_observation_space(self):
|
|
if self.return_context_observation:
|
|
mask = self.env.context_mask
|
|
# return full observation
|
|
min_obs_bound = self.env.observation_space.low[mask]
|
|
max_obs_bound = self.env.observation_space.high[mask]
|
|
return spaces.Box(low=min_obs_bound, high=max_obs_bound, dtype=self.env.observation_space.dtype)
|
|
return self.env.observation_space
|
|
|
|
def step(self, action: np.ndarray):
|
|
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
|
|
|
|
# TODO remove this part, right now only needed for beer pong
|
|
mp_params, env_spec_params = self.env.episode_callback(
|
|
action, self.traj_gen)
|
|
position, velocity = self.get_trajectory(mp_params)
|
|
|
|
trajectory_length = len(position)
|
|
rewards = np.zeros(shape=(trajectory_length,))
|
|
if self.verbose >= 2:
|
|
actions = np.zeros(shape=(trajectory_length,) +
|
|
self.env.action_space.shape)
|
|
observations = np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape,
|
|
dtype=self.env.observation_space.dtype)
|
|
|
|
infos = dict()
|
|
done = False
|
|
|
|
self.plan_steps += 1
|
|
for t, (pos, vel) in enumerate(zip(position, velocity)):
|
|
step_action = self.tracking_controller.get_action(
|
|
pos, vel, self.current_pos, self.current_vel)
|
|
c_action = np.clip(
|
|
step_action, self.env.action_space.low, self.env.action_space.high)
|
|
obs, c_reward, terminated, truncated, info = self.env.step(
|
|
c_action)
|
|
rewards[t] = c_reward
|
|
|
|
if self.verbose >= 2:
|
|
actions[t, :] = c_action
|
|
observations[t, :] = obs
|
|
|
|
for k, v in info.items():
|
|
elems = infos.get(k, [None] * trajectory_length)
|
|
elems[t] = v
|
|
infos[k] = elems
|
|
|
|
if self.render_kwargs:
|
|
self.env.render(**self.render_kwargs)
|
|
|
|
if terminated or truncated or (self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
|
|
t + 1 + self.current_traj_steps)
|
|
and self.plan_steps < self.max_planning_times):
|
|
|
|
if self.condition_on_desired:
|
|
self.condition_pos = pos
|
|
self.condition_vel = vel
|
|
|
|
break
|
|
|
|
infos.update({k: v[:t + 1] for k, v in infos.items()})
|
|
self.current_traj_steps += t + 1
|
|
|
|
if self.verbose >= 2:
|
|
infos['positions'] = position
|
|
infos['velocities'] = velocity
|
|
infos['step_actions'] = actions[:t + 1]
|
|
infos['step_observations'] = observations[:t + 1]
|
|
infos['step_rewards'] = rewards[:t + 1]
|
|
|
|
infos['trajectory_length'] = t + 1
|
|
trajectory_return = self.reward_aggregation(rewards[:t + 1])
|
|
return self.observation(obs), trajectory_return, terminated, truncated, infos
|
|
|
|
def render(self, **kwargs):
|
|
"""Only set render options here, such that they can be used during the rollout.
|
|
This only needs to be called once"""
|
|
self.render_kwargs = kwargs
|
|
|
|
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
|
-> Tuple[ObsType, Dict[str, Any]]:
|
|
self.current_traj_steps = 0
|
|
self.plan_steps = 0
|
|
self.traj_gen.reset()
|
|
self.condition_pos = None
|
|
self.condition_vel = None
|
|
return super(BlackBoxWrapper, self).reset(seed=seed, options=options)
|