2021-02-15 16:31:34 +01:00
|
|
|
from alr_envs.utils.policies import get_policy_class
|
2021-02-05 17:10:03 +01:00
|
|
|
from mp_lib import det_promp
|
|
|
|
import numpy as np
|
|
|
|
import gym
|
|
|
|
|
|
|
|
|
2021-02-15 16:31:34 +01:00
|
|
|
class DetPMPEnvWrapper(gym.Wrapper):
|
2021-02-05 17:10:03 +01:00
|
|
|
def __init__(self,
|
|
|
|
env,
|
|
|
|
num_dof,
|
|
|
|
num_basis,
|
|
|
|
width,
|
|
|
|
start_pos=None,
|
|
|
|
duration=1,
|
|
|
|
dt=0.01,
|
|
|
|
post_traj_time=0.,
|
2021-02-15 16:31:34 +01:00
|
|
|
policy_type=None,
|
|
|
|
weights_scale=1,
|
2021-02-16 15:47:32 +01:00
|
|
|
zero_start=False,
|
|
|
|
zero_goal=False,
|
|
|
|
):
|
2021-02-15 16:31:34 +01:00
|
|
|
super(DetPMPEnvWrapper, self).__init__(env)
|
2021-02-05 17:10:03 +01:00
|
|
|
self.num_dof = num_dof
|
|
|
|
self.num_basis = num_basis
|
|
|
|
self.dim = num_dof * num_basis
|
2021-02-15 16:31:34 +01:00
|
|
|
self.pmp = det_promp.DeterministicProMP(n_basis=num_basis, n_dof=num_dof, width=width, off=0.01,
|
2021-02-16 15:47:32 +01:00
|
|
|
zero_start=zero_start, zero_goal=zero_goal)
|
2021-02-05 17:10:03 +01:00
|
|
|
weights = np.zeros(shape=(num_basis, num_dof))
|
|
|
|
self.pmp.set_weights(duration, weights)
|
|
|
|
self.weights_scale = weights_scale
|
|
|
|
|
|
|
|
self.duration = duration
|
|
|
|
self.dt = dt
|
|
|
|
self.post_traj_steps = int(post_traj_time / dt)
|
|
|
|
|
|
|
|
self.start_pos = start_pos
|
2021-02-26 17:34:31 +01:00
|
|
|
self.zero_start = zero_start
|
2021-02-05 17:10:03 +01:00
|
|
|
|
2021-02-15 16:31:34 +01:00
|
|
|
policy_class = get_policy_class(policy_type)
|
|
|
|
self.policy = policy_class(env)
|
2021-02-05 17:10:03 +01:00
|
|
|
|
2021-02-15 16:31:34 +01:00
|
|
|
def __call__(self, params, contexts=None):
|
2021-02-05 17:10:03 +01:00
|
|
|
params = np.atleast_2d(params)
|
|
|
|
rewards = []
|
|
|
|
infos = []
|
2021-02-15 16:31:34 +01:00
|
|
|
for p, c in zip(params, contexts):
|
|
|
|
reward, info = self.rollout(p, c)
|
2021-02-05 17:10:03 +01:00
|
|
|
rewards.append(reward)
|
|
|
|
infos.append(info)
|
|
|
|
|
|
|
|
return np.array(rewards), infos
|
|
|
|
|
2021-02-15 16:31:34 +01:00
|
|
|
def rollout(self, params, context=None, render=False):
|
2021-02-05 17:10:03 +01:00
|
|
|
""" This function generates a trajectory based on a DMP and then does the usual loop over reset and step"""
|
|
|
|
params = np.reshape(params, newshape=(self.num_basis, self.num_dof)) * self.weights_scale
|
|
|
|
self.pmp.set_weights(self.duration, params)
|
2021-02-15 16:31:34 +01:00
|
|
|
t, des_pos, des_vel, des_acc = self.pmp.compute_trajectory(1 / self.dt, 1.)
|
2021-02-26 17:34:31 +01:00
|
|
|
if self.zero_start:
|
2021-02-15 16:31:34 +01:00
|
|
|
des_pos += self.start_pos[None, :]
|
2021-02-05 17:10:03 +01:00
|
|
|
|
|
|
|
if self.post_traj_steps > 0:
|
|
|
|
des_pos = np.vstack([des_pos, np.tile(des_pos[-1, :], [self.post_traj_steps, 1])])
|
|
|
|
des_vel = np.vstack([des_vel, np.zeros(shape=(self.post_traj_steps, self.num_dof))])
|
|
|
|
|
|
|
|
self._trajectory = des_pos
|
2021-02-15 16:31:34 +01:00
|
|
|
self._velocity = des_vel
|
2021-02-05 17:10:03 +01:00
|
|
|
|
|
|
|
rews = []
|
|
|
|
infos = []
|
|
|
|
|
2021-02-15 16:31:34 +01:00
|
|
|
self.env.configure(context)
|
2021-02-05 17:10:03 +01:00
|
|
|
self.env.reset()
|
|
|
|
|
|
|
|
for t, pos_vel in enumerate(zip(des_pos, des_vel)):
|
2021-02-15 16:31:34 +01:00
|
|
|
ac = self.policy.get_action(pos_vel[0], pos_vel[1])
|
2021-02-05 17:10:03 +01:00
|
|
|
obs, rew, done, info = self.env.step(ac)
|
|
|
|
rews.append(rew)
|
|
|
|
infos.append(info)
|
|
|
|
if render:
|
|
|
|
self.env.render(mode="human")
|
|
|
|
if done:
|
|
|
|
break
|
|
|
|
|
|
|
|
reward = np.sum(rews)
|
|
|
|
|
2021-02-15 16:31:34 +01:00
|
|
|
return reward, info
|
|
|
|
|