2021-03-26 14:05:16 +01:00
|
|
|
import gym
|
2021-05-12 09:52:25 +02:00
|
|
|
import numpy as np
|
|
|
|
from mp_lib import dmps
|
|
|
|
from mp_lib.basis import DMPBasisGenerator
|
|
|
|
from mp_lib.phase import ExpDecayPhaseGenerator
|
2021-03-26 14:05:16 +01:00
|
|
|
|
2021-05-11 06:19:30 +02:00
|
|
|
from alr_envs.utils.mps.alr_env import AlrEnv
|
2021-05-12 09:52:25 +02:00
|
|
|
from alr_envs.utils.mps.mp_wrapper import MPWrapper
|
2021-03-26 14:05:16 +01:00
|
|
|
|
|
|
|
|
|
|
|
class DmpWrapper(MPWrapper):
|
|
|
|
|
2021-05-21 15:44:49 +02:00
|
|
|
def __init__(self, env: AlrEnv, num_dof: int, num_basis: int,
|
2021-05-10 12:17:52 +02:00
|
|
|
duration: int = 1, alpha_phase: float = 2., dt: float = None,
|
2021-05-12 09:52:25 +02:00
|
|
|
learn_goal: bool = False, post_traj_time: float = 0.,
|
2021-04-23 11:37:42 +02:00
|
|
|
weights_scale: float = 1., goal_scale: float = 1., bandwidth_factor: float = 3.,
|
2021-04-30 16:22:33 +02:00
|
|
|
policy_type: str = None, render_mode: str = None):
|
2021-03-26 14:05:16 +01:00
|
|
|
|
|
|
|
"""
|
|
|
|
This Wrapper generates a trajectory based on a DMP and will only return episodic performances.
|
|
|
|
Args:
|
|
|
|
env:
|
|
|
|
num_dof:
|
|
|
|
num_basis:
|
|
|
|
duration:
|
|
|
|
alpha_phase:
|
|
|
|
dt:
|
|
|
|
learn_goal:
|
|
|
|
post_traj_time:
|
|
|
|
policy_type:
|
|
|
|
weights_scale:
|
|
|
|
goal_scale:
|
|
|
|
"""
|
|
|
|
self.learn_goal = learn_goal
|
2021-05-11 06:19:30 +02:00
|
|
|
|
2021-03-26 14:05:16 +01:00
|
|
|
self.t = np.linspace(0, duration, int(duration / dt))
|
|
|
|
self.goal_scale = goal_scale
|
|
|
|
|
2021-05-11 06:19:30 +02:00
|
|
|
super().__init__(env=env, num_dof=num_dof, duration=duration, post_traj_time=post_traj_time,
|
|
|
|
policy_type=policy_type, weights_scale=weights_scale, render_mode=render_mode,
|
2021-05-12 09:52:25 +02:00
|
|
|
num_basis=num_basis, alpha_phase=alpha_phase, bandwidth_factor=bandwidth_factor)
|
2021-03-26 14:05:16 +01:00
|
|
|
|
2021-05-21 15:44:49 +02:00
|
|
|
action_bounds = np.inf * np.ones((np.prod(self.mp.weights.shape) + (num_dof if learn_goal else 0)))
|
2021-03-26 14:05:16 +01:00
|
|
|
self.action_space = gym.spaces.Box(low=-action_bounds, high=action_bounds, dtype=np.float32)
|
|
|
|
|
2021-05-11 06:19:30 +02:00
|
|
|
def initialize_mp(self, num_dof: int, duration: int, num_basis: int, alpha_phase: float = 2.,
|
|
|
|
bandwidth_factor: int = 3, **kwargs):
|
2021-03-26 14:05:16 +01:00
|
|
|
|
|
|
|
phase_generator = ExpDecayPhaseGenerator(alpha_phase=alpha_phase, duration=duration)
|
2021-04-21 10:45:34 +02:00
|
|
|
basis_generator = DMPBasisGenerator(phase_generator, duration=duration, num_basis=num_basis,
|
|
|
|
basis_bandwidth_factor=bandwidth_factor)
|
2021-03-26 14:05:16 +01:00
|
|
|
|
|
|
|
dmp = dmps.DMP(num_dof=num_dof, basis_generator=basis_generator, phase_generator=phase_generator,
|
2021-05-11 06:19:30 +02:00
|
|
|
dt=self.dt)
|
2021-03-26 14:05:16 +01:00
|
|
|
|
|
|
|
return dmp
|
|
|
|
|
|
|
|
def goal_and_weights(self, params):
|
|
|
|
assert params.shape[-1] == self.action_space.shape[0]
|
|
|
|
params = np.atleast_2d(params)
|
|
|
|
|
|
|
|
if self.learn_goal:
|
|
|
|
goal_pos = params[0, -self.mp.num_dimensions:] # [num_dof]
|
|
|
|
params = params[:, :-self.mp.num_dimensions] # [1,num_dof]
|
|
|
|
else:
|
2021-05-12 17:48:57 +02:00
|
|
|
goal_pos = self.env.goal_pos
|
2021-03-26 14:05:16 +01:00
|
|
|
assert goal_pos is not None
|
|
|
|
|
2021-05-21 15:44:49 +02:00
|
|
|
weight_matrix = np.reshape(params, self.mp.weights.shape) # [num_basis, num_dof]
|
2021-03-26 14:05:16 +01:00
|
|
|
return goal_pos * self.goal_scale, weight_matrix * self.weights_scale
|
|
|
|
|
|
|
|
def mp_rollout(self, action):
|
2021-05-12 09:52:25 +02:00
|
|
|
self.mp.dmp_start_pos = self.env.start_pos
|
2021-03-26 14:05:16 +01:00
|
|
|
goal_pos, weight_matrix = self.goal_and_weights(action)
|
|
|
|
self.mp.set_weights(weight_matrix, goal_pos)
|
|
|
|
return self.mp.reference_trajectory(self.t)
|