47 lines
1.3 KiB
Python
47 lines
1.3 KiB
Python
from alr_envs.mujoco.alr_mujoco_env import AlrMujocoEnv
|
|
|
|
|
|
class BaseController:
|
|
def __init__(self, env: AlrMujocoEnv):
|
|
self.env = env
|
|
|
|
def get_action(self, des_pos, des_vel):
|
|
raise NotImplementedError
|
|
|
|
|
|
class PosController(BaseController):
|
|
def get_action(self, des_pos, des_vel):
|
|
return des_pos
|
|
|
|
|
|
class VelController(BaseController):
|
|
def get_action(self, des_pos, des_vel):
|
|
return des_vel
|
|
|
|
|
|
class PDController(BaseController):
|
|
def __init__(self, env):
|
|
self.p_gains = env.p_gains
|
|
self.d_gains = env.d_gains
|
|
super(PDController, self).__init__(env)
|
|
|
|
def get_action(self, des_pos, des_vel):
|
|
# TODO: make standardized ALRenv such that all of them have current_pos/vel attributes
|
|
cur_pos = self.env.current_pos
|
|
cur_vel = self.env.current_vel
|
|
if len(des_pos) != len(cur_pos):
|
|
des_pos = self.env.extend_des_pos(des_pos)
|
|
if len(des_vel) != len(cur_vel):
|
|
des_vel = self.env.extend_des_vel(des_vel)
|
|
trq = self.p_gains * (des_pos - cur_pos) + self.d_gains * (des_vel - cur_vel)
|
|
return trq
|
|
|
|
|
|
def get_policy_class(policy_type):
|
|
if policy_type == "motor":
|
|
return PDController
|
|
elif policy_type == "velocity":
|
|
return VelController
|
|
elif policy_type == "position":
|
|
return PosController
|