from gym import Env

from alr_envs.mujoco.alr_mujoco_env import AlrMujocoEnv


class BaseController:
    def __init__(self, env: Env):
        self.env = env

    def get_action(self, des_pos, des_vel):
        raise NotImplementedError


class PosController(BaseController):
    def get_action(self, des_pos, des_vel):
        return des_pos


class VelController(BaseController):
    def get_action(self, des_pos, des_vel):
        return des_vel


class PDController(BaseController):
    def __init__(self, env: AlrMujocoEnv):
        self.p_gains = env.p_gains
        self.d_gains = env.d_gains
        super(PDController, self).__init__(env)

    def get_action(self, des_pos, des_vel):
        # TODO: make standardized ALRenv such that all of them have current_pos/vel attributes
        cur_pos = self.env.current_pos
        cur_vel = self.env.current_vel
        if len(des_pos) != len(cur_pos):
            des_pos = self.env.extend_des_pos(des_pos)
        if len(des_vel) != len(cur_vel):
            des_vel = self.env.extend_des_vel(des_vel)
        trq = self.p_gains * (des_pos - cur_pos) + self.d_gains * (des_vel - cur_vel)
        return trq


def get_policy_class(policy_type):
    if policy_type == "motor":
        return PDController
    elif policy_type == "velocity":
        return VelController
    elif policy_type == "position":
        return PosController