fancy_gym/alr_envs/utils/detpmp_env_wrapper.py

88 lines
2.9 KiB
Python
Raw Normal View History

2021-02-15 16:31:34 +01:00
from alr_envs.utils.policies import get_policy_class
2021-02-05 17:10:03 +01:00
from mp_lib import det_promp
import numpy as np
import gym
2021-02-15 16:31:34 +01:00
class DetPMPEnvWrapper(gym.Wrapper):
2021-02-05 17:10:03 +01:00
def __init__(self,
env,
num_dof,
num_basis,
width,
start_pos=None,
duration=1,
dt=0.01,
post_traj_time=0.,
2021-02-15 16:31:34 +01:00
policy_type=None,
weights_scale=1,
2021-02-16 15:47:32 +01:00
zero_start=False,
zero_goal=False,
):
2021-02-15 16:31:34 +01:00
super(DetPMPEnvWrapper, self).__init__(env)
2021-02-05 17:10:03 +01:00
self.num_dof = num_dof
self.num_basis = num_basis
self.dim = num_dof * num_basis
2021-02-15 16:31:34 +01:00
self.pmp = det_promp.DeterministicProMP(n_basis=num_basis, n_dof=num_dof, width=width, off=0.01,
2021-02-16 15:47:32 +01:00
zero_start=zero_start, zero_goal=zero_goal)
2021-02-05 17:10:03 +01:00
weights = np.zeros(shape=(num_basis, num_dof))
self.pmp.set_weights(duration, weights)
self.weights_scale = weights_scale
self.duration = duration
self.dt = dt
self.post_traj_steps = int(post_traj_time / dt)
self.start_pos = start_pos
2021-02-16 15:47:32 +01:00
self.zero_centered = zero_start
2021-02-05 17:10:03 +01:00
2021-02-15 16:31:34 +01:00
policy_class = get_policy_class(policy_type)
self.policy = policy_class(env)
2021-02-05 17:10:03 +01:00
2021-02-15 16:31:34 +01:00
def __call__(self, params, contexts=None):
2021-02-05 17:10:03 +01:00
params = np.atleast_2d(params)
rewards = []
infos = []
2021-02-15 16:31:34 +01:00
for p, c in zip(params, contexts):
reward, info = self.rollout(p, c)
2021-02-05 17:10:03 +01:00
rewards.append(reward)
infos.append(info)
return np.array(rewards), infos
2021-02-15 16:31:34 +01:00
def rollout(self, params, context=None, render=False):
2021-02-05 17:10:03 +01:00
""" This function generates a trajectory based on a DMP and then does the usual loop over reset and step"""
params = np.reshape(params, newshape=(self.num_basis, self.num_dof)) * self.weights_scale
self.pmp.set_weights(self.duration, params)
2021-02-15 16:31:34 +01:00
t, des_pos, des_vel, des_acc = self.pmp.compute_trajectory(1 / self.dt, 1.)
if self.zero_centered:
des_pos += self.start_pos[None, :]
2021-02-05 17:10:03 +01:00
if self.post_traj_steps > 0:
des_pos = np.vstack([des_pos, np.tile(des_pos[-1, :], [self.post_traj_steps, 1])])
des_vel = np.vstack([des_vel, np.zeros(shape=(self.post_traj_steps, self.num_dof))])
self._trajectory = des_pos
2021-02-15 16:31:34 +01:00
self._velocity = des_vel
2021-02-05 17:10:03 +01:00
rews = []
infos = []
2021-02-15 16:31:34 +01:00
self.env.configure(context)
2021-02-05 17:10:03 +01:00
self.env.reset()
for t, pos_vel in enumerate(zip(des_pos, des_vel)):
2021-02-15 16:31:34 +01:00
ac = self.policy.get_action(pos_vel[0], pos_vel[1])
2021-02-05 17:10:03 +01:00
obs, rew, done, info = self.env.step(ac)
rews.append(rew)
infos.append(info)
if render:
self.env.render(mode="human")
if done:
break
reward = np.sum(rews)
2021-02-15 16:31:34 +01:00
return reward, info