130 lines
4.4 KiB
Python
130 lines
4.4 KiB
Python
from typing import Tuple, Union
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
class BaseController:
|
|
def __init__(self, env, **kwargs):
|
|
self.env = env
|
|
|
|
def get_action(self, des_pos, des_vel):
|
|
raise NotImplementedError
|
|
|
|
|
|
class PosController(BaseController):
|
|
"""
|
|
A Position Controller. The controller calculates a response only based on the desired position.
|
|
"""
|
|
def get_action(self, des_pos, des_vel):
|
|
return des_pos
|
|
|
|
|
|
class VelController(BaseController):
|
|
"""
|
|
A Velocity Controller. The controller calculates a response only based on the desired velocity.
|
|
"""
|
|
def get_action(self, des_pos, des_vel):
|
|
return des_vel
|
|
|
|
|
|
class PDController(BaseController):
|
|
"""
|
|
A PD-Controller. Using position and velocity information from a provided environment,
|
|
the controller calculates a response based on the desired position and velocity
|
|
|
|
:param env: A position environment
|
|
:param p_gains: Factors for the proportional gains
|
|
:param d_gains: Factors for the differential gains
|
|
"""
|
|
|
|
def __init__(self,
|
|
env,
|
|
p_gains: Union[float, Tuple],
|
|
d_gains: Union[float, Tuple]):
|
|
self.p_gains = p_gains
|
|
self.d_gains = d_gains
|
|
super(PDController, self).__init__(env)
|
|
|
|
def get_action(self, des_pos, des_vel):
|
|
cur_pos = self.env.current_pos
|
|
cur_vel = self.env.current_vel
|
|
assert des_pos.shape == cur_pos.shape, \
|
|
f"Mismatch in dimension between desired position {des_pos.shape} and current position {cur_pos.shape}"
|
|
assert des_vel.shape == cur_vel.shape, \
|
|
f"Mismatch in dimension between desired velocity {des_vel.shape} and current velocity {cur_vel.shape}"
|
|
trq = self.p_gains * (des_pos - cur_pos) + self.d_gains * (des_vel - cur_vel)
|
|
return trq
|
|
|
|
|
|
class MetaWorldController(BaseController):
|
|
"""
|
|
A Metaworld Controller. Using position and velocity information from a provided environment,
|
|
the controller calculates a response based on the desired position and velocity.
|
|
Unlike the other Controllers, this is a special controller for MetaWorld environments.
|
|
They use a position delta for the xyz coordinates and a raw position for the gripper opening.
|
|
|
|
:param env: A position environment
|
|
"""
|
|
|
|
def __init__(self,
|
|
env
|
|
):
|
|
super(MetaWorldController, self).__init__(env)
|
|
|
|
def get_action(self, des_pos, des_vel):
|
|
gripper_pos = des_pos[-1]
|
|
|
|
cur_pos = self.env.current_pos[:-1]
|
|
xyz_pos = des_pos[:-1]
|
|
|
|
assert xyz_pos.shape == cur_pos.shape, \
|
|
f"Mismatch in dimension between desired position {xyz_pos.shape} and current position {cur_pos.shape}"
|
|
trq = np.hstack([(xyz_pos - cur_pos), gripper_pos])
|
|
return trq
|
|
|
|
|
|
#TODO: Do we need this class?
|
|
class PDControllerExtend(BaseController):
|
|
"""
|
|
A PD-Controller. Using position and velocity information from a provided positional environment,
|
|
the controller calculates a response based on the desired position and velocity
|
|
|
|
:param env: A position environment
|
|
:param p_gains: Factors for the proportional gains
|
|
:param d_gains: Factors for the differential gains
|
|
"""
|
|
|
|
def __init__(self,
|
|
env,
|
|
p_gains: Union[float, Tuple],
|
|
d_gains: Union[float, Tuple]):
|
|
|
|
self.p_gains = p_gains
|
|
self.d_gains = d_gains
|
|
super(PDControllerExtend, self).__init__(env)
|
|
|
|
def get_action(self, des_pos, des_vel):
|
|
cur_pos = self.env.current_pos
|
|
cur_vel = self.env.current_vel
|
|
if len(des_pos) != len(cur_pos):
|
|
des_pos = self.env.extend_des_pos(des_pos)
|
|
if len(des_vel) != len(cur_vel):
|
|
des_vel = self.env.extend_des_vel(des_vel)
|
|
trq = self.p_gains * (des_pos - cur_pos) + self.d_gains * (des_vel - cur_vel)
|
|
return trq
|
|
|
|
|
|
def get_policy_class(policy_type, env, **kwargs):
|
|
if policy_type == "motor":
|
|
return PDController(env, **kwargs)
|
|
elif policy_type == "velocity":
|
|
return VelController(env)
|
|
elif policy_type == "position":
|
|
return PosController(env)
|
|
elif policy_type == "metaworld":
|
|
return MetaWorldController(env)
|
|
else:
|
|
raise ValueError(f"Invalid controller type {policy_type} provided. Only 'motor', 'velocity', 'position' "
|
|
f"and 'metaworld are currently supported controllers.")
|