2021-02-11 10:49:57 +01:00
|
|
|
from alr_envs.mujoco.alr_mujoco_env import AlrMujocoEnv
|
2021-02-05 17:10:03 +01:00
|
|
|
|
2021-02-11 10:49:57 +01:00
|
|
|
|
|
|
|
class BaseController:
|
|
|
|
def __init__(self, env: AlrMujocoEnv):
|
|
|
|
self.env = env
|
|
|
|
|
|
|
|
def get_action(self, des_pos, des_vel):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
class PosController(BaseController):
|
|
|
|
def get_action(self, des_pos, des_vel):
|
|
|
|
return des_pos
|
|
|
|
|
|
|
|
|
|
|
|
class VelController(BaseController):
|
|
|
|
def get_action(self, des_pos, des_vel):
|
|
|
|
return des_vel
|
|
|
|
|
|
|
|
|
|
|
|
class PDController(BaseController):
|
|
|
|
def __init__(self, env):
|
|
|
|
self.p_gains = env.p_gains
|
|
|
|
self.d_gains = env.d_gains
|
|
|
|
super(PDController, self).__init__(env)
|
|
|
|
|
|
|
|
def get_action(self, des_pos, des_vel):
|
2021-02-05 17:10:03 +01:00
|
|
|
# TODO: make standardized ALRenv such that all of them have current_pos/vel attributes
|
2021-02-11 10:49:57 +01:00
|
|
|
cur_pos = self.env.current_pos
|
|
|
|
cur_vel = self.env.current_vel
|
2021-02-05 17:10:03 +01:00
|
|
|
if len(des_pos) != len(cur_pos):
|
2021-02-11 10:49:57 +01:00
|
|
|
des_pos = self.env.extend_des_pos(des_pos)
|
2021-02-05 17:10:03 +01:00
|
|
|
if len(des_vel) != len(cur_vel):
|
2021-02-11 10:49:57 +01:00
|
|
|
des_vel = self.env.extend_des_vel(des_vel)
|
2021-02-05 17:10:03 +01:00
|
|
|
trq = self.p_gains * (des_pos - cur_pos) + self.d_gains * (des_vel - cur_vel)
|
|
|
|
return trq
|
2021-02-15 16:31:34 +01:00
|
|
|
|
|
|
|
|
|
|
|
def get_policy_class(policy_type):
|
|
|
|
if policy_type == "motor":
|
|
|
|
return PDController
|
|
|
|
elif policy_type == "velocity":
|
|
|
|
return VelController
|
|
|
|
elif policy_type == "position":
|
|
|
|
return PosController
|