fancy_gym/alr_envs/black_box/black_box_wrapper.py

192 lines
8.6 KiB
Python
Raw Normal View History

2022-06-30 14:08:54 +02:00
from typing import Tuple, Union
2022-04-28 09:05:28 +02:00
import gym
import numpy as np
from gym import spaces
from mp_pytorch.mp.mp_interfaces import MPInterface
2022-06-30 17:33:05 +02:00
from alr_envs.black_box.controller.base_controller import BaseController
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
2022-06-29 16:30:36 +02:00
from alr_envs.utils.utils import get_numpy
2022-05-03 19:51:54 +02:00
2022-04-28 09:05:28 +02:00
2022-06-30 14:20:52 +02:00
class BlackBoxWrapper(gym.ObservationWrapper):
2022-06-29 09:37:18 +02:00
def __init__(self,
env: RawInterfaceWrapper,
2022-06-30 17:33:05 +02:00
trajectory_generator: MPInterface,
tracking_controller: BaseController,
duration: float,
verbose: int = 1,
learn_sub_trajectories: bool = False,
2022-06-30 14:08:54 +02:00
replanning_schedule: Union[None, callable] = None,
2022-06-30 17:33:05 +02:00
reward_aggregation: callable = np.sum
):
2022-06-29 09:37:18 +02:00
"""
gym.Wrapper for leveraging a black box approach with a trajectory generator.
Args:
env: The (wrapped) environment this wrapper is applied on
trajectory_generator: Generates the full or partial trajectory
tracking_controller: Translates the desired trajectory to raw action sequences
duration: Length of the trajectory of the movement primitive in seconds
verbose: level of detail for returned values in info dict.
2022-06-30 14:08:54 +02:00
learn_sub_trajectories: Transforms full episode learning into learning sub-trajectories, similar to
step-based learning
replanning_schedule: callable that receives
2022-06-29 09:37:18 +02:00
reward_aggregation: function that takes the np.ndarray of step rewards as input and returns the trajectory
reward, default summation over all values.
"""
2022-06-30 14:20:52 +02:00
super().__init__(env)
2022-04-28 09:05:28 +02:00
self.duration = duration
2022-06-30 14:08:54 +02:00
self.learn_sub_trajectories = learn_sub_trajectories
self.replanning_schedule = replanning_schedule
2022-06-29 16:30:36 +02:00
self.current_traj_steps = 0
2022-04-28 09:05:28 +02:00
2022-06-29 09:37:18 +02:00
# trajectory generation
2022-06-30 14:08:54 +02:00
self.traj_gen = trajectory_generator
2022-06-29 09:37:18 +02:00
self.tracking_controller = tracking_controller
2022-06-29 12:25:40 +02:00
# self.time_steps = np.linspace(0, self.duration, self.traj_steps)
2022-06-30 14:08:54 +02:00
# self.traj_gen.set_mp_times(self.time_steps)
self.traj_gen.set_duration(np.array([self.duration]), np.array([self.dt]))
2022-06-29 12:25:40 +02:00
# reward computation
2022-06-29 09:37:18 +02:00
self.reward_aggregation = reward_aggregation
2022-04-28 09:05:28 +02:00
2022-06-29 09:37:18 +02:00
# spaces
2022-06-30 14:08:54 +02:00
self.return_context_observation = not (self.learn_sub_trajectories or replanning_schedule)
2022-06-29 16:30:36 +02:00
self.traj_gen_action_space = self.get_traj_gen_action_space()
self.action_space = self.get_action_space()
2022-06-29 09:37:18 +02:00
self.observation_space = spaces.Box(low=self.env.observation_space.low[self.env.context_mask],
high=self.env.observation_space.high[self.env.context_mask],
2022-04-28 09:05:28 +02:00
dtype=self.env.observation_space.dtype)
2022-06-29 09:37:18 +02:00
# rendering
self.render_kwargs = {}
self.verbose = verbose
2022-06-29 09:37:18 +02:00
def observation(self, observation):
2022-06-29 16:30:36 +02:00
# return context space if we are
2022-06-30 14:08:54 +02:00
return observation[self.env.context_mask] if self.return_context_observation else observation
2022-06-29 09:37:18 +02:00
2022-04-28 09:05:28 +02:00
def get_trajectory(self, action: np.ndarray) -> Tuple:
2022-06-29 16:30:36 +02:00
clipped_params = np.clip(action, self.traj_gen_action_space.low, self.traj_gen_action_space.high)
2022-06-30 14:08:54 +02:00
self.traj_gen.set_params(clipped_params)
# TODO: Bruce said DMP, ProMP, ProDMP can have 0 bc_time for sequencing
# TODO Check with Bruce for replanning
self.traj_gen.set_boundary_conditions(
bc_time=np.zeros((1,)) if not self.replanning_schedule else self.current_traj_steps * self.dt,
bc_pos=self.current_pos, bc_vel=self.current_vel)
2022-06-29 16:30:36 +02:00
# TODO: is this correct for replanning? Do we need to adjust anything here?
self.traj_gen.set_duration(None if self.learn_sub_trajectories else np.array([self.duration]), np.array([self.dt]))
2022-06-30 14:08:54 +02:00
traj_dict = self.traj_gen.get_trajs(get_pos=True, get_vel=True)
2022-04-28 09:05:28 +02:00
trajectory_tensor, velocity_tensor = traj_dict['pos'], traj_dict['vel']
2022-06-29 16:30:36 +02:00
return get_numpy(trajectory_tensor), get_numpy(velocity_tensor)
2022-04-28 09:05:28 +02:00
2022-06-29 16:30:36 +02:00
def get_traj_gen_action_space(self):
2022-06-30 14:08:54 +02:00
"""This function can be used to set up an individual space for the parameters of the traj_gen."""
min_action_bounds, max_action_bounds = self.traj_gen.get_param_bounds()
2022-05-05 18:50:20 +02:00
mp_action_space = gym.spaces.Box(low=min_action_bounds.numpy(), high=max_action_bounds.numpy(),
dtype=np.float32)
2022-05-05 18:50:20 +02:00
return mp_action_space
def get_action_space(self):
2022-05-03 19:51:54 +02:00
"""
This function can be used to modify the action space for considering actions which are not learned via motion
primitives. E.g. ball releasing time for the beer pong task. By default, it is the parameter space of the
motion primitive.
Only needs to be overwritten if the action space needs to be modified.
"""
try:
2022-06-29 16:30:36 +02:00
return self.traj_gen_action_space
except AttributeError:
2022-06-29 16:30:36 +02:00
return self.get_traj_gen_action_space()
2022-05-03 19:51:54 +02:00
2022-04-28 09:05:28 +02:00
def step(self, action: np.ndarray):
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
2022-06-29 16:30:36 +02:00
# agent to learn when to release the ball
mp_params, env_spec_params = self.env._episode_callback(action, self.traj_gen)
2022-05-03 19:51:54 +02:00
trajectory, velocity = self.get_trajectory(mp_params)
2022-04-28 09:05:28 +02:00
trajectory_length = len(trajectory)
2022-06-29 09:37:18 +02:00
rewards = np.zeros(shape=(trajectory_length,))
if self.verbose >= 2:
2022-05-03 19:51:54 +02:00
actions = np.zeros(shape=(trajectory_length,) + self.env.action_space.shape)
2022-04-28 09:05:28 +02:00
observations = np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape,
dtype=self.env.observation_space.dtype)
2022-04-28 09:05:28 +02:00
infos = dict()
2022-06-29 09:37:18 +02:00
done = False
2022-04-28 09:05:28 +02:00
for t, pos_vel in enumerate(zip(trajectory, velocity)):
2022-06-29 09:37:18 +02:00
step_action = self.tracking_controller.get_action(pos_vel[0], pos_vel[1], self.current_pos,
self.current_vel)
step_action = self._step_callback(t, env_spec_params, step_action) # include possible callback info
2022-05-03 19:51:54 +02:00
c_action = np.clip(step_action, self.env.action_space.low, self.env.action_space.high)
2022-05-05 18:50:20 +02:00
# print('step/clipped action ratio: ', step_action/c_action)
2022-05-03 19:51:54 +02:00
obs, c_reward, done, info = self.env.step(c_action)
2022-06-29 09:37:18 +02:00
rewards[t] = c_reward
2022-05-03 19:51:54 +02:00
if self.verbose >= 2:
actions[t, :] = c_action
observations[t, :] = obs
2022-06-29 09:37:18 +02:00
2022-04-28 09:05:28 +02:00
for k, v in info.items():
elems = infos.get(k, [None] * trajectory_length)
elems[t] = v
infos[k] = elems
2022-06-29 09:37:18 +02:00
2022-06-29 16:30:36 +02:00
if self.render_kwargs:
self.render(**self.render_kwargs)
2022-06-29 09:37:18 +02:00
2022-06-30 14:08:54 +02:00
if done:
break
if self.replanning_schedule and self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
t + 1 + self.current_traj_steps):
2022-04-28 09:05:28 +02:00
break
2022-06-29 09:37:18 +02:00
2022-04-28 09:05:28 +02:00
infos.update({k: v[:t + 1] for k, v in infos.items()})
2022-06-29 16:30:36 +02:00
self.current_traj_steps += t + 1
2022-06-29 09:37:18 +02:00
2022-05-03 19:51:54 +02:00
if self.verbose >= 2:
infos['trajectory'] = trajectory
2022-05-02 15:06:21 +02:00
infos['step_actions'] = actions[:t + 1]
infos['step_observations'] = observations[:t + 1]
infos['step_rewards'] = rewards[:t + 1]
2022-06-29 09:37:18 +02:00
2022-05-29 11:58:01 +02:00
infos['trajectory_length'] = t + 1
2022-06-29 09:37:18 +02:00
trajectory_return = self.reward_aggregation(rewards[:t + 1])
2022-06-29 16:30:36 +02:00
return obs, trajectory_return, done, infos
2022-04-28 09:05:28 +02:00
2022-06-29 16:30:36 +02:00
def render(self, **kwargs):
2022-04-28 09:05:28 +02:00
"""Only set render options here, such that they can be used during the rollout.
This only needs to be called once"""
self.render_kwargs = kwargs
# self.env.render(mode=self.render_mode, **self.render_kwargs)
2022-06-29 16:30:36 +02:00
self.env.render(**kwargs)
2022-06-29 16:30:36 +02:00
def reset(self, **kwargs):
self.current_traj_steps = 0
2022-06-30 14:08:54 +02:00
super(BlackBoxWrapper, self).reset(**kwargs)
2022-05-31 19:41:08 +02:00
def plot_trajs(self, des_trajs, des_vels):
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('TkAgg')
pos_fig = plt.figure('positions')
vel_fig = plt.figure('velocities')
for i in range(des_trajs.shape[1]):
plt.figure(pos_fig.number)
plt.subplot(des_trajs.shape[1], 1, i + 1)
plt.plot(np.ones(des_trajs.shape[0]) * self.current_pos[i])
plt.plot(des_trajs[:, i])
plt.figure(vel_fig.number)
plt.subplot(des_vels.shape[1], 1, i + 1)
plt.plot(np.ones(des_trajs.shape[0]) * self.current_vel[i])
plt.plot(des_vels[:, i])