diff --git a/alr_envs/dmc/suite/cartpole/__init__.py b/alr_envs/dmc/suite/cartpole/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/alr_envs/dmc/suite/cartpole/cartpole_mp_wrapper.py b/alr_envs/dmc/suite/cartpole/cartpole_mp_wrapper.py new file mode 100644 index 0000000..d8f8493 --- /dev/null +++ b/alr_envs/dmc/suite/cartpole/cartpole_mp_wrapper.py @@ -0,0 +1,51 @@ +from typing import Tuple, Union + +import numpy as np + +from mp_env_api.interface_wrappers.mp_env_wrapper import MPEnvWrapper + + +class DMCCartpoleMPWrapper(MPEnvWrapper): + + def __init__(self, env, n_poles: int = 1): + self.n_poles = n_poles + super().__init__(env) + + + @property + def active_obs(self): + # Besides the ball position, the environment is always set to 0. + return np.hstack([ + [True], # slider position + [True] * 2 * self.n_poles, # sin/cos hinge angles + [True], # slider velocity + [True] * self.n_poles, # hinge velocities + ]) + + @property + def current_pos(self) -> Union[float, int, np.ndarray]: + return self.env.physics.named.data.qpos["slider"] + + @property + def current_vel(self) -> Union[float, int, np.ndarray, Tuple]: + return self.env.physics.named.data.qvel["slider"] + + @property + def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]: + raise ValueError("Goal position is not available and has to be learnt based on the environment.") + + @property + def dt(self) -> Union[float, int]: + return self.env.dt + + +class DMCCartpoleTwoPolesMPWrapper(DMCCartpoleMPWrapper): + + def __init__(self, env): + super().__init__(env, n_poles=2) + + +class DMCCartpoleThreePolesMPWrapper(DMCCartpoleMPWrapper): + + def __init__(self, env): + super().__init__(env, n_poles=3)