Add interface for envs controlable by a PD Controller and add more infos to mp_wrapper info return value
This commit is contained in:
parent
746d408a76
commit
4279414656
@ -1,3 +1,4 @@
|
||||
import numpy as np
|
||||
from gym.envs.registration import register
|
||||
|
||||
from alr_envs.stochastic_search.functions.f_rosenbrock import Rosenbrock
|
||||
@ -321,7 +322,9 @@ register(
|
||||
"bandwidth_factor": 2.5,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 100,
|
||||
"return_to_start": True
|
||||
"return_to_start": True,
|
||||
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||
}
|
||||
)
|
||||
|
||||
@ -339,7 +342,9 @@ register(
|
||||
"bandwidth_factor": 2.5,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 100,
|
||||
"return_to_start": True
|
||||
"return_to_start": True,
|
||||
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||
}
|
||||
)
|
||||
|
||||
@ -357,7 +362,9 @@ register(
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"zero_goal": True
|
||||
"zero_goal": True,
|
||||
"p_gains": np.array([4./3., 2.4, 2.5, 5./3., 2., 2., 1.25]),
|
||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||
}
|
||||
)
|
||||
|
||||
@ -374,7 +381,9 @@ register(
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 0.2,
|
||||
"zero_start": True,
|
||||
"zero_goal": True
|
||||
"zero_goal": True,
|
||||
"p_gains": np.array([4./3., 2.4, 2.5, 5./3., 2., 2., 1.25]),
|
||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||
}
|
||||
)
|
||||
|
||||
@ -392,7 +401,9 @@ register(
|
||||
"bandwidth_factor": 2.5,
|
||||
"policy_type": "motor",
|
||||
"weights_scale": 50,
|
||||
"goal_scale": 0.1
|
||||
"goal_scale": 0.1,
|
||||
"p_gains": np.array([4. / 3., 2.4, 2.5, 5. / 3., 2., 2., 1.25]),
|
||||
"d_gains": np.array([0.0466, 0.12, 0.125, 0.04166, 0.06, 0.06, 0.025])
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -7,7 +7,7 @@ from gym.utils import seeding
|
||||
from matplotlib import patches
|
||||
|
||||
from alr_envs.classic_control.utils import check_self_collision
|
||||
from alr_envs.utils.mps.mp_environments import AlrEnv
|
||||
from alr_envs.utils.mps.alr_env import AlrEnv
|
||||
|
||||
|
||||
class HoleReacherEnv(AlrEnv):
|
||||
|
@ -5,7 +5,7 @@ import numpy as np
|
||||
from gym import spaces
|
||||
from gym.utils import seeding
|
||||
|
||||
from alr_envs.utils.mps.mp_environments import AlrEnv
|
||||
from alr_envs.utils.mps.alr_env import AlrEnv
|
||||
|
||||
|
||||
class SimpleReacherEnv(AlrEnv):
|
||||
|
@ -6,7 +6,7 @@ import numpy as np
|
||||
from gym.utils import seeding
|
||||
|
||||
from alr_envs.classic_control.utils import check_self_collision
|
||||
from alr_envs.utils.mps.mp_environments import AlrEnv
|
||||
from alr_envs.utils.mps.alr_env import AlrEnv
|
||||
|
||||
|
||||
class ViaPointReacher(AlrEnv):
|
||||
|
@ -7,7 +7,9 @@ from gym import error, spaces
|
||||
from gym.utils import seeding
|
||||
import numpy as np
|
||||
from os import path
|
||||
import gym
|
||||
|
||||
from alr_envs.utils.mps.alr_env import AlrEnv
|
||||
from alr_envs.utils.positional_env import PositionalEnv
|
||||
|
||||
try:
|
||||
import mujoco_py
|
||||
@ -33,7 +35,7 @@ def convert_observation_to_space(observation):
|
||||
return space
|
||||
|
||||
|
||||
class AlrMujocoEnv(gym.Env):
|
||||
class AlrMujocoEnv(PositionalEnv, AlrEnv):
|
||||
"""
|
||||
Superclass for all MuJoCo environments.
|
||||
"""
|
||||
@ -44,7 +46,7 @@ class AlrMujocoEnv(gym.Env):
|
||||
Args:
|
||||
model_path: path to xml file
|
||||
n_substeps: how many steps mujoco does per call to env.step
|
||||
use_servo: use actuator defined in xml, use False for direct torque control
|
||||
apply_gravity_comp: Whether gravity compensation should be active
|
||||
"""
|
||||
if model_path.startswith("/"):
|
||||
fullpath = model_path
|
||||
@ -73,10 +75,6 @@ class AlrMujocoEnv(gym.Env):
|
||||
|
||||
self._set_action_space()
|
||||
|
||||
# action = self.action_space.sample()
|
||||
# observation, _reward, done, _info = self.step(action)
|
||||
# assert not done
|
||||
|
||||
observation = self._get_obs() # TODO: is calling get_obs enough? should we call reset, or even step?
|
||||
|
||||
self._set_observation_space(observation)
|
||||
@ -204,7 +202,7 @@ class AlrMujocoEnv(gym.Env):
|
||||
|
||||
try:
|
||||
self.sim.step()
|
||||
except mujoco_py.builder.MujocoException as e:
|
||||
except mujoco_py.builder.MujocoException:
|
||||
error_in_sim = True
|
||||
|
||||
return error_in_sim
|
||||
|
@ -16,8 +16,6 @@ class ALRBallInACupEnv(alr_mujoco_env.AlrMujocoEnv, utils.EzPickle):
|
||||
self._q_vel = []
|
||||
# self.weight_matrix_scale = 50
|
||||
self.max_ctrl = np.array([150., 125., 40., 60., 5., 5., 2.])
|
||||
self.p_gains = 1 / self.max_ctrl * np.array([200, 300, 100, 100, 10, 10, 2.5])
|
||||
self.d_gains = 1 / self.max_ctrl * np.array([7, 15, 5, 2.5, 0.3, 0.3, 0.05])
|
||||
|
||||
self.j_min = np.array([-2.6, -1.985, -2.8, -0.9, -4.55, -1.5707, -2.7])
|
||||
self.j_max = np.array([2.6, 1.985, 2.8, 3.14159, 1.25, 1.5707, 2.7])
|
||||
|
@ -1,14 +1,13 @@
|
||||
from abc import abstractmethod
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AlrEnv(gym.Env):
|
||||
class AlrEnv(gym.Env, ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def active_obs(self):
|
||||
"""Returns boolean mask for each observation entry
|
||||
whether the observation is returned for the contextual case or not.
|
||||
@ -31,3 +30,11 @@ class AlrEnv(gym.Env):
|
||||
By default this returns the starting position.
|
||||
"""
|
||||
return self.start_pos
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dt(self) -> Union[float, int]:
|
||||
"""
|
||||
Returns the time between two simulated steps of the environment
|
||||
"""
|
||||
raise NotImplementedError()
|
@ -2,29 +2,36 @@ import gym
|
||||
import numpy as np
|
||||
from mp_lib import det_promp
|
||||
|
||||
from alr_envs.utils.mps.mp_environments import AlrEnv
|
||||
from alr_envs.utils.mps.alr_env import AlrEnv
|
||||
from alr_envs.utils.mps.mp_wrapper import MPWrapper
|
||||
|
||||
|
||||
class DetPMPWrapper(MPWrapper):
|
||||
def __init__(self, env: AlrEnv, num_dof: int, num_basis: int, width: float, duration: float = 1, dt: float = 0.01,
|
||||
post_traj_time: float = 0., policy_type: str = None, weights_scale: float = 1.,
|
||||
zero_start: bool = False, zero_goal: bool = False, **mp_kwargs):
|
||||
def __init__(self, env: AlrEnv, num_dof, num_basis, width, start_pos=None, duration=1, post_traj_time=0.,
|
||||
policy_type=None, weights_scale=1, zero_start=False, zero_goal=False, learn_mp_length: bool =True,
|
||||
**mp_kwargs):
|
||||
self.duration = duration # seconds
|
||||
|
||||
dt = env.dt if hasattr(env, "dt") else dt
|
||||
assert dt is not None
|
||||
self.dt = dt
|
||||
super().__init__(env=env, num_dof=num_dof, duration=duration, post_traj_time=post_traj_time,
|
||||
policy_type=policy_type, weights_scale=weights_scale, num_basis=num_basis,
|
||||
width=width, zero_start=zero_start, zero_goal=zero_goal,
|
||||
**mp_kwargs)
|
||||
|
||||
super().__init__(env, num_dof, dt, duration, post_traj_time, policy_type, weights_scale, num_basis=num_basis,
|
||||
width=width, zero_start=zero_start, zero_goal=zero_goal, **mp_kwargs)
|
||||
self.learn_mp_length = learn_mp_length
|
||||
if self.learn_mp_length:
|
||||
parameter_space_shape = (1+num_basis*num_dof,)
|
||||
else:
|
||||
parameter_space_shape = (num_basis * num_dof,)
|
||||
self.min_param = -np.inf
|
||||
self.max_param = np.inf
|
||||
self.parameterization_space = gym.spaces.Box(low=self.min_param, high=self.max_param,
|
||||
shape=parameter_space_shape, dtype=np.float32)
|
||||
|
||||
action_bounds = np.inf * np.ones((self.mp.n_basis * self.mp.n_dof))
|
||||
self.action_space = gym.spaces.Box(low=-action_bounds, high=action_bounds, dtype=np.float32)
|
||||
self.start_pos = start_pos
|
||||
|
||||
def initialize_mp(self, num_dof: int, duration: int, dt: float, num_basis: int = 5, width: float = None,
|
||||
off: float = 0.01, zero_start: bool = False, zero_goal: bool = False):
|
||||
pmp = det_promp.DeterministicProMP(n_basis=num_basis, n_dof=num_dof, width=width, off=off,
|
||||
def initialize_mp(self, num_dof: int, duration: int, num_basis: int = 5, width: float = None,
|
||||
zero_start: bool = False, zero_goal: bool = False, **kwargs):
|
||||
pmp = det_promp.DeterministicProMP(n_basis=num_basis, n_dof=num_dof, width=width, off=0.01,
|
||||
zero_start=zero_start, zero_goal=zero_goal)
|
||||
|
||||
weights = np.zeros(shape=(num_basis, num_dof))
|
||||
@ -33,10 +40,15 @@ class DetPMPWrapper(MPWrapper):
|
||||
return pmp
|
||||
|
||||
def mp_rollout(self, action):
|
||||
params = np.reshape(action, newshape=(self.mp.n_basis, self.mp.n_dof)) * self.weights_scale
|
||||
self.mp.set_weights(self.duration, params)
|
||||
_, des_pos, des_vel, _ = self.mp.compute_trajectory(1 / self.dt, 1.)
|
||||
if self.learn_mp_length:
|
||||
duration = max(1, self.duration*abs(action[0]))
|
||||
params = np.reshape(action[1:], (self.mp.n_basis, -1)) * self.weights_scale # TODO: Fix Bug when zero_start is true
|
||||
else:
|
||||
duration = self.duration
|
||||
params = np.reshape(action, (self.mp.n_basis, -1)) * self.weights_scale # TODO: Fix Bug when zero_start is true
|
||||
self.mp.set_weights(1., params)
|
||||
_, des_pos, des_vel, _ = self.mp.compute_trajectory(frequency=max(1, duration))
|
||||
if self.mp.zero_start:
|
||||
des_pos += self.env.start_pos[None, :]
|
||||
des_pos += self.start_pos
|
||||
|
||||
return des_pos, des_vel
|
||||
|
@ -4,7 +4,7 @@ from mp_lib import dmps
|
||||
from mp_lib.basis import DMPBasisGenerator
|
||||
from mp_lib.phase import ExpDecayPhaseGenerator
|
||||
|
||||
from alr_envs.utils.mps.mp_environments import AlrEnv
|
||||
from alr_envs.utils.mps.alr_env import AlrEnv
|
||||
from alr_envs.utils.mps.mp_wrapper import MPWrapper
|
||||
|
||||
|
||||
@ -32,26 +32,26 @@ class DmpWrapper(MPWrapper):
|
||||
goal_scale:
|
||||
"""
|
||||
self.learn_goal = learn_goal
|
||||
dt = env.dt if hasattr(env, "dt") else dt
|
||||
assert dt is not None
|
||||
|
||||
self.t = np.linspace(0, duration, int(duration / dt))
|
||||
self.goal_scale = goal_scale
|
||||
|
||||
super().__init__(env, num_dof, dt, duration, post_traj_time, policy_type, weights_scale, render_mode,
|
||||
super().__init__(env=env, num_dof=num_dof, duration=duration, post_traj_time=post_traj_time,
|
||||
policy_type=policy_type, weights_scale=weights_scale, render_mode=render_mode,
|
||||
num_basis=num_basis, alpha_phase=alpha_phase, bandwidth_factor=bandwidth_factor)
|
||||
|
||||
action_bounds = np.inf * np.ones((np.prod(self.mp.weights.shape) + (num_dof if learn_goal else 0)))
|
||||
self.action_space = gym.spaces.Box(low=-action_bounds, high=action_bounds, dtype=np.float32)
|
||||
|
||||
def initialize_mp(self, num_dof: int, duration: int, dt: float, num_basis: int = 5, alpha_phase: float = 2.,
|
||||
bandwidth_factor: int = 3):
|
||||
def initialize_mp(self, num_dof: int, duration: int, num_basis: int, alpha_phase: float = 2.,
|
||||
bandwidth_factor: int = 3, **kwargs):
|
||||
|
||||
phase_generator = ExpDecayPhaseGenerator(alpha_phase=alpha_phase, duration=duration)
|
||||
basis_generator = DMPBasisGenerator(phase_generator, duration=duration, num_basis=num_basis,
|
||||
basis_bandwidth_factor=bandwidth_factor)
|
||||
|
||||
dmp = dmps.DMP(num_dof=num_dof, basis_generator=basis_generator, phase_generator=phase_generator,
|
||||
duration=duration, dt=dt)
|
||||
dt=self.dt)
|
||||
|
||||
return dmp
|
||||
|
||||
|
@ -1,16 +1,36 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from alr_envs.utils.mps.mp_environments import AlrEnv
|
||||
from alr_envs.utils.policies import get_policy_class
|
||||
from alr_envs.utils.mps.alr_env import AlrEnv
|
||||
from alr_envs.utils.policies import get_policy_class, BaseController
|
||||
|
||||
|
||||
class MPWrapper(gym.Wrapper, ABC):
|
||||
"""
|
||||
Base class for movement primitive based gym.Wrapper implementations.
|
||||
|
||||
def __init__(self, env: AlrEnv, num_dof: int, dt: float, duration: float = 1, post_traj_time: float = 0.,
|
||||
policy_type: str = None, weights_scale: float = 1., render_mode: str = None, **mp_kwargs):
|
||||
:param env: The (wrapped) environment this wrapper is applied on
|
||||
:param num_dof: Dimension of the action space of the wrapped env
|
||||
:param duration: Number of timesteps in the trajectory of the movement primitive
|
||||
:param post_traj_time: Time for which the last position of the trajectory is fed to the environment to continue
|
||||
simulation
|
||||
:param policy_type: Type or object defining the policy that is used to generate action based on the trajectory
|
||||
:param weight_scale: Scaling parameter for the actions given to this wrapper
|
||||
:param render_mode: Equivalent to gym render mode
|
||||
"""
|
||||
def __init__(self,
|
||||
env: AlrEnv,
|
||||
num_dof: int,
|
||||
duration: int = 1,
|
||||
post_traj_time: float = 0.,
|
||||
policy_type: Union[str, BaseController] = None,
|
||||
weights_scale: float = 1.,
|
||||
render_mode: str = None,
|
||||
**mp_kwargs
|
||||
):
|
||||
super().__init__(env)
|
||||
|
||||
# adjust observation space to reduce version
|
||||
@ -19,14 +39,15 @@ class MPWrapper(gym.Wrapper, ABC):
|
||||
high=obs_sp.high[self.env.active_obs],
|
||||
dtype=obs_sp.dtype)
|
||||
|
||||
assert dt is not None # this should never happen as MPWrapper is a base class
|
||||
self.post_traj_steps = int(post_traj_time / dt)
|
||||
self.post_traj_steps = int(post_traj_time / env.dt)
|
||||
|
||||
self.mp = self.initialize_mp(num_dof, duration, dt, **mp_kwargs)
|
||||
self.mp = self.initialize_mp(num_dof=num_dof, duration=duration, **mp_kwargs)
|
||||
self.weights_scale = weights_scale
|
||||
|
||||
policy_class = get_policy_class(policy_type)
|
||||
self.policy = policy_class(env)
|
||||
if type(policy_type) is str:
|
||||
self.policy = get_policy_class(policy_type, env, mp_kwargs)
|
||||
else:
|
||||
self.policy = policy_type
|
||||
|
||||
# rendering
|
||||
self.render_mode = render_mode
|
||||
@ -62,29 +83,30 @@ class MPWrapper(gym.Wrapper, ABC):
|
||||
|
||||
if self.post_traj_steps > 0:
|
||||
trajectory = np.vstack([trajectory, np.tile(trajectory[-1, :], [self.post_traj_steps, 1])])
|
||||
velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.mp.n_dof))])
|
||||
|
||||
# self._trajectory = trajectory
|
||||
# self._velocity = velocity
|
||||
|
||||
rewards = 0
|
||||
info = {}
|
||||
# create random obs as the reset function is called externally
|
||||
obs = self.env.observation_space.sample()
|
||||
velocity = np.vstack([velocity, np.zeros(shape=(self.post_traj_steps, self.mp.num_dimensions))])
|
||||
|
||||
trajectory_length = len(trajectory)
|
||||
actions = np.zeros(shape=(trajectory_length, self.mp.num_dimensions))
|
||||
observations= np.zeros(shape=(trajectory_length,) + self.env.observation_space.shape)
|
||||
rewards = np.zeros(shape=(trajectory_length,))
|
||||
trajectory_return = 0
|
||||
infos = dict(step_infos =[])
|
||||
for t, pos_vel in enumerate(zip(trajectory, velocity)):
|
||||
ac = self.policy.get_action(pos_vel[0], pos_vel[1])
|
||||
obs, rew, done, info = self.env.step(ac)
|
||||
rewards += rew
|
||||
# TODO return all dicts?
|
||||
# [infos[k].append(v) for k, v in info.items()]
|
||||
actions[t,:] = self.policy.get_action(pos_vel[0], pos_vel[1])
|
||||
observations[t,:], rewards[t], done, info = self.env.step(actions[t,:])
|
||||
trajectory_return += rewards[t]
|
||||
infos['step_infos'].append(info)
|
||||
if self.render_mode:
|
||||
self.env.render(mode=self.render_mode, **self.render_kwargs)
|
||||
if done:
|
||||
break
|
||||
|
||||
infos['step_actions'] = actions[:t+1]
|
||||
infos['step_observations'] = observations[:t+1]
|
||||
infos['step_rewards'] = rewards[:t+1]
|
||||
infos['trajectory_length'] = t+1
|
||||
done = True
|
||||
return obs[self.env.active_obs], rewards, done, info
|
||||
return observations[t][self.env.active_obs], trajectory_return, done, infos
|
||||
|
||||
def render(self, mode='human', **kwargs):
|
||||
"""Only set render options here, such that they can be used during the rollout.
|
||||
@ -102,7 +124,7 @@ class MPWrapper(gym.Wrapper, ABC):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def initialize_mp(self, num_dof: int, duration: float, dt: float, **kwargs):
|
||||
def initialize_mp(self, num_dof: int, duration: float, **kwargs):
|
||||
"""
|
||||
Create respective instance of MP
|
||||
Returns:
|
||||
|
@ -1,10 +1,12 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
from gym import Env
|
||||
|
||||
from alr_envs.mujoco.alr_mujoco_env import AlrMujocoEnv
|
||||
from alr_envs.utils.positional_env import PositionalEnv
|
||||
|
||||
|
||||
class BaseController:
|
||||
def __init__(self, env: Env):
|
||||
def __init__(self, env: Env, **kwargs):
|
||||
self.env = env
|
||||
|
||||
def get_action(self, des_pos, des_vel):
|
||||
@ -22,27 +24,38 @@ class VelController(BaseController):
|
||||
|
||||
|
||||
class PDController(BaseController):
|
||||
def __init__(self, env: AlrMujocoEnv):
|
||||
self.p_gains = env.p_gains
|
||||
self.d_gains = env.d_gains
|
||||
super(PDController, self).__init__(env)
|
||||
"""
|
||||
A PD-Controller. Using position and velocity information from a provided positional environment,
|
||||
the controller calculates a response based on the desired position and velocity
|
||||
|
||||
:param env: A position environment
|
||||
:param p_gains: Factors for the proportional gains
|
||||
:param d_gains: Factors for the differential gains
|
||||
"""
|
||||
def __init__(self,
|
||||
env: PositionalEnv,
|
||||
p_gains: Union[float, Tuple],
|
||||
d_gains: Union[float, Tuple]):
|
||||
self.p_gains = p_gains
|
||||
self.d_gains = d_gains
|
||||
super(PDController, self).__init__(env, )
|
||||
|
||||
def get_action(self, des_pos, des_vel):
|
||||
# TODO: make standardized ALRenv such that all of them have current_pos/vel attributes
|
||||
cur_pos = self.env.current_pos
|
||||
cur_vel = self.env.current_vel
|
||||
if len(des_pos) != len(cur_pos):
|
||||
des_pos = self.env.extend_des_pos(des_pos)
|
||||
if len(des_vel) != len(cur_vel):
|
||||
des_vel = self.env.extend_des_vel(des_vel)
|
||||
assert des_pos.shape != cur_pos.shape, \
|
||||
"Mismatch in dimension between desired position {} and current position {}".format(des_pos.shape, cur_pos.shape)
|
||||
assert des_vel.shape != cur_vel.shape, \
|
||||
"Mismatch in dimension between desired velocity {} and current velocity {}".format(des_vel.shape,
|
||||
cur_vel.shape)
|
||||
trq = self.p_gains * (des_pos - cur_pos) + self.d_gains * (des_vel - cur_vel)
|
||||
return trq
|
||||
|
||||
|
||||
def get_policy_class(policy_type):
|
||||
def get_policy_class(policy_type, env, mp_kwargs, **kwargs):
|
||||
if policy_type == "motor":
|
||||
return PDController
|
||||
return PDController(env, p_gains=mp_kwargs['p_gains'], d_gains=mp_kwargs['d_gains'])
|
||||
elif policy_type == "velocity":
|
||||
return VelController
|
||||
return VelController(env)
|
||||
elif policy_type == "position":
|
||||
return PosController
|
||||
return PosController(env)
|
||||
|
22
alr_envs/utils/positional_env.py
Normal file
22
alr_envs/utils/positional_env.py
Normal file
@ -0,0 +1,22 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Union, Tuple
|
||||
|
||||
import numpy as np
|
||||
from gym import Env
|
||||
|
||||
class PositionalEnv(Env):
|
||||
"""A position and velocity based environment. It functions just as any regular OpenAI Gym
|
||||
environment but it provides position, velocity and acceleration information. This usually means that the
|
||||
corresponding information from the agent is forwarded via the properties.
|
||||
PD-Controller based policies require this environment to calculate the state dependent actions for example.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||
raise NotImplementedError
|
Loading…
Reference in New Issue
Block a user