fancy_gym/alr_envs/utils/dmp_env_wrapper.py

190 lines
6.1 KiB
Python
Raw Normal View History

2021-01-11 16:08:42 +01:00
from mp_lib.phase import ExpDecayPhaseGenerator
from mp_lib.basis import DMPBasisGenerator
from mp_lib import dmps
import numpy as np
import gym
class DmpEnvWrapperBase(gym.Wrapper):
2021-02-05 17:10:03 +01:00
def __init__(self,
env,
num_dof,
num_basis,
start_pos=None,
final_pos=None,
duration=1,
alpha_phase=2,
dt=0.01,
learn_goal=False,
post_traj_time=0.,
policy=None):
2021-01-11 16:08:42 +01:00
super(DmpEnvWrapperBase, self).__init__(env)
self.num_dof = num_dof
self.num_basis = num_basis
self.dim = num_dof * num_basis
if learn_goal:
self.dim += num_dof
2021-02-05 17:10:03 +01:00
self.learn_goal = learn_goal
2021-01-12 10:52:08 +01:00
self.duration = duration # seconds
2021-01-11 16:08:42 +01:00
time_steps = int(duration / dt)
self.t = np.linspace(0, duration, time_steps)
2021-02-05 17:10:03 +01:00
self.post_traj_steps = int(post_traj_time / dt)
2021-01-11 16:08:42 +01:00
2021-02-05 17:10:03 +01:00
phase_generator = ExpDecayPhaseGenerator(alpha_phase=alpha_phase, duration=duration)
2021-01-11 16:08:42 +01:00
basis_generator = DMPBasisGenerator(phase_generator, duration=duration, num_basis=self.num_basis)
self.dmp = dmps.DMP(num_dof=num_dof,
basis_generator=basis_generator,
phase_generator=phase_generator,
num_time_steps=time_steps,
dt=dt
)
2021-02-05 17:10:03 +01:00
self.dmp.dmp_start_pos = start_pos.reshape((1, num_dof))
2021-01-11 16:08:42 +01:00
dmp_weights = np.zeros((num_basis, num_dof))
2021-02-05 17:10:03 +01:00
if learn_goal:
dmp_goal_pos = np.zeros(num_dof)
else:
dmp_goal_pos = final_pos
2021-01-11 16:08:42 +01:00
self.dmp.set_weights(dmp_weights, dmp_goal_pos)
2021-02-05 17:10:03 +01:00
self.policy = policy
2021-01-12 10:52:08 +01:00
def __call__(self, params):
params = np.atleast_2d(params)
observations = []
rewards = []
dones = []
infos = []
for p in params:
observation, reward, done, info = self.rollout(p)
observations.append(observation)
rewards.append(reward)
dones.append(done)
infos.append(info)
2021-02-05 17:10:03 +01:00
return np.array(rewards), infos
2021-01-12 10:52:08 +01:00
2021-01-11 16:08:42 +01:00
def goal_and_weights(self, params):
if len(params.shape) > 1:
assert params.shape[1] == self.dim
else:
assert len(params) == self.dim
params = np.reshape(params, [1, self.dim])
if self.learn_goal:
goal_pos = params[0, -self.num_dof:]
weight_matrix = np.reshape(params[:, :-self.num_dof], [self.num_basis, self.num_dof])
else:
goal_pos = None
weight_matrix = np.reshape(params, [self.num_basis, self.num_dof])
return goal_pos, weight_matrix
2021-01-12 10:52:08 +01:00
def rollout(self, params, render=False):
""" This function generates a trajectory based on a DMP and then does the usual loop over reset and step"""
2021-01-11 16:08:42 +01:00
raise NotImplementedError
2021-02-05 17:10:03 +01:00
class DmpEnvWrapperPos(DmpEnvWrapperBase):
2021-01-12 10:52:08 +01:00
"""
Wrapper for gym environments which creates a trajectory in joint angle space
"""
def rollout(self, action, render=False):
2021-01-11 16:08:42 +01:00
goal_pos, weight_matrix = self.goal_and_weights(action)
2021-01-12 10:52:08 +01:00
if hasattr(self.env, "weight_matrix_scale"):
weight_matrix = weight_matrix * self.env.weight_matrix_scale
2021-01-11 16:08:42 +01:00
self.dmp.set_weights(weight_matrix, goal_pos)
2021-02-05 17:10:03 +01:00
trajectory, _ = self.dmp.reference_trajectory(self.t)
if self.post_traj_steps > 0:
trajectory = np.vstack([trajectory, np.tile(trajectory[-1, :], [self.post_traj_steps, 1])])
self._trajectory = trajectory
2021-01-11 16:08:42 +01:00
rews = []
2021-01-12 10:52:08 +01:00
self.env.reset()
2021-01-11 16:08:42 +01:00
for t, traj in enumerate(trajectory):
obs, rew, done, info = self.env.step(traj)
rews.append(rew)
if render:
self.env.render(mode="human")
if done:
break
reward = np.sum(rews)
return obs, reward, done, info
class DmpEnvWrapperVel(DmpEnvWrapperBase):
2021-01-12 10:52:08 +01:00
"""
Wrapper for gym environments which creates a trajectory in joint velocity space
"""
def rollout(self, action, render=False):
2021-01-11 16:08:42 +01:00
goal_pos, weight_matrix = self.goal_and_weights(action)
2021-01-12 10:52:08 +01:00
if hasattr(self.env, "weight_matrix_scale"):
weight_matrix = weight_matrix * self.env.weight_matrix_scale
2021-01-11 16:08:42 +01:00
self.dmp.set_weights(weight_matrix, goal_pos)
2021-02-05 17:10:03 +01:00
_, velocities = self.dmp.reference_trajectory(self.t)
2021-01-11 16:08:42 +01:00
rews = []
2021-01-14 17:10:03 +01:00
infos = []
2021-01-11 16:08:42 +01:00
2021-01-12 10:52:08 +01:00
self.env.reset()
2021-01-11 16:08:42 +01:00
for t, vel in enumerate(velocities):
obs, rew, done, info = self.env.step(vel)
rews.append(rew)
2021-01-14 17:10:03 +01:00
infos.append(info)
2021-01-11 16:08:42 +01:00
if render:
self.env.render(mode="human")
if done:
break
reward = np.sum(rews)
return obs, reward, done, info
2021-02-05 17:10:03 +01:00
class DmpEnvWrapperPD(DmpEnvWrapperBase):
"""
Wrapper for gym environments which creates a trajectory in joint velocity space
"""
def rollout(self, action, render=False):
goal_pos, weight_matrix = self.goal_and_weights(action)
if hasattr(self.env, "weight_matrix_scale"):
weight_matrix = weight_matrix * self.env.weight_matrix_scale
self.dmp.set_weights(weight_matrix, goal_pos)
trajectory, velocity = self.dmp.reference_trajectory(self.t)
if self.post_traj_steps > 0:
trajectory = np.vstack([trajectory, np.tile(trajectory[-1, :], [self.post_traj_steps, 1])])
velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.num_dof))])
self._trajectory = trajectory
self._velocity = velocity
rews = []
infos = []
self.env.reset()
for t, pos_vel in enumerate(zip(trajectory, velocity)):
ac = self.policy.get_action(self.env, pos_vel[0], pos_vel[1])
obs, rew, done, info = self.env.step(ac)
rews.append(rew)
infos.append(info)
if render:
self.env.render(mode="human")
if done:
break
reward = np.sum(rews)
return obs, reward, done, info