2021-03-26 14:05:16 +01:00
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
|
|
|
import gym
|
|
|
|
import numpy as np
|
|
|
|
|
2021-05-21 15:44:49 +02:00
|
|
|
from alr_envs.utils.mps.mp_environments import AlrEnv
|
2021-03-26 14:05:16 +01:00
|
|
|
from alr_envs.utils.policies import get_policy_class
|
|
|
|
|
|
|
|
|
|
|
|
class MPWrapper(gym.Wrapper, ABC):
|
|
|
|
|
2021-05-21 15:44:49 +02:00
|
|
|
def __init__(self, env: AlrEnv, num_dof: int, dt: float, duration: float = 1, post_traj_time: float = 0.,
|
2021-05-12 09:52:25 +02:00
|
|
|
policy_type: str = None, weights_scale: float = 1., render_mode: str = None, **mp_kwargs):
|
2021-03-26 14:05:16 +01:00
|
|
|
super().__init__(env)
|
|
|
|
|
2021-05-12 17:48:57 +02:00
|
|
|
# adjust observation space to reduce version
|
|
|
|
obs_sp = self.env.observation_space
|
|
|
|
self.observation_space = gym.spaces.Box(low=obs_sp.low[self.env.active_obs],
|
|
|
|
high=obs_sp.high[self.env.active_obs],
|
|
|
|
dtype=obs_sp.dtype)
|
|
|
|
|
2021-04-21 10:45:34 +02:00
|
|
|
assert dt is not None # this should never happen as MPWrapper is a base class
|
2021-03-26 14:05:16 +01:00
|
|
|
self.post_traj_steps = int(post_traj_time / dt)
|
|
|
|
|
|
|
|
self.mp = self.initialize_mp(num_dof, duration, dt, **mp_kwargs)
|
|
|
|
self.weights_scale = weights_scale
|
|
|
|
|
|
|
|
policy_class = get_policy_class(policy_type)
|
|
|
|
self.policy = policy_class(env)
|
|
|
|
|
|
|
|
# rendering
|
2021-04-30 16:10:16 +02:00
|
|
|
self.render_mode = render_mode
|
2021-04-30 16:22:33 +02:00
|
|
|
self.render_kwargs = {}
|
2021-03-26 14:05:16 +01:00
|
|
|
|
2021-05-12 09:52:25 +02:00
|
|
|
# TODO: @Max I think this should not be in this class, this functionality should be part of your sampler.
|
2021-04-21 10:45:34 +02:00
|
|
|
def __call__(self, params, contexts=None):
|
2021-05-12 09:52:25 +02:00
|
|
|
"""
|
|
|
|
Can be used to provide a batch of parameter sets
|
|
|
|
"""
|
2021-04-21 10:45:34 +02:00
|
|
|
params = np.atleast_2d(params)
|
|
|
|
obs = []
|
|
|
|
rewards = []
|
|
|
|
dones = []
|
|
|
|
infos = []
|
2021-04-23 11:37:42 +02:00
|
|
|
# for p, c in zip(params, contexts):
|
|
|
|
for p in params:
|
|
|
|
# self.configure(c)
|
2021-04-21 10:45:34 +02:00
|
|
|
ob, reward, done, info = self.step(p)
|
|
|
|
obs.append(ob)
|
|
|
|
rewards.append(reward)
|
|
|
|
dones.append(done)
|
|
|
|
infos.append(info)
|
|
|
|
|
|
|
|
return obs, np.array(rewards), dones, infos
|
|
|
|
|
2021-05-07 09:51:53 +02:00
|
|
|
def reset(self):
|
2021-05-12 17:48:57 +02:00
|
|
|
return self.env.reset()[self.env.active_obs]
|
2021-05-07 09:51:53 +02:00
|
|
|
|
2021-03-26 14:05:16 +01:00
|
|
|
def step(self, action: np.ndarray):
|
|
|
|
""" This function generates a trajectory based on a DMP and then does the usual loop over reset and step"""
|
|
|
|
trajectory, velocity = self.mp_rollout(action)
|
|
|
|
|
|
|
|
if self.post_traj_steps > 0:
|
|
|
|
trajectory = np.vstack([trajectory, np.tile(trajectory[-1, :], [self.post_traj_steps, 1])])
|
2021-05-21 15:44:49 +02:00
|
|
|
velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.mp.n_dof))])
|
2021-03-26 14:05:16 +01:00
|
|
|
|
|
|
|
# self._trajectory = trajectory
|
|
|
|
# self._velocity = velocity
|
|
|
|
|
|
|
|
rewards = 0
|
2021-03-26 16:37:38 +01:00
|
|
|
info = {}
|
2021-05-12 09:52:25 +02:00
|
|
|
# create random obs as the reset function is called externally
|
|
|
|
obs = self.env.observation_space.sample()
|
2021-03-26 14:05:16 +01:00
|
|
|
|
|
|
|
for t, pos_vel in enumerate(zip(trajectory, velocity)):
|
|
|
|
ac = self.policy.get_action(pos_vel[0], pos_vel[1])
|
|
|
|
obs, rew, done, info = self.env.step(ac)
|
|
|
|
rewards += rew
|
2021-03-26 16:37:38 +01:00
|
|
|
# TODO return all dicts?
|
|
|
|
# [infos[k].append(v) for k, v in info.items()]
|
2021-03-26 14:05:16 +01:00
|
|
|
if self.render_mode:
|
|
|
|
self.env.render(mode=self.render_mode, **self.render_kwargs)
|
|
|
|
if done:
|
|
|
|
break
|
|
|
|
|
|
|
|
done = True
|
2021-05-12 17:48:57 +02:00
|
|
|
return obs[self.env.active_obs], rewards, done, info
|
2021-03-26 14:05:16 +01:00
|
|
|
|
|
|
|
def render(self, mode='human', **kwargs):
|
|
|
|
"""Only set render options here, such that they can be used during the rollout.
|
|
|
|
This only needs to be called once"""
|
|
|
|
self.render_mode = mode
|
|
|
|
self.render_kwargs = kwargs
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
def mp_rollout(self, action):
|
|
|
|
"""
|
|
|
|
Generate trajectory and velocity based on the MP
|
|
|
|
Returns:
|
|
|
|
trajectory/positions, velocity
|
|
|
|
"""
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
@abstractmethod
|
2021-05-21 15:44:49 +02:00
|
|
|
def initialize_mp(self, num_dof: int, duration: float, dt: float, **kwargs):
|
2021-03-26 14:05:16 +01:00
|
|
|
"""
|
|
|
|
Create respective instance of MP
|
|
|
|
Returns:
|
|
|
|
MP instance
|
|
|
|
"""
|
|
|
|
|
|
|
|
raise NotImplementedError
|