Adopt new interface structure
This commit is contained in:
parent
37c4ab8a41
commit
0046ade102
@ -573,7 +573,7 @@ register(
|
|||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "gym.envs.classic_control:MountainCarContinuous-v0",
|
"name": "gym.envs.classic_control:MountainCarContinuous-v0",
|
||||||
"wrappers": [continuous_mountain_car.PositionalWrapper, continuous_mountain_car.MPWrapper],
|
"wrappers": [continuous_mountain_car.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 1,
|
"num_dof": 1,
|
||||||
"num_basis": 4,
|
"num_basis": 4,
|
||||||
@ -594,7 +594,7 @@ register(
|
|||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "gym.envs.mujoco:Reacher-v2",
|
"name": "gym.envs.mujoco:Reacher-v2",
|
||||||
"wrappers": [reacher_v2.PositionalWrapper, reacher_v2.MPWrapper],
|
"wrappers": [reacher_v2.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 2,
|
"num_dof": 2,
|
||||||
"num_basis": 6,
|
"num_basis": 6,
|
||||||
@ -615,7 +615,7 @@ register(
|
|||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "gym.envs.robotics:FetchSlideDense-v1",
|
"name": "gym.envs.robotics:FetchSlideDense-v1",
|
||||||
"wrappers": [fetch.PositionalWrapper, fetch.MPWrapper],
|
"wrappers": [fetch.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 4,
|
"num_dof": 4,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
@ -632,7 +632,7 @@ register(
|
|||||||
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_detpmp_env_helper',
|
||||||
kwargs={
|
kwargs={
|
||||||
"name": "gym.envs.robotics:FetchReachDense-v1",
|
"name": "gym.envs.robotics:FetchReachDense-v1",
|
||||||
"wrappers": [fetch.PositionalWrapper, fetch.MPWrapper],
|
"wrappers": [fetch.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 4,
|
"num_dof": 4,
|
||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
|
@ -1,2 +1 @@
|
|||||||
from alr_envs.open_ai.continuous_mountain_car.positional_wrapper import PositionalWrapper
|
|
||||||
from alr_envs.open_ai.continuous_mountain_car.mp_wrapper import MPWrapper
|
from alr_envs.open_ai.continuous_mountain_car.mp_wrapper import MPWrapper
|
@ -1,12 +1,17 @@
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from mp_env_api.env_wrappers.mp_env_wrapper import MPEnvWrapper
|
from mp_env_api.env_wrappers.mp_env_wrapper import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
class MPWrapper(MPEnvWrapper):
|
class MPWrapper(MPEnvWrapper):
|
||||||
@property
|
@property
|
||||||
def start_pos(self):
|
def current_vel(self) -> Union[float, int, np.ndarray]:
|
||||||
raise ValueError("Start position is not available")
|
return np.array([self.state[1]])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
|
return np.array([self.state[0]])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def goal_pos(self):
|
def goal_pos(self):
|
||||||
|
@ -1,13 +0,0 @@
|
|||||||
from typing import Union
|
|
||||||
import numpy as np
|
|
||||||
from mp_env_api.env_wrappers.positional_env_wrapper import PositionalEnvWrapper
|
|
||||||
|
|
||||||
|
|
||||||
class PositionalWrapper(PositionalEnvWrapper):
|
|
||||||
@property
|
|
||||||
def current_vel(self) -> Union[float, int, np.ndarray]:
|
|
||||||
return np.array([self.state[1]])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray]:
|
|
||||||
return np.array([self.state[0]])
|
|
@ -1,2 +1 @@
|
|||||||
from alr_envs.open_ai.fetch.positional_wrapper import PositionalWrapper
|
|
||||||
from alr_envs.open_ai.fetch.mp_wrapper import MPWrapper
|
from alr_envs.open_ai.fetch.mp_wrapper import MPWrapper
|
@ -1,13 +1,17 @@
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
from gym import spaces
|
import numpy as np
|
||||||
from mp_env_api.env_wrappers.mp_env_wrapper import MPEnvWrapper
|
from mp_env_api.env_wrappers.mp_env_wrapper import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
class MPWrapper(MPEnvWrapper):
|
class MPWrapper(MPEnvWrapper):
|
||||||
@property
|
@property
|
||||||
def start_pos(self):
|
def current_vel(self) -> Union[float, int, np.ndarray]:
|
||||||
return self.initial_gripper_xpos
|
return self.unwrapped._get_obs()["observation"][-5:-1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
|
return self.unwrapped._get_obs()["observation"][:4]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def goal_pos(self):
|
def goal_pos(self):
|
||||||
|
@ -1,13 +0,0 @@
|
|||||||
from typing import Union
|
|
||||||
import numpy as np
|
|
||||||
from mp_env_api.env_wrappers.positional_env_wrapper import PositionalEnvWrapper
|
|
||||||
|
|
||||||
|
|
||||||
class PositionalWrapper(PositionalEnvWrapper):
|
|
||||||
@property
|
|
||||||
def current_vel(self) -> Union[float, int, np.ndarray]:
|
|
||||||
return self.unwrapped._get_obs()["observation"][-5:-1]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray]:
|
|
||||||
return self.unwrapped._get_obs()["observation"][:4]
|
|
@ -1,2 +1 @@
|
|||||||
from alr_envs.open_ai.reacher_v2.positional_wrapper import PositionalWrapper
|
|
||||||
from alr_envs.open_ai.reacher_v2.mp_wrapper import MPWrapper
|
from alr_envs.open_ai.reacher_v2.mp_wrapper import MPWrapper
|
@ -1,13 +1,18 @@
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
from mp_env_api.env_wrappers.mp_env_wrapper import MPEnvWrapper
|
from mp_env_api.env_wrappers.mp_env_wrapper import MPEnvWrapper
|
||||||
|
|
||||||
|
|
||||||
class MPWrapper(MPEnvWrapper):
|
class MPWrapper(MPEnvWrapper):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def start_pos(self):
|
def current_vel(self) -> Union[float, int, np.ndarray]:
|
||||||
raise ValueError("Start position is not available")
|
return self.sim.data.qvel[:2]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
|
return self.sim.data.qpos[:2]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def goal_pos(self):
|
def goal_pos(self):
|
||||||
|
@ -1,13 +0,0 @@
|
|||||||
from typing import Union
|
|
||||||
import numpy as np
|
|
||||||
from mp_env_api.env_wrappers.positional_env_wrapper import PositionalEnvWrapper
|
|
||||||
|
|
||||||
|
|
||||||
class PositionalWrapper(PositionalEnvWrapper):
|
|
||||||
@property
|
|
||||||
def current_vel(self) -> Union[float, int, np.ndarray]:
|
|
||||||
return self.sim.data.qvel[:2]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray]:
|
|
||||||
return self.sim.data.qpos[:2]
|
|
Loading…
Reference in New Issue
Block a user