fancy_gym/alr_envs/utils/mps/mp_wrapper.py

135 lines
5.0 KiB
Python

from abc import ABC, abstractmethod
from typing import Union
import gym
import numpy as np
from alr_envs.utils.mps.alr_env import AlrEnv
from alr_envs.utils.policies import get_policy_class, BaseController
class MPWrapper(gym.Wrapper, ABC):
"""
Base class for movement primitive based gym.Wrapper implementations.
:param env: The (wrapped) environment this wrapper is applied on
:param num_dof: Dimension of the action space of the wrapped env
:param duration: Number of timesteps in the trajectory of the movement primitive
:param post_traj_time: Time for which the last position of the trajectory is fed to the environment to continue
simulation
:param policy_type: Type or object defining the policy that is used to generate action based on the trajectory
:param weight_scale: Scaling parameter for the actions given to this wrapper
:param render_mode: Equivalent to gym render mode
"""
def __init__(self,
env: AlrEnv,
num_dof: int,
duration: int = 1,
post_traj_time: float = 0.,
policy_type: Union[str, BaseController] = None,
weights_scale: float = 1.,
render_mode: str = None,
**mp_kwargs
):
super().__init__(env)
# adjust observation space to reduce version
obs_sp = self.env.observation_space
self.observation_space = gym.spaces.Box(low=obs_sp.low[self.env.active_obs],
high=obs_sp.high[self.env.active_obs],
dtype=obs_sp.dtype)
self.post_traj_steps = int(post_traj_time / env.dt)
self.mp = self.initialize_mp(num_dof=num_dof, duration=duration, **mp_kwargs)
self.weights_scale = weights_scale
if type(policy_type) is str:
self.policy = get_policy_class(policy_type, env, mp_kwargs)
else:
self.policy = policy_type
# rendering
self.render_mode = render_mode
self.render_kwargs = {}
# TODO: @Max I think this should not be in this class, this functionality should be part of your sampler.
def __call__(self, params, contexts=None):
"""
Can be used to provide a batch of parameter sets
"""
params = np.atleast_2d(params)
obs = []
rewards = []
dones = []
infos = []
# for p, c in zip(params, contexts):
for p in params:
# self.configure(c)
ob, reward, done, info = self.step(p)
obs.append(ob)
rewards.append(reward)
dones.append(done)
infos.append(info)
return obs, np.array(rewards), dones, infos
def reset(self):
return self.env.reset()[self.env.active_obs]
def step(self, action: np.ndarray):
""" This function generates a trajectory based on a DMP and then does the usual loop over reset and step"""
trajectory, velocity = self.mp_rollout(action)
if self.post_traj_steps > 0:
trajectory = np.vstack([trajectory, np.tile(trajectory[-1, :], [self.post_traj_steps, 1])])
velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.mp.num_dimensions))])
trajectory_length = len(trajectory)
actions = np.zeros(shape=(trajectory_length, self.mp.num_dimensions))
observations= np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape)
rewards = np.zeros(shape=(trajectory_length,))
trajectory_return = 0
infos = dict(step_infos =[])
for t, pos_vel in enumerate(zip(trajectory, velocity)):
actions[t,:] = self.policy.get_action(pos_vel[0], pos_vel[1])
observations[t,:], rewards[t], done, info = self.env.step(actions[t,:])
trajectory_return += rewards[t]
infos['step_infos'].append(info)
if self.render_mode:
self.env.render(mode=self.render_mode, **self.render_kwargs)
if done:
break
infos['step_actions'] = actions[:t+1]
infos['step_observations'] = observations[:t+1]
infos['step_rewards'] = rewards[:t+1]
infos['trajectory_length'] = t+1
done = True
return observations[t][self.env.active_obs], trajectory_return, done, infos
def render(self, mode='human', **kwargs):
"""Only set render options here, such that they can be used during the rollout.
This only needs to be called once"""
self.render_mode = mode
self.render_kwargs = kwargs
@abstractmethod
def mp_rollout(self, action):
"""
Generate trajectory and velocity based on the MP
Returns:
trajectory/positions, velocity
"""
raise NotImplementedError()
@abstractmethod
def initialize_mp(self, num_dof: int, duration: float, **kwargs):
"""
Create respective instance of MP
Returns:
MP instance
"""
raise NotImplementedError