fancy_gym/alr_envs/utils/policies.py

62 lines
2.1 KiB
Python
Raw Normal View History

from typing import Tuple, Union
from gym import Env
from alr_envs.utils.positional_env import PositionalEnv
2021-02-05 17:10:03 +01:00
class BaseController:
def __init__(self, env: Env, **kwargs):
self.env = env
def get_action(self, des_pos, des_vel):
raise NotImplementedError
class PosController(BaseController):
def get_action(self, des_pos, des_vel):
return des_pos
class VelController(BaseController):
def get_action(self, des_pos, des_vel):
return des_vel
class PDController(BaseController):
"""
A PD-Controller. Using position and velocity information from a provided positional environment,
the controller calculates a response based on the desired position and velocity
:param env: A position environment
:param p_gains: Factors for the proportional gains
:param d_gains: Factors for the differential gains
"""
def __init__(self,
env: PositionalEnv,
p_gains: Union[float, Tuple],
d_gains: Union[float, Tuple]):
self.p_gains = p_gains
self.d_gains = d_gains
super(PDController, self).__init__(env, )
def get_action(self, des_pos, des_vel):
cur_pos = self.env.current_pos
cur_vel = self.env.current_vel
assert des_pos.shape != cur_pos.shape, \
"Mismatch in dimension between desired position {} and current position {}".format(des_pos.shape, cur_pos.shape)
assert des_vel.shape != cur_vel.shape, \
"Mismatch in dimension between desired velocity {} and current velocity {}".format(des_vel.shape,
cur_vel.shape)
2021-02-05 17:10:03 +01:00
trq = self.p_gains * (des_pos - cur_pos) + self.d_gains * (des_vel - cur_vel)
return trq
2021-02-15 16:31:34 +01:00
def get_policy_class(policy_type, env, mp_kwargs, **kwargs):
2021-02-15 16:31:34 +01:00
if policy_type == "motor":
return PDController(env, p_gains=mp_kwargs['p_gains'], d_gains=mp_kwargs['d_gains'])
2021-02-15 16:31:34 +01:00
elif policy_type == "velocity":
return VelController(env)
2021-02-15 16:31:34 +01:00
elif policy_type == "position":
return PosController(env)