clean up open_ai envs
This commit is contained in:
parent
2706af0b77
commit
4a3134d7be
@ -1,85 +1,48 @@
|
|||||||
from gym import register
|
from gym import register
|
||||||
from gym.wrappers import FlattenObservation
|
from copy import deepcopy
|
||||||
|
|
||||||
from . import classic_control, mujoco, robotics
|
from . import mujoco
|
||||||
|
from .deprecated_needs_gym_robotics import robotics
|
||||||
|
|
||||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
||||||
|
|
||||||
# Short Continuous Mountain Car
|
DEFAULT_BB_DICT_ProMP = {
|
||||||
register(
|
"name": 'EnvName',
|
||||||
id="MountainCarContinuous-v1",
|
"wrappers": [],
|
||||||
entry_point="gym.envs.classic_control:Continuous_MountainCarEnv",
|
"trajectory_generator_kwargs": {
|
||||||
max_episode_steps=100,
|
'trajectory_generator_type': 'promp'
|
||||||
reward_threshold=90.0,
|
},
|
||||||
)
|
"phase_generator_kwargs": {
|
||||||
|
'phase_generator_type': 'linear'
|
||||||
# Open AI
|
},
|
||||||
# Classic Control
|
"controller_kwargs": {
|
||||||
register(
|
'controller_type': 'motor',
|
||||||
id='ContinuousMountainCarProMP-v1',
|
"p_gains": 1.0,
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
"d_gains": 0.1,
|
||||||
kwargs={
|
},
|
||||||
"name": "alr_envs:MountainCarContinuous-v1",
|
"basis_generator_kwargs": {
|
||||||
"wrappers": [classic_control.continuous_mountain_car.MPWrapper],
|
'basis_generator_type': 'zero_rbf',
|
||||||
"traj_gen_kwargs": {
|
'num_basis': 5,
|
||||||
"num_dof": 1,
|
'num_basis_zero_start': 1
|
||||||
"num_basis": 4,
|
|
||||||
"duration": 2,
|
|
||||||
"post_traj_time": 0,
|
|
||||||
"zero_start": True,
|
|
||||||
"policy_type": "motor",
|
|
||||||
"policy_kwargs": {
|
|
||||||
"p_gains": 1.,
|
|
||||||
"d_gains": 1.
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
)
|
}
|
||||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ContinuousMountainCarProMP-v1")
|
|
||||||
|
|
||||||
register(
|
|
||||||
id='ContinuousMountainCarProMP-v0',
|
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
|
||||||
kwargs={
|
|
||||||
"name": "gym.envs.classic_control:MountainCarContinuous-v0",
|
|
||||||
"wrappers": [classic_control.continuous_mountain_car.MPWrapper],
|
|
||||||
"traj_gen_kwargs": {
|
|
||||||
"num_dof": 1,
|
|
||||||
"num_basis": 4,
|
|
||||||
"duration": 19.98,
|
|
||||||
"post_traj_time": 0,
|
|
||||||
"zero_start": True,
|
|
||||||
"policy_type": "motor",
|
|
||||||
"policy_kwargs": {
|
|
||||||
"p_gains": 1.,
|
|
||||||
"d_gains": 1.
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ContinuousMountainCarProMP-v0")
|
|
||||||
|
|
||||||
|
kwargs_dict_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
|
kwargs_dict_reacher_promp['controller_kwargs']['p_gains'] = 0.6
|
||||||
|
kwargs_dict_reacher_promp['controller_kwargs']['d_gains'] = 0.075
|
||||||
|
kwargs_dict_reacher_promp['basis_generator_kwargs']['num_basis'] = 6
|
||||||
|
kwargs_dict_reacher_promp['name'] = "Reacher-v2"
|
||||||
|
kwargs_dict_reacher_promp['wrappers'].append(mujoco.reacher_v2.MPWrapper)
|
||||||
register(
|
register(
|
||||||
id='ReacherProMP-v2',
|
id='Reacher2dProMP-v2',
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||||
kwargs={
|
kwargs=kwargs_dict_reacher_promp
|
||||||
"name": "gym.envs.mujoco:Reacher-v2",
|
|
||||||
"wrappers": [mujoco.reacher_v2.MPWrapper],
|
|
||||||
"traj_gen_kwargs": {
|
|
||||||
"num_dof": 2,
|
|
||||||
"num_basis": 6,
|
|
||||||
"duration": 1,
|
|
||||||
"post_traj_time": 0,
|
|
||||||
"zero_start": True,
|
|
||||||
"policy_type": "motor",
|
|
||||||
"policy_kwargs": {
|
|
||||||
"p_gains": .6,
|
|
||||||
"d_gains": .075
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ReacherProMP-v2")
|
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ReacherProMP-v2")
|
||||||
|
"""
|
||||||
|
The Fetch environments are not supported by gym anymore. A new repository (gym_robotics) is supporting the environments.
|
||||||
|
However, the usage and so on needs to be checked
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='FetchSlideDenseProMP-v1',
|
id='FetchSlideDenseProMP-v1',
|
||||||
@ -152,3 +115,4 @@ register(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("FetchReachProMP-v1")
|
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("FetchReachProMP-v1")
|
||||||
|
"""
|
||||||
|
@ -1 +0,0 @@
|
|||||||
from . import continuous_mountain_car
|
|
@ -1,23 +0,0 @@
|
|||||||
from typing import Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
|
||||||
|
|
||||||
|
|
||||||
class MPWrapper(RawInterfaceWrapper):
|
|
||||||
@property
|
|
||||||
def current_vel(self) -> Union[float, int, np.ndarray]:
|
|
||||||
return np.array([self.state[1]])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray]:
|
|
||||||
return np.array([self.state[0]])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def goal_pos(self):
|
|
||||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dt(self) -> Union[float, int]:
|
|
||||||
return 0.02
|
|
@ -14,3 +14,13 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
@property
|
@property
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray]:
|
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
return self.sim.data.qpos[:2]
|
return self.sim.data.qpos[:2]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def context_mask(self):
|
||||||
|
return np.concatenate([
|
||||||
|
[False] * 2, # cos of two links
|
||||||
|
[False] * 2, # sin of two links
|
||||||
|
[True] * 2, # goal position
|
||||||
|
[False] * 2, # angular velocity
|
||||||
|
[False] * 3, # goal distance
|
||||||
|
])
|
||||||
|
@ -1 +0,0 @@
|
|||||||
from .mp_wrapper import MPWrapper
|
|
Loading…
Reference in New Issue
Block a user