fancy_gym/fancy_gym/black_box/black_box_wrapper.py

184 lines
8.2 KiB
Python
Raw Normal View History

2022-07-11 16:18:18 +02:00
from typing import Tuple, Optional
2022-04-28 09:05:28 +02:00
import gym
import numpy as np
from gym import spaces
from mp_pytorch.mp.mp_interfaces import MPInterface
2022-07-13 15:10:43 +02:00
from fancy_gym.black_box.controller.base_controller import BaseController
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
from fancy_gym.utils.utils import get_numpy
2022-05-03 19:51:54 +02:00
2022-04-28 09:05:28 +02:00
2022-06-30 14:20:52 +02:00
class BlackBoxWrapper(gym.ObservationWrapper):
2022-06-29 09:37:18 +02:00
def __init__(self,
env: RawInterfaceWrapper,
2022-06-30 17:33:05 +02:00
trajectory_generator: MPInterface,
tracking_controller: BaseController,
duration: float,
verbose: int = 1,
learn_sub_trajectories: bool = False,
2022-07-06 09:05:35 +02:00
replanning_schedule: Optional[callable] = None,
2022-06-30 17:33:05 +02:00
reward_aggregation: callable = np.sum
):
2022-06-29 09:37:18 +02:00
"""
gym.Wrapper for leveraging a black box approach with a trajectory generator.
Args:
env: The (wrapped) environment this wrapper is applied on
trajectory_generator: Generates the full or partial trajectory
tracking_controller: Translates the desired trajectory to raw action sequences
duration: Length of the trajectory of the movement primitive in seconds
verbose: level of detail for returned values in info dict.
2022-06-30 14:08:54 +02:00
learn_sub_trajectories: Transforms full episode learning into learning sub-trajectories, similar to
step-based learning
replanning_schedule: callable that receives
2022-06-29 09:37:18 +02:00
reward_aggregation: function that takes the np.ndarray of step rewards as input and returns the trajectory
reward, default summation over all values.
"""
2022-06-30 14:20:52 +02:00
super().__init__(env)
2022-04-28 09:05:28 +02:00
self.duration = duration
2022-06-30 14:08:54 +02:00
self.learn_sub_trajectories = learn_sub_trajectories
2022-07-06 09:05:35 +02:00
self.do_replanning = replanning_schedule is not None
self.replanning_schedule = replanning_schedule or (lambda *x: False)
2022-06-29 16:30:36 +02:00
self.current_traj_steps = 0
2022-04-28 09:05:28 +02:00
2022-06-29 09:37:18 +02:00
# trajectory generation
2022-06-30 14:08:54 +02:00
self.traj_gen = trajectory_generator
2022-06-29 09:37:18 +02:00
self.tracking_controller = tracking_controller
2022-06-29 12:25:40 +02:00
# self.time_steps = np.linspace(0, self.duration, self.traj_steps)
2022-06-30 14:08:54 +02:00
# self.traj_gen.set_mp_times(self.time_steps)
2022-09-20 11:17:20 +02:00
self.traj_gen.set_duration(self.duration, self.dt)
2022-06-29 12:25:40 +02:00
# reward computation
2022-06-29 09:37:18 +02:00
self.reward_aggregation = reward_aggregation
2022-04-28 09:05:28 +02:00
2022-06-29 09:37:18 +02:00
# spaces
2022-07-06 09:05:35 +02:00
self.return_context_observation = not (learn_sub_trajectories or self.do_replanning)
self.traj_gen_action_space = self._get_traj_gen_action_space()
self.action_space = self._get_action_space()
self.observation_space = self._get_observation_space()
2022-04-28 09:05:28 +02:00
2022-06-29 09:37:18 +02:00
# rendering
self.render_kwargs = {}
self.verbose = verbose
2022-06-29 09:37:18 +02:00
def observation(self, observation):
2022-06-29 16:30:36 +02:00
# return context space if we are
2022-07-27 16:34:35 +02:00
if self.return_context_observation:
observation = observation[self.env.context_mask]
2022-07-07 10:47:04 +02:00
# cast dtype because metaworld returns incorrect that throws gym error
2022-07-27 16:34:35 +02:00
return observation.astype(self.observation_space.dtype)
2022-06-29 09:37:18 +02:00
2022-04-28 09:05:28 +02:00
def get_trajectory(self, action: np.ndarray) -> Tuple:
2022-06-29 16:30:36 +02:00
clipped_params = np.clip(action, self.traj_gen_action_space.low, self.traj_gen_action_space.high)
2022-06-30 14:08:54 +02:00
self.traj_gen.set_params(clipped_params)
2022-07-26 10:33:59 +02:00
bc_time = np.array(0 if not self.do_replanning else self.current_traj_steps * self.dt)
2022-07-27 16:34:35 +02:00
# TODO we could think about initializing with the previous desired value in order to have a smooth transition
# at least from the planning point of view.
2022-07-26 10:33:59 +02:00
self.traj_gen.set_boundary_conditions(bc_time, self.current_pos, self.current_vel)
2022-07-27 16:34:35 +02:00
duration = None if self.learn_sub_trajectories else self.duration
self.traj_gen.set_duration(duration, self.dt)
2022-07-26 10:33:59 +02:00
# traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
trajectory = get_numpy(self.traj_gen.get_traj_pos())
velocity = get_numpy(self.traj_gen.get_traj_vel())
2022-04-28 09:05:28 +02:00
2022-07-26 10:33:59 +02:00
if self.do_replanning:
# Remove first part of trajectory as this is already over
trajectory = trajectory[self.current_traj_steps:]
velocity = velocity[self.current_traj_steps:]
return trajectory, velocity
2022-04-28 09:05:28 +02:00
2022-07-06 09:05:35 +02:00
def _get_traj_gen_action_space(self):
2022-06-30 14:08:54 +02:00
"""This function can be used to set up an individual space for the parameters of the traj_gen."""
2022-07-26 10:33:59 +02:00
min_action_bounds, max_action_bounds = self.traj_gen.get_params_bounds()
2022-07-06 09:05:35 +02:00
action_space = gym.spaces.Box(low=min_action_bounds.numpy(), high=max_action_bounds.numpy(),
dtype=self.env.action_space.dtype)
return action_space
2022-05-05 18:50:20 +02:00
2022-07-06 09:05:35 +02:00
def _get_action_space(self):
2022-05-03 19:51:54 +02:00
"""
2022-07-13 15:10:43 +02:00
This function can be used to modify the action space for considering actions which are not learned via movement
2022-05-03 19:51:54 +02:00
primitives. E.g. ball releasing time for the beer pong task. By default, it is the parameter space of the
2022-07-13 15:10:43 +02:00
movement primitive.
2022-05-03 19:51:54 +02:00
Only needs to be overwritten if the action space needs to be modified.
"""
try:
2022-06-29 16:30:36 +02:00
return self.traj_gen_action_space
except AttributeError:
2022-07-06 09:05:35 +02:00
return self._get_traj_gen_action_space()
def _get_observation_space(self):
2022-07-26 10:33:59 +02:00
if self.return_context_observation:
mask = self.env.context_mask
2022-07-06 09:05:35 +02:00
# return full observation
2022-07-26 10:33:59 +02:00
min_obs_bound = self.env.observation_space.low[mask]
max_obs_bound = self.env.observation_space.high[mask]
return spaces.Box(low=min_obs_bound, high=max_obs_bound, dtype=self.env.observation_space.dtype)
return self.env.observation_space
2022-05-03 19:51:54 +02:00
2022-04-28 09:05:28 +02:00
def step(self, action: np.ndarray):
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
2022-06-29 16:30:36 +02:00
2022-07-06 09:05:35 +02:00
# TODO remove this part, right now only needed for beer pong
2022-07-13 09:52:51 +02:00
mp_params, env_spec_params = self.env.episode_callback(action, self.traj_gen)
2022-05-03 19:51:54 +02:00
trajectory, velocity = self.get_trajectory(mp_params)
2022-04-28 09:05:28 +02:00
trajectory_length = len(trajectory)
2022-06-29 09:37:18 +02:00
rewards = np.zeros(shape=(trajectory_length,))
if self.verbose >= 2:
2022-05-03 19:51:54 +02:00
actions = np.zeros(shape=(trajectory_length,) + self.env.action_space.shape)
2022-04-28 09:05:28 +02:00
observations = np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape,
dtype=self.env.observation_space.dtype)
2022-04-28 09:05:28 +02:00
infos = dict()
2022-06-29 09:37:18 +02:00
done = False
2022-04-28 09:05:28 +02:00
2022-07-06 09:05:35 +02:00
for t, (pos, vel) in enumerate(zip(trajectory, velocity)):
step_action = self.tracking_controller.get_action(pos, vel, self.current_pos, self.current_vel)
2022-05-03 19:51:54 +02:00
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
obs, c_reward, done, info = self.env.step(c_action)
2022-06-29 09:37:18 +02:00
rewards[t] = c_reward
2022-05-03 19:51:54 +02:00
if self.verbose >= 2:
actions[t, :] = c_action
observations[t, :] = obs
2022-06-29 09:37:18 +02:00
2022-04-28 09:05:28 +02:00
for k, v in info.items():
elems = infos.get(k, [None] * trajectory_length)
elems[t] = v
infos[k] = elems
2022-06-29 09:37:18 +02:00
2022-06-29 16:30:36 +02:00
if self.render_kwargs:
2022-07-07 10:47:04 +02:00
self.env.render(**self.render_kwargs)
2022-06-29 09:37:18 +02:00
2022-07-06 09:05:35 +02:00
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
t + 1 + self.current_traj_steps):
2022-04-28 09:05:28 +02:00
break
2022-06-29 09:37:18 +02:00
2022-07-26 10:33:59 +02:00
infos.update({k: v[:t] for k, v in infos.items()})
2022-06-29 16:30:36 +02:00
self.current_traj_steps += t + 1
2022-06-29 09:37:18 +02:00
2022-05-03 19:51:54 +02:00
if self.verbose >= 2:
2022-07-06 09:05:35 +02:00
infos['positions'] = trajectory
infos['velocities'] = velocity
2022-07-26 10:33:59 +02:00
infos['step_actions'] = actions[:t + 1]
2022-05-02 15:06:21 +02:00
infos['step_observations'] = observations[:t + 1]
2022-07-26 10:33:59 +02:00
infos['step_rewards'] = rewards[:t + 1]
2022-06-29 09:37:18 +02:00
2022-05-29 11:58:01 +02:00
infos['trajectory_length'] = t + 1
2022-07-26 10:33:59 +02:00
trajectory_return = self.reward_aggregation(rewards[:t + 1])
2022-07-06 09:05:35 +02:00
return self.observation(obs), trajectory_return, done, infos
2022-04-28 09:05:28 +02:00
2022-06-29 16:30:36 +02:00
def render(self, **kwargs):
2022-04-28 09:05:28 +02:00
"""Only set render options here, such that they can be used during the rollout.
This only needs to be called once"""
2022-07-07 10:47:04 +02:00
self.render_kwargs = kwargs
2022-07-06 09:05:35 +02:00
def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None):
2022-06-29 16:30:36 +02:00
self.current_traj_steps = 0
2022-07-07 10:47:04 +02:00
return super(BlackBoxWrapper, self).reset()