fancy_gym/alr_envs/utils/policies.py

49 lines
1.4 KiB
Python
Raw Normal View History

from gym import Env
from alr_envs.mujoco.alr_mujoco_env import AlrMujocoEnv
2021-02-05 17:10:03 +01:00
class BaseController:
def __init__(self, env: Env):
self.env = env
def get_action(self, des_pos, des_vel):
raise NotImplementedError
class PosController(BaseController):
def get_action(self, des_pos, des_vel):
return des_pos
class VelController(BaseController):
def get_action(self, des_pos, des_vel):
return des_vel
class PDController(BaseController):
def __init__(self, env: AlrMujocoEnv):
self.p_gains = env.p_gains
self.d_gains = env.d_gains
super(PDController, self).__init__(env)
def get_action(self, des_pos, des_vel):
2021-02-05 17:10:03 +01:00
# TODO: make standardized ALRenv such that all of them have current_pos/vel attributes
cur_pos = self.env.current_pos
cur_vel = self.env.current_vel
2021-02-05 17:10:03 +01:00
if len(des_pos) != len(cur_pos):
des_pos = self.env.extend_des_pos(des_pos)
2021-02-05 17:10:03 +01:00
if len(des_vel) != len(cur_vel):
des_vel = self.env.extend_des_vel(des_vel)
2021-02-05 17:10:03 +01:00
trq = self.p_gains * (des_pos - cur_pos) + self.d_gains * (des_vel - cur_vel)
return trq
2021-02-15 16:31:34 +01:00
def get_policy_class(policy_type):
if policy_type == "motor":
return PDController
elif policy_type == "velocity":
return VelController
elif policy_type == "position":
return PosController