bugfixes
This commit is contained in:
parent
4a3134d7be
commit
fc00cf8a87
@ -157,60 +157,36 @@ register(
|
|||||||
id='ALRAntJump-v0',
|
id='ALRAntJump-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:AntJumpEnv',
|
entry_point='alr_envs.alr.mujoco:AntJumpEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP,
|
max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP,
|
||||||
kwargs={
|
|
||||||
"max_episode_steps": MAX_EPISODE_STEPS_ANTJUMP,
|
|
||||||
"context": True
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRHalfCheetahJump-v0',
|
id='ALRHalfCheetahJump-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:ALRHalfCheetahJumpEnv',
|
entry_point='alr_envs.alr.mujoco:ALRHalfCheetahJumpEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
|
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
|
||||||
kwargs={
|
|
||||||
"max_episode_steps": MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
|
|
||||||
"context": True
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='HopperJumpOnBox-v0',
|
id='HopperJumpOnBox-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:ALRHopperJumpOnBoxEnv',
|
entry_point='alr_envs.alr.mujoco:HopperJumpOnBoxEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
|
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
|
||||||
kwargs={
|
|
||||||
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
|
|
||||||
"context": True
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRHopperThrow-v0',
|
id='ALRHopperThrow-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:ALRHopperThrowEnv',
|
entry_point='alr_envs.alr.mujoco:ALRHopperThrowEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
|
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
|
||||||
kwargs={
|
|
||||||
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERTHROW,
|
|
||||||
"context": True
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRHopperThrowInBasket-v0',
|
id='ALRHopperThrowInBasket-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:ALRHopperThrowInBasketEnv',
|
entry_point='alr_envs.alr.mujoco:ALRHopperThrowInBasketEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
|
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
|
||||||
kwargs={
|
|
||||||
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
|
|
||||||
"context": True
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRWalker2DJump-v0',
|
id='ALRWalker2DJump-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:ALRWalker2dJumpEnv',
|
entry_point='alr_envs.alr.mujoco:ALRWalker2dJumpEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
|
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
|
||||||
kwargs={
|
|
||||||
"max_episode_steps": MAX_EPISODE_STEPS_WALKERJUMP,
|
|
||||||
"context": True
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
@ -403,46 +379,48 @@ for _v in _versions:
|
|||||||
|
|
||||||
## Table Tennis needs to be fixed according to Zhou's implementation
|
## Table Tennis needs to be fixed according to Zhou's implementation
|
||||||
|
|
||||||
########################################################################################################################
|
# TODO: Add later when finished
|
||||||
|
# ########################################################################################################################
|
||||||
## AntJump
|
#
|
||||||
_versions = ['ALRAntJump-v0']
|
# ## AntJump
|
||||||
for _v in _versions:
|
# _versions = ['ALRAntJump-v0']
|
||||||
_name = _v.split("-")
|
# for _v in _versions:
|
||||||
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
# _name = _v.split("-")
|
||||||
kwargs_dict_ant_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
# _env_id = f'{_name[0]}ProMP-{_name[1]}'
|
||||||
kwargs_dict_ant_jump_promp['wrappers'].append(mujoco.ant_jump.MPWrapper)
|
# kwargs_dict_ant_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
kwargs_dict_ant_jump_promp['name'] = _v
|
# kwargs_dict_ant_jump_promp['wrappers'].append(mujoco.ant_jump.MPWrapper)
|
||||||
register(
|
# kwargs_dict_ant_jump_promp['name'] = _v
|
||||||
id=_env_id,
|
# register(
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
# id=_env_id,
|
||||||
kwargs=kwargs_dict_ant_jump_promp
|
# entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||||
)
|
# kwargs=kwargs_dict_ant_jump_promp
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
# )
|
||||||
|
# ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
########################################################################################################################
|
#
|
||||||
|
# ########################################################################################################################
|
||||||
## HalfCheetahJump
|
#
|
||||||
_versions = ['ALRHalfCheetahJump-v0']
|
# ## HalfCheetahJump
|
||||||
for _v in _versions:
|
# _versions = ['ALRHalfCheetahJump-v0']
|
||||||
_name = _v.split("-")
|
# for _v in _versions:
|
||||||
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
# _name = _v.split("-")
|
||||||
kwargs_dict_halfcheetah_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
# _env_id = f'{_name[0]}ProMP-{_name[1]}'
|
||||||
kwargs_dict_halfcheetah_jump_promp['wrappers'].append(mujoco.half_cheetah_jump.MPWrapper)
|
# kwargs_dict_halfcheetah_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
kwargs_dict_halfcheetah_jump_promp['name'] = _v
|
# kwargs_dict_halfcheetah_jump_promp['wrappers'].append(mujoco.half_cheetah_jump.MPWrapper)
|
||||||
register(
|
# kwargs_dict_halfcheetah_jump_promp['name'] = _v
|
||||||
id=_env_id,
|
# register(
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
# id=_env_id,
|
||||||
kwargs=kwargs_dict_halfcheetah_jump_promp
|
# entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||||
)
|
# kwargs=kwargs_dict_halfcheetah_jump_promp
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
# )
|
||||||
|
# ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
########################################################################################################################
|
#
|
||||||
|
# ########################################################################################################################
|
||||||
|
|
||||||
|
|
||||||
## HopperJump
|
## HopperJump
|
||||||
_versions = ['HopperJump-v0', 'HopperJumpSparse-v0', 'ALRHopperJumpOnBox-v0', 'ALRHopperThrow-v0',
|
_versions = ['HopperJump-v0', 'HopperJumpSparse-v0',
|
||||||
'ALRHopperThrowInBasket-v0']
|
# 'ALRHopperJumpOnBox-v0', 'ALRHopperThrow-v0', 'ALRHopperThrowInBasket-v0'
|
||||||
|
]
|
||||||
# TODO: Check if all environments work with the same MPWrapper
|
# TODO: Check if all environments work with the same MPWrapper
|
||||||
for _v in _versions:
|
for _v in _versions:
|
||||||
_name = _v.split("-")
|
_name = _v.split("-")
|
||||||
@ -457,23 +435,23 @@ for _v in _versions:
|
|||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
########################################################################################################################
|
# ########################################################################################################################
|
||||||
|
#
|
||||||
|
#
|
||||||
## Walker2DJump
|
# ## Walker2DJump
|
||||||
_versions = ['ALRWalker2DJump-v0']
|
# _versions = ['ALRWalker2DJump-v0']
|
||||||
for _v in _versions:
|
# for _v in _versions:
|
||||||
_name = _v.split("-")
|
# _name = _v.split("-")
|
||||||
_env_id = f'{_name[0]}ProMP-{_name[1]}'
|
# _env_id = f'{_name[0]}ProMP-{_name[1]}'
|
||||||
kwargs_dict_walker2d_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
# kwargs_dict_walker2d_jump_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
kwargs_dict_walker2d_jump_promp['wrappers'].append(mujoco.walker_2d_jump.MPWrapper)
|
# kwargs_dict_walker2d_jump_promp['wrappers'].append(mujoco.walker_2d_jump.MPWrapper)
|
||||||
kwargs_dict_walker2d_jump_promp['name'] = _v
|
# kwargs_dict_walker2d_jump_promp['name'] = _v
|
||||||
register(
|
# register(
|
||||||
id=_env_id,
|
# id=_env_id,
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
# entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||||
kwargs=kwargs_dict_walker2d_jump_promp
|
# kwargs=kwargs_dict_walker2d_jump_promp
|
||||||
)
|
# )
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
# ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
### Depricated, we will not provide non random starts anymore
|
### Depricated, we will not provide non random starts anymore
|
||||||
"""
|
"""
|
||||||
@ -639,7 +617,7 @@ for i in _vs:
|
|||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRHopperJumpOnBox-v0',
|
id='ALRHopperJumpOnBox-v0',
|
||||||
entry_point='alr_envs.alr.mujoco:ALRHopperJumpOnBoxEnv',
|
entry_point='alr_envs.alr.mujoco:HopperJumpOnBoxEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
|
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
|
||||||
kwargs={
|
kwargs={
|
||||||
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
|
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
from .beerpong.beerpong import BeerPongEnv, BeerPongEnvFixedReleaseStep, BeerPongEnvStepBasedEpisodicReward
|
from .beerpong.beerpong import BeerPongEnv, BeerPongEnvFixedReleaseStep, BeerPongEnvStepBasedEpisodicReward
|
||||||
from .ant_jump.ant_jump import AntJumpEnv
|
from .ant_jump.ant_jump import AntJumpEnv
|
||||||
from .half_cheetah_jump.half_cheetah_jump import ALRHalfCheetahJumpEnv
|
from .half_cheetah_jump.half_cheetah_jump import ALRHalfCheetahJumpEnv
|
||||||
from .hopper_jump.hopper_jump_on_box import ALRHopperJumpOnBoxEnv
|
from .hopper_jump.hopper_jump_on_box import HopperJumpOnBoxEnv
|
||||||
from .hopper_throw.hopper_throw import ALRHopperThrowEnv
|
from .hopper_throw.hopper_throw import ALRHopperThrowEnv
|
||||||
from .hopper_throw.hopper_throw_in_basket import ALRHopperThrowInBasketEnv
|
from .hopper_throw.hopper_throw_in_basket import ALRHopperThrowInBasketEnv
|
||||||
from .reacher.reacher import ReacherEnv
|
from .reacher.reacher import ReacherEnv
|
||||||
from .walker_2d_jump.walker_2d_jump import ALRWalker2dJumpEnv
|
from .walker_2d_jump.walker_2d_jump import ALRWalker2dJumpEnv
|
||||||
|
from .hopper_jump.hopper_jump import HopperJumpEnv
|
||||||
|
@ -7,7 +7,8 @@ from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
|||||||
|
|
||||||
class MPWrapper(RawInterfaceWrapper):
|
class MPWrapper(RawInterfaceWrapper):
|
||||||
|
|
||||||
def get_context_mask(self):
|
@property
|
||||||
|
def context_mask(self) -> np.ndarray:
|
||||||
return np.hstack([
|
return np.hstack([
|
||||||
[False] * 7, # cos
|
[False] * 7, # cos
|
||||||
[False] * 7, # sin
|
[False] * 7, # sin
|
||||||
@ -15,16 +16,16 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
[False] * 3, # cup_goal_diff_final
|
[False] * 3, # cup_goal_diff_final
|
||||||
[False] * 3, # cup_goal_diff_top
|
[False] * 3, # cup_goal_diff_top
|
||||||
[True] * 2, # xy position of cup
|
[True] * 2, # xy position of cup
|
||||||
[False] # env steps
|
# [False] # env steps
|
||||||
])
|
])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
def current_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
return self.env.sim.data.qpos[0:7].copy()
|
return self.env.data.qpos[0:7].copy()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
return self.env.sim.data.qvel[0:7].copy()
|
return self.env.data.qvel[0:7].copy()
|
||||||
|
|
||||||
# TODO: Fix this
|
# TODO: Fix this
|
||||||
def _episode_callback(self, action: np.ndarray, mp) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
|
def _episode_callback(self, action: np.ndarray, mp) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
|
||||||
|
@ -69,7 +69,7 @@ class ALRHalfCheetahJumpEnv(HalfCheetahEnv):
|
|||||||
options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]:
|
options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]:
|
||||||
self.max_height = 0
|
self.max_height = 0
|
||||||
self.current_step = 0
|
self.current_step = 0
|
||||||
self.goal = np.random.uniform(1.1, 1.6, 1) # 1.1 1.6
|
self.goal = self.np_random.uniform(1.1, 1.6, 1) # 1.1 1.6
|
||||||
return super().reset()
|
return super().reset()
|
||||||
|
|
||||||
# overwrite reset_model to make it deterministic
|
# overwrite reset_model to make it deterministic
|
||||||
|
@ -1,2 +1 @@
|
|||||||
from .mp_wrapper import MPWrapper
|
from .mp_wrapper import MPWrapper
|
||||||
|
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import copy
|
import copy
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from gym.envs.mujoco.hopper_v3 import HopperEnv
|
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from gym.envs.mujoco.hopper_v3 import HopperEnv
|
||||||
|
|
||||||
MAX_EPISODE_STEPS_HOPPERJUMP = 250
|
MAX_EPISODE_STEPS_HOPPERJUMP = 250
|
||||||
|
|
||||||
|
|
||||||
@ -23,10 +22,10 @@ class HopperJumpEnv(HopperEnv):
|
|||||||
xml_file='hopper_jump.xml',
|
xml_file='hopper_jump.xml',
|
||||||
forward_reward_weight=1.0,
|
forward_reward_weight=1.0,
|
||||||
ctrl_cost_weight=1e-3,
|
ctrl_cost_weight=1e-3,
|
||||||
healthy_reward=2.0, # 1 step
|
healthy_reward=2.0,
|
||||||
contact_weight=2.0, # 0 step
|
contact_weight=2.0,
|
||||||
height_weight=10.0, # 3 step
|
height_weight=10.0,
|
||||||
dist_weight=3.0, # 3 step
|
dist_weight=3.0,
|
||||||
terminate_when_unhealthy=False,
|
terminate_when_unhealthy=False,
|
||||||
healthy_state_range=(-100.0, 100.0),
|
healthy_state_range=(-100.0, 100.0),
|
||||||
healthy_z_range=(0.5, float('inf')),
|
healthy_z_range=(0.5, float('inf')),
|
||||||
@ -42,7 +41,7 @@ class HopperJumpEnv(HopperEnv):
|
|||||||
self._contact_weight = contact_weight
|
self._contact_weight = contact_weight
|
||||||
|
|
||||||
self.max_height = 0
|
self.max_height = 0
|
||||||
self.goal = 0
|
self.goal = np.zeros(3, )
|
||||||
|
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
self.contact_with_floor = False
|
self.contact_with_floor = False
|
||||||
@ -58,6 +57,10 @@ class HopperJumpEnv(HopperEnv):
|
|||||||
# increase initial height
|
# increase initial height
|
||||||
self.init_qpos[1] = 1.5
|
self.init_qpos[1] = 1.5
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exclude_current_positions_from_observation(self):
|
||||||
|
return self._exclude_current_positions_from_observation
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
self._steps += 1
|
self._steps += 1
|
||||||
|
|
||||||
@ -80,7 +83,7 @@ class HopperJumpEnv(HopperEnv):
|
|||||||
costs = ctrl_cost
|
costs = ctrl_cost
|
||||||
done = False
|
done = False
|
||||||
|
|
||||||
goal_dist = np.linalg.norm(site_pos_after - np.array([self.goal, 0, 0]))
|
goal_dist = np.linalg.norm(site_pos_after - self.goal)
|
||||||
if self.contact_dist is None and self.contact_with_floor:
|
if self.contact_dist is None and self.contact_with_floor:
|
||||||
self.contact_dist = goal_dist
|
self.contact_dist = goal_dist
|
||||||
|
|
||||||
@ -99,7 +102,7 @@ class HopperJumpEnv(HopperEnv):
|
|||||||
height=height_after,
|
height=height_after,
|
||||||
x_pos=site_pos_after,
|
x_pos=site_pos_after,
|
||||||
max_height=self.max_height,
|
max_height=self.max_height,
|
||||||
goal=self.goal,
|
goal=self.goal[:1],
|
||||||
goal_dist=goal_dist,
|
goal_dist=goal_dist,
|
||||||
height_rew=self.max_height,
|
height_rew=self.max_height,
|
||||||
healthy_reward=self.healthy_reward * 2,
|
healthy_reward=self.healthy_reward * 2,
|
||||||
@ -109,14 +112,15 @@ class HopperJumpEnv(HopperEnv):
|
|||||||
return observation, reward, done, info
|
return observation, reward, done, info
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
goal_dist = self.data.get_site_xpos('foot_site') - np.array([self.goal, 0, 0])
|
goal_dist = self.data.get_site_xpos('foot_site') - self.goal
|
||||||
return np.concatenate((super(HopperJumpEnv, self)._get_obs(), goal_dist.copy(), self.goal.copy()))
|
return np.concatenate((super(HopperJumpEnv, self)._get_obs(), goal_dist.copy(), self.goal[:1]))
|
||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
super(HopperJumpEnv, self).reset_model()
|
super(HopperJumpEnv, self).reset_model()
|
||||||
|
|
||||||
self.goal = self.np_random.uniform(0.3, 1.35, 1)[0]
|
# self.goal = self.np_random.uniform(0.3, 1.35, 1)[0]
|
||||||
self.sim.model.body_pos[self.sim.model.body_name2id('goal_site_body')] = np.array([self.goal, 0, 0])
|
self.goal = np.concatenate([self.np_random.uniform(0.3, 1.35, 1), np.zeros(2, )])
|
||||||
|
self.sim.model.body_pos[self.sim.model.body_name2id('goal_site_body')] = self.goal
|
||||||
self.max_height = 0
|
self.max_height = 0
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ import os
|
|||||||
MAX_EPISODE_STEPS_HOPPERJUMPONBOX = 250
|
MAX_EPISODE_STEPS_HOPPERJUMPONBOX = 250
|
||||||
|
|
||||||
|
|
||||||
class ALRHopperJumpOnBoxEnv(HopperEnv):
|
class HopperJumpOnBoxEnv(HopperEnv):
|
||||||
"""
|
"""
|
||||||
Initialization changes to normal Hopper:
|
Initialization changes to normal Hopper:
|
||||||
- healthy_reward: 1.0 -> 0.01 -> 0.001
|
- healthy_reward: 1.0 -> 0.01 -> 0.001
|
||||||
@ -153,7 +153,7 @@ class ALRHopperJumpOnBoxEnv(HopperEnv):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render_mode = "human" # "human" or "partial" or "final"
|
render_mode = "human" # "human" or "partial" or "final"
|
||||||
env = ALRHopperJumpOnBoxEnv()
|
env = HopperJumpOnBoxEnv()
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
for i in range(2000):
|
for i in range(2000):
|
||||||
|
@ -14,7 +14,8 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
[False] * (2 + int(not self.exclude_current_positions_from_observation)), # position
|
[False] * (2 + int(not self.exclude_current_positions_from_observation)), # position
|
||||||
[True] * 3, # set to true if randomize initial pos
|
[True] * 3, # set to true if randomize initial pos
|
||||||
[False] * 6, # velocity
|
[False] * 6, # velocity
|
||||||
[True]
|
[True] * 3, # goal distance
|
||||||
|
[True] # goal
|
||||||
])
|
])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -67,7 +67,9 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
|
|
||||||
def observation(self, observation):
|
def observation(self, observation):
|
||||||
# return context space if we are
|
# return context space if we are
|
||||||
return observation[self.env.context_mask] if self.return_context_observation else observation
|
obs = observation[self.env.context_mask] if self.return_context_observation else observation
|
||||||
|
# cast dtype because metaworld returns incorrect that throws gym error
|
||||||
|
return obs.astype(self.observation_space.dtype)
|
||||||
|
|
||||||
def get_trajectory(self, action: np.ndarray) -> Tuple:
|
def get_trajectory(self, action: np.ndarray) -> Tuple:
|
||||||
clipped_params = np.clip(action, self.traj_gen_action_space.low, self.traj_gen_action_space.high)
|
clipped_params = np.clip(action, self.traj_gen_action_space.low, self.traj_gen_action_space.high)
|
||||||
@ -147,7 +149,7 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
infos[k] = elems
|
infos[k] = elems
|
||||||
|
|
||||||
if self.render_kwargs:
|
if self.render_kwargs:
|
||||||
self.render(**self.render_kwargs)
|
self.env.render(**self.render_kwargs)
|
||||||
|
|
||||||
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
|
if done or self.replanning_schedule(self.current_pos, self.current_vel, obs, c_action,
|
||||||
t + 1 + self.current_traj_steps):
|
t + 1 + self.current_traj_steps):
|
||||||
@ -170,13 +172,13 @@ class BlackBoxWrapper(gym.ObservationWrapper):
|
|||||||
def render(self, **kwargs):
|
def render(self, **kwargs):
|
||||||
"""Only set render options here, such that they can be used during the rollout.
|
"""Only set render options here, such that they can be used during the rollout.
|
||||||
This only needs to be called once"""
|
This only needs to be called once"""
|
||||||
self.render_kwargs = kwargs or self.render_kwargs
|
self.render_kwargs = kwargs
|
||||||
# self.env.render(mode=self.render_mode, **self.render_kwargs)
|
# self.env.render(mode=self.render_mode, **self.render_kwargs)
|
||||||
self.env.render(**self.render_kwargs)
|
# self.env.render(**self.render_kwargs)
|
||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None):
|
def reset(self, *, seed: Optional[int] = None, return_info: bool = False, options: Optional[dict] = None):
|
||||||
self.current_traj_steps = 0
|
self.current_traj_steps = 0
|
||||||
return super(BlackBoxWrapper, self).reset(seed=seed, return_info=return_info, options=options)
|
return super(BlackBoxWrapper, self).reset()
|
||||||
|
|
||||||
def plot_trajs(self, des_trajs, des_vels):
|
def plot_trajs(self, des_trajs, des_vels):
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
@ -10,13 +10,12 @@ class MetaWorldController(BaseController):
|
|||||||
Unlike the other Controllers, this is a special tracking_controller for MetaWorld environments.
|
Unlike the other Controllers, this is a special tracking_controller for MetaWorld environments.
|
||||||
They use a position delta for the xyz coordinates and a raw position for the gripper opening.
|
They use a position delta for the xyz coordinates and a raw position for the gripper opening.
|
||||||
|
|
||||||
:param env: A position environment
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_action(self, des_pos, des_vel, c_pos, c_vel):
|
def get_action(self, des_pos, des_vel, c_pos, c_vel):
|
||||||
gripper_pos = des_pos[-1]
|
gripper_pos = des_pos[-1]
|
||||||
|
|
||||||
cur_pos = env.current_pos[:-1]
|
cur_pos = c_pos[:-1]
|
||||||
xyz_pos = des_pos[:-1]
|
xyz_pos = des_pos[:-1]
|
||||||
|
|
||||||
assert xyz_pos.shape == cur_pos.shape, \
|
assert xyz_pos.shape == cur_pos.shape, \
|
||||||
|
@ -18,4 +18,4 @@ def get_controller(controller_type: str, **kwargs):
|
|||||||
return MetaWorldController()
|
return MetaWorldController()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Specified controller type {controller_type} not supported, "
|
raise ValueError(f"Specified controller type {controller_type} not supported, "
|
||||||
f"please choose one of {ALL_TYPES}.")
|
f"please choose one of {ALL_TYPES}.")
|
@ -63,16 +63,16 @@ def example_custom_mp(env_name="Reacher5dProMP-v0", seed=1, iterations=1, render
|
|||||||
# mp_dict.update({'black_box_kwargs': {'learn_sub_trajectories': True}})
|
# mp_dict.update({'black_box_kwargs': {'learn_sub_trajectories': True}})
|
||||||
# mp_dict.update({'black_box_kwargs': {'do_replanning': lambda pos, vel, t: lambda t: t % 100}})
|
# mp_dict.update({'black_box_kwargs': {'do_replanning': lambda pos, vel, t: lambda t: t % 100}})
|
||||||
|
|
||||||
|
rewards = 0
|
||||||
|
obs = env.reset()
|
||||||
|
|
||||||
# This time rendering every trajectory
|
# This time rendering every trajectory
|
||||||
if render:
|
if render:
|
||||||
env.render(mode="human")
|
env.render(mode="human")
|
||||||
|
|
||||||
rewards = 0
|
|
||||||
obs = env.reset()
|
|
||||||
|
|
||||||
# number of samples/full trajectories (multiple environment steps)
|
# number of samples/full trajectories (multiple environment steps)
|
||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
ac = env.action_space.sample()
|
ac = env.action_space.sample() * 1000
|
||||||
obs, reward, done, info = env.step(ac)
|
obs, reward, done, info = env.step(ac)
|
||||||
rewards += reward
|
rewards += reward
|
||||||
|
|
||||||
@ -139,7 +139,7 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render = False
|
render = True
|
||||||
# # DMP
|
# # DMP
|
||||||
# example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)
|
# example_mp("alr_envs:HoleReacherDMP-v1", seed=10, iterations=1, render=render)
|
||||||
#
|
#
|
||||||
@ -150,7 +150,7 @@ if __name__ == '__main__':
|
|||||||
# example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=1, render=render)
|
# example_mp("alr_envs:HoleReacherDetPMP-v1", seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
# Altered basis functions
|
# Altered basis functions
|
||||||
example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=10, render=render)
|
example_custom_mp("HopperJumpSparseProMP-v0", seed=10, iterations=10, render=render)
|
||||||
|
|
||||||
# Custom MP
|
# Custom MP
|
||||||
# example_fully_custom_mp(seed=10, iterations=1, render=render)
|
# example_fully_custom_mp(seed=10, iterations=1, render=render)
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
from gym import register
|
from gym import register
|
||||||
|
|
||||||
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
|
from . import goal_object_change_mp_wrapper, goal_change_mp_wrapper, goal_endeffector_change_mp_wrapper, \
|
||||||
@ -7,27 +9,39 @@ ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
|
|||||||
|
|
||||||
# MetaWorld
|
# MetaWorld
|
||||||
|
|
||||||
|
DEFAULT_BB_DICT_ProMP = {
|
||||||
|
"name": 'EnvName',
|
||||||
|
"wrappers": [],
|
||||||
|
"trajectory_generator_kwargs": {
|
||||||
|
'trajectory_generator_type': 'promp'
|
||||||
|
},
|
||||||
|
"phase_generator_kwargs": {
|
||||||
|
'phase_generator_type': 'linear'
|
||||||
|
},
|
||||||
|
"controller_kwargs": {
|
||||||
|
'controller_type': 'metaworld',
|
||||||
|
},
|
||||||
|
"basis_generator_kwargs": {
|
||||||
|
'basis_generator_type': 'zero_rbf',
|
||||||
|
'num_basis': 5,
|
||||||
|
'num_basis_zero_start': 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
_goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2",
|
_goal_change_envs = ["assembly-v2", "pick-out-of-hole-v2", "plate-slide-v2", "plate-slide-back-v2",
|
||||||
"plate-slide-side-v2", "plate-slide-back-side-v2"]
|
"plate-slide-side-v2", "plate-slide-back-side-v2"]
|
||||||
for _task in _goal_change_envs:
|
for _task in _goal_change_envs:
|
||||||
task_id_split = _task.split("-")
|
task_id_split = _task.split("-")
|
||||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||||
|
kwargs_dict_goal_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
|
kwargs_dict_goal_change_promp['wrappers'].append(goal_change_mp_wrapper.MPWrapper)
|
||||||
|
kwargs_dict_goal_change_promp['name'] = _task
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||||
kwargs={
|
kwargs=kwargs_dict_goal_change_promp
|
||||||
"name": _task,
|
|
||||||
"wrappers": [goal_change_mp_wrapper.MPWrapper],
|
|
||||||
"traj_gen_kwargs": {
|
|
||||||
"num_dof": 4,
|
|
||||||
"num_basis": 5,
|
|
||||||
"duration": 6.25,
|
|
||||||
"post_traj_time": 0,
|
|
||||||
"zero_start": True,
|
|
||||||
"policy_type": "metaworld",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
@ -36,21 +50,13 @@ for _task in _object_change_envs:
|
|||||||
task_id_split = _task.split("-")
|
task_id_split = _task.split("-")
|
||||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||||
|
kwargs_dict_object_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
|
kwargs_dict_object_change_promp['wrappers'].append(object_change_mp_wrapper.MPWrapper)
|
||||||
|
kwargs_dict_object_change_promp['name'] = _task
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||||
kwargs={
|
kwargs=kwargs_dict_object_change_promp
|
||||||
"name": _task,
|
|
||||||
"wrappers": [object_change_mp_wrapper.MPWrapper],
|
|
||||||
"traj_gen_kwargs": {
|
|
||||||
"num_dof": 4,
|
|
||||||
"num_basis": 5,
|
|
||||||
"duration": 6.25,
|
|
||||||
"post_traj_time": 0,
|
|
||||||
"zero_start": True,
|
|
||||||
"policy_type": "metaworld",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
@ -69,21 +75,14 @@ for _task in _goal_and_object_change_envs:
|
|||||||
task_id_split = _task.split("-")
|
task_id_split = _task.split("-")
|
||||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||||
|
kwargs_dict_goal_and_object_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
|
kwargs_dict_goal_and_object_change_promp['wrappers'].append(goal_object_change_mp_wrapper.MPWrapper)
|
||||||
|
kwargs_dict_goal_and_object_change_promp['name'] = _task
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||||
kwargs={
|
kwargs=kwargs_dict_goal_and_object_change_promp
|
||||||
"name": _task,
|
|
||||||
"wrappers": [goal_object_change_mp_wrapper.MPWrapper],
|
|
||||||
"traj_gen_kwargs": {
|
|
||||||
"num_dof": 4,
|
|
||||||
"num_basis": 5,
|
|
||||||
"duration": 6.25,
|
|
||||||
"post_traj_time": 0,
|
|
||||||
"zero_start": True,
|
|
||||||
"policy_type": "metaworld",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
@ -92,20 +91,13 @@ for _task in _goal_and_endeffector_change_envs:
|
|||||||
task_id_split = _task.split("-")
|
task_id_split = _task.split("-")
|
||||||
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
name = "".join([s.capitalize() for s in task_id_split[:-1]])
|
||||||
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
_env_id = f'{name}ProMP-{task_id_split[-1]}'
|
||||||
|
kwargs_dict_goal_and_endeffector_change_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
|
||||||
|
kwargs_dict_goal_and_endeffector_change_promp['wrappers'].append(goal_endeffector_change_mp_wrapper.MPWrapper)
|
||||||
|
kwargs_dict_goal_and_endeffector_change_promp['name'] = _task
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id=_env_id,
|
id=_env_id,
|
||||||
entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper',
|
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
|
||||||
kwargs={
|
kwargs=kwargs_dict_goal_and_endeffector_change_promp
|
||||||
"name": _task,
|
|
||||||
"wrappers": [goal_endeffector_change_mp_wrapper.MPWrapper],
|
|
||||||
"traj_gen_kwargs": {
|
|
||||||
"num_dof": 4,
|
|
||||||
"num_basis": 5,
|
|
||||||
"duration": 6.25,
|
|
||||||
"post_traj_time": 0,
|
|
||||||
"zero_start": True,
|
|
||||||
"policy_type": "metaworld",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
21
alr_envs/meta/base_metaworld_mp_wrapper.py
Normal file
21
alr_envs/meta/base_metaworld_mp_wrapper.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from abc import ABC
|
||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMetaworldMPWrapper(RawInterfaceWrapper, ABC):
|
||||||
|
@property
|
||||||
|
def current_pos(self) -> Union[float, int, np.ndarray]:
|
||||||
|
r_close = self.env.data.get_joint_qpos("r_close")
|
||||||
|
# TODO check if this is correct
|
||||||
|
# return np.hstack([self.env.data.get_body_xpos('hand').flatten() / self.env.action_scale, r_close])
|
||||||
|
return np.hstack([self.env.data.mocap_pos.flatten() / self.env.action_scale, r_close])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
||||||
|
# TODO check if this is correct
|
||||||
|
return np.zeros(4, )
|
||||||
|
# raise NotImplementedError("Velocity cannot be retrieved.")
|
@ -1,11 +1,9 @@
|
|||||||
from typing import Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
from alr_envs.meta.base_metaworld_mp_wrapper import BaseMetaworldMPWrapper
|
||||||
|
|
||||||
|
|
||||||
class MPWrapper(RawInterfaceWrapper):
|
class MPWrapper(BaseMetaworldMPWrapper):
|
||||||
"""
|
"""
|
||||||
This Wrapper is for environments where merely the goal changes in the beginning
|
This Wrapper is for environments where merely the goal changes in the beginning
|
||||||
and no secondary objects or end effectors are altered at the start of an episode.
|
and no secondary objects or end effectors are altered at the start of an episode.
|
||||||
@ -49,20 +47,3 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
# Goal
|
# Goal
|
||||||
[True] * 3, # goal position
|
[True] * 3, # goal position
|
||||||
])
|
])
|
||||||
|
|
||||||
@property
|
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray]:
|
|
||||||
r_close = self.env.data.get_joint_qpos("r_close")
|
|
||||||
return np.hstack([self.env.data.mocap_pos.flatten() / self.env.action_scale, r_close])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
||||||
raise NotImplementedError("Velocity cannot be retrieved.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
||||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dt(self) -> Union[float, int]:
|
|
||||||
return self.env.dt
|
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
from typing import Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
from alr_envs.meta.base_metaworld_mp_wrapper import BaseMetaworldMPWrapper
|
||||||
|
|
||||||
|
|
||||||
class MPWrapper(RawInterfaceWrapper):
|
class MPWrapper(BaseMetaworldMPWrapper):
|
||||||
"""
|
"""
|
||||||
This Wrapper is for environments where merely the goal changes in the beginning
|
This Wrapper is for environments where merely the goal changes in the beginning
|
||||||
and no secondary objects or end effectors are altered at the start of an episode.
|
and no secondary objects or end effectors are altered at the start of an episode.
|
||||||
@ -49,20 +47,3 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
# Goal
|
# Goal
|
||||||
[True] * 3, # goal position
|
[True] * 3, # goal position
|
||||||
])
|
])
|
||||||
|
|
||||||
@property
|
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray]:
|
|
||||||
r_close = self.env.data.get_joint_qpos("r_close")
|
|
||||||
return np.hstack([self.env.data.mocap_pos.flatten() / self.env.action_scale, r_close])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
||||||
raise NotImplementedError("Velocity cannot be retrieved.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
||||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dt(self) -> Union[float, int]:
|
|
||||||
return self.env.dt
|
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
from typing import Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
from alr_envs.meta.base_metaworld_mp_wrapper import BaseMetaworldMPWrapper
|
||||||
|
|
||||||
|
|
||||||
class MPWrapper(RawInterfaceWrapper):
|
class MPWrapper(BaseMetaworldMPWrapper):
|
||||||
"""
|
"""
|
||||||
This Wrapper is for environments where merely the goal changes in the beginning
|
This Wrapper is for environments where merely the goal changes in the beginning
|
||||||
and no secondary objects or end effectors are altered at the start of an episode.
|
and no secondary objects or end effectors are altered at the start of an episode.
|
||||||
@ -49,20 +47,3 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
# Goal
|
# Goal
|
||||||
[True] * 3, # goal position
|
[True] * 3, # goal position
|
||||||
])
|
])
|
||||||
|
|
||||||
@property
|
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray]:
|
|
||||||
r_close = self.env.data.get_joint_qpos("r_close")
|
|
||||||
return np.hstack([self.env.data.mocap_pos.flatten() / self.env.action_scale, r_close])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
||||||
raise NotImplementedError("Velocity cannot be retrieved.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
||||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dt(self) -> Union[float, int]:
|
|
||||||
return self.env.dt
|
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
from typing import Tuple, Union
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from alr_envs.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
from alr_envs.meta.base_metaworld_mp_wrapper import BaseMetaworldMPWrapper
|
||||||
|
|
||||||
|
|
||||||
class MPWrapper(RawInterfaceWrapper):
|
class MPWrapper(BaseMetaworldMPWrapper):
|
||||||
"""
|
"""
|
||||||
This Wrapper is for environments where merely the goal changes in the beginning
|
This Wrapper is for environments where merely the goal changes in the beginning
|
||||||
and no secondary objects or end effectors are altered at the start of an episode.
|
and no secondary objects or end effectors are altered at the start of an episode.
|
||||||
@ -49,20 +47,3 @@ class MPWrapper(RawInterfaceWrapper):
|
|||||||
# Goal
|
# Goal
|
||||||
[True] * 3, # goal position
|
[True] * 3, # goal position
|
||||||
])
|
])
|
||||||
|
|
||||||
@property
|
|
||||||
def current_pos(self) -> Union[float, int, np.ndarray]:
|
|
||||||
r_close = self.env.data.get_joint_qpos("r_close")
|
|
||||||
return np.hstack([self.env.data.mocap_pos.flatten() / self.env.action_scale, r_close])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def current_vel(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
||||||
raise NotImplementedError("Velocity cannot be retrieved.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def goal_pos(self) -> Union[float, int, np.ndarray, Tuple]:
|
|
||||||
raise ValueError("Goal position is not available and has to be learnt based on the environment.")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dt(self) -> Union[float, int]:
|
|
||||||
return self.env.dt
|
|
||||||
|
@ -20,7 +20,7 @@ def make_dmc(
|
|||||||
environment_kwargs: dict = {},
|
environment_kwargs: dict = {},
|
||||||
time_limit: Union[None, float] = None,
|
time_limit: Union[None, float] = None,
|
||||||
channels_first: bool = True
|
channels_first: bool = True
|
||||||
):
|
):
|
||||||
# Adopted from: https://github.com/denisyarats/dmc2gym/blob/master/dmc2gym/__init__.py
|
# Adopted from: https://github.com/denisyarats/dmc2gym/blob/master/dmc2gym/__init__.py
|
||||||
# License: MIT
|
# License: MIT
|
||||||
# Copyright (c) 2020 Denis Yarats
|
# Copyright (c) 2020 Denis Yarats
|
||||||
@ -32,12 +32,10 @@ def make_dmc(
|
|||||||
env_id = f'dmc_{domain_name}_{task_name}_{seed}-v1'
|
env_id = f'dmc_{domain_name}_{task_name}_{seed}-v1'
|
||||||
|
|
||||||
if from_pixels:
|
if from_pixels:
|
||||||
assert not visualize_reward, 'cannot use visualize reward when learning from pixels'
|
assert not visualize_reward, 'Cannot use visualize reward when learning from pixels.'
|
||||||
|
|
||||||
# shorten episode length
|
# Default lengths for benchmarking suite is 1000 and for manipulation tasks 250
|
||||||
if episode_length is None:
|
episode_length = episode_length or (250 if domain_name == "manipulation" else 1000)
|
||||||
# Default lengths for benchmarking suite is 1000 and for manipulation tasks 250
|
|
||||||
episode_length = 250 if domain_name == "manipulation" else 1000
|
|
||||||
|
|
||||||
max_episode_steps = (episode_length + frame_skip - 1) // frame_skip
|
max_episode_steps = (episode_length + frame_skip - 1) // frame_skip
|
||||||
if env_id not in gym.envs.registry.env_specs:
|
if env_id not in gym.envs.registry.env_specs:
|
||||||
@ -61,7 +59,7 @@ def make_dmc(
|
|||||||
camera_id=camera_id,
|
camera_id=camera_id,
|
||||||
frame_skip=frame_skip,
|
frame_skip=frame_skip,
|
||||||
channels_first=channels_first,
|
channels_first=channels_first,
|
||||||
),
|
),
|
||||||
max_episode_steps=max_episode_steps,
|
max_episode_steps=max_episode_steps,
|
||||||
)
|
)
|
||||||
return gym.make(env_id)
|
return gym.make(env_id)
|
||||||
|
@ -8,7 +8,7 @@ from gym.envs.registration import EnvSpec, registry
|
|||||||
from gym.wrappers import TimeAwareObservation
|
from gym.wrappers import TimeAwareObservation
|
||||||
|
|
||||||
from alr_envs.black_box.black_box_wrapper import BlackBoxWrapper
|
from alr_envs.black_box.black_box_wrapper import BlackBoxWrapper
|
||||||
from alr_envs.black_box.controller.controller_factory import get_controller
|
from alr_envs.black_box.factory.controller_factory import get_controller
|
||||||
from alr_envs.black_box.factory.basis_generator_factory import get_basis_generator
|
from alr_envs.black_box.factory.basis_generator_factory import get_basis_generator
|
||||||
from alr_envs.black_box.factory.phase_generator_factory import get_phase_generator
|
from alr_envs.black_box.factory.phase_generator_factory import get_phase_generator
|
||||||
from alr_envs.black_box.factory.trajectory_generator_factory import get_trajectory_generator
|
from alr_envs.black_box.factory.trajectory_generator_factory import get_trajectory_generator
|
||||||
@ -43,11 +43,7 @@ def make_rank(env_id: str, seed: int, rank: int = 0, return_callable=True, **kwa
|
|||||||
|
|
||||||
|
|
||||||
def make(env_id, seed, **kwargs):
|
def make(env_id, seed, **kwargs):
|
||||||
# This access is required to allow for nested dict updates
|
return _make(env_id, seed, **kwargs)
|
||||||
spec = registry.get(env_id)
|
|
||||||
all_kwargs = deepcopy(spec.kwargs)
|
|
||||||
nested_update(all_kwargs, kwargs)
|
|
||||||
return _make(env_id, seed, **all_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _make(env_id: str, seed, **kwargs):
|
def _make(env_id: str, seed, **kwargs):
|
||||||
@ -62,12 +58,25 @@ def _make(env_id: str, seed, **kwargs):
|
|||||||
Returns: Gym environment
|
Returns: Gym environment
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if any(deprec in env_id for deprec in ["DetPMP", "detpmp"]):
|
|
||||||
warnings.warn("DetPMP is deprecated and converted to ProMP")
|
# 'dmc:domain-task'
|
||||||
env_id = env_id.replace("DetPMP", "ProMP")
|
# 'gym:name-vX'
|
||||||
env_id = env_id.replace("detpmp", "promp")
|
# 'meta:name-vX'
|
||||||
|
# 'meta:bb:name-vX'
|
||||||
|
# 'hand:name-vX'
|
||||||
|
# 'name-vX'
|
||||||
|
# 'bb:name-vX'
|
||||||
|
#
|
||||||
|
# env_id.split(':')
|
||||||
|
# if 'dmc' :
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# This access is required to allow for nested dict updates for BB envs
|
||||||
|
spec = registry.get(env_id)
|
||||||
|
all_kwargs = deepcopy(spec.kwargs)
|
||||||
|
nested_update(all_kwargs, kwargs)
|
||||||
|
kwargs = all_kwargs
|
||||||
|
|
||||||
# Add seed to kwargs in case it is a predefined gym+dmc hybrid environment.
|
# Add seed to kwargs in case it is a predefined gym+dmc hybrid environment.
|
||||||
if env_id.startswith("dmc"):
|
if env_id.startswith("dmc"):
|
||||||
kwargs.update({"seed": seed})
|
kwargs.update({"seed": seed})
|
||||||
@ -77,22 +86,25 @@ def _make(env_id: str, seed, **kwargs):
|
|||||||
env.seed(seed)
|
env.seed(seed)
|
||||||
env.action_space.seed(seed)
|
env.action_space.seed(seed)
|
||||||
env.observation_space.seed(seed)
|
env.observation_space.seed(seed)
|
||||||
except gym.error.Error:
|
except (gym.error.Error, AttributeError):
|
||||||
|
|
||||||
# MetaWorld env
|
# MetaWorld env
|
||||||
import metaworld
|
import metaworld
|
||||||
if env_id in metaworld.ML1.ENV_NAMES:
|
if env_id in metaworld.ML1.ENV_NAMES:
|
||||||
env = metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id + "-goal-observable"](seed=seed, **kwargs)
|
env = metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[env_id + "-goal-observable"](seed=seed, **kwargs)
|
||||||
|
|
||||||
# setting this avoids generating the same initialization after each reset
|
# setting this avoids generating the same initialization after each reset
|
||||||
env._freeze_rand_vec = False
|
env._freeze_rand_vec = False
|
||||||
|
env.seeded_rand_vec = True
|
||||||
|
|
||||||
# Manually set spec, as metaworld environments are not registered via gym
|
# Manually set spec, as metaworld environments are not registered via gym
|
||||||
env.unwrapped.spec = EnvSpec(env_id)
|
env.unwrapped.spec = EnvSpec(env_id)
|
||||||
# Set Timelimit based on the maximum allowed path length of the environment
|
# Set Timelimit based on the maximum allowed path length of the environment
|
||||||
env = gym.wrappers.TimeLimit(env, max_episode_steps=env.max_path_length)
|
env = gym.wrappers.TimeLimit(env, max_episode_steps=env.max_path_length)
|
||||||
env.seed(seed)
|
# env.seed(seed)
|
||||||
env.action_space.seed(seed)
|
# env.action_space.seed(seed)
|
||||||
env.observation_space.seed(seed)
|
# env.observation_space.seed(seed)
|
||||||
env.goal_space.seed(seed)
|
# env.goal_space.seed(seed)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# DMC
|
# DMC
|
||||||
|
26
setup.py
26
setup.py
@ -1,10 +1,10 @@
|
|||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
from setuptools import setup
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
# Environment-specific dependencies for dmc and metaworld
|
# Environment-specific dependencies for dmc and metaworld
|
||||||
extras = {
|
extras = {
|
||||||
"dmc": ["dm_control"],
|
"dmc": ["dm_control==1.0.1"],
|
||||||
"meta": ["metaworld @ git+https://github.com/rlworkgroup/metaworld.git@master#egg=metaworld"],
|
"meta": ["metaworld @ git+https://github.com/rlworkgroup/metaworld.git@master#egg=metaworld"],
|
||||||
"mujoco": ["mujoco==2.2.0", "imageio>=2.14.1"],
|
"mujoco": ["mujoco==2.2.0", "imageio>=2.14.1"],
|
||||||
}
|
}
|
||||||
@ -16,12 +16,28 @@ extras["all"] = list(set(itertools.chain.from_iterable(map(lambda group: extras[
|
|||||||
setup(
|
setup(
|
||||||
author='Fabian Otto, Onur Celik, Marcel Sandermann, Maximilian Huettenrauch',
|
author='Fabian Otto, Onur Celik, Marcel Sandermann, Maximilian Huettenrauch',
|
||||||
name='simple_gym',
|
name='simple_gym',
|
||||||
version='0.0.1',
|
version='0.1',
|
||||||
packages=['alr_envs', 'alr_envs.alr', 'alr_envs.open_ai', 'alr_envs.dmc', 'alr_envs.meta', 'alr_envs.utils'],
|
classifiers=[
|
||||||
|
# Python 3.6 is minimally supported (only with basic gym environments and API)
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.7",
|
||||||
|
"Programming Language :: Python :: 3.8",
|
||||||
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
],
|
||||||
|
extras_require=extras,
|
||||||
install_requires=[
|
install_requires=[
|
||||||
'gym',
|
'gym>=0.24.0',
|
||||||
"mujoco_py<2.2,>=2.1",
|
"mujoco_py<2.2,>=2.1",
|
||||||
],
|
],
|
||||||
|
packages=[package for package in find_packages() if package.startswith("alr_envs")],
|
||||||
|
# packages=['alr_envs', 'alr_envs.alr', 'alr_envs.open_ai', 'alr_envs.dmc', 'alr_envs.meta', 'alr_envs.utils'],
|
||||||
|
package_data={
|
||||||
|
"alr_envs": [
|
||||||
|
"alr/mujoco/*/assets/*.xml",
|
||||||
|
]
|
||||||
|
},
|
||||||
|
python_requires=">=3.6",
|
||||||
url='https://github.com/ALRhub/alr_envs/',
|
url='https://github.com/ALRhub/alr_envs/',
|
||||||
# license='AGPL-3.0 license',
|
# license='AGPL-3.0 license',
|
||||||
author_email='',
|
author_email='',
|
||||||
|
@ -34,12 +34,7 @@ class TestMPEnvironments(unittest.TestCase):
|
|||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
self._verify_observations(obs, env.observation_space, "reset()")
|
self._verify_observations(obs, env.observation_space, "reset()")
|
||||||
|
|
||||||
length = env.spec.max_episode_steps
|
iterations = iterations or (env.spec.max_episode_steps or 1)
|
||||||
if iterations is None:
|
|
||||||
if length is None:
|
|
||||||
iterations = 1
|
|
||||||
else:
|
|
||||||
iterations = length
|
|
||||||
|
|
||||||
# number of samples(multiple environment steps)
|
# number of samples(multiple environment steps)
|
||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
@ -76,7 +71,7 @@ class TestMPEnvironments(unittest.TestCase):
|
|||||||
traj2 = self._run_env(env_id, seed=seed)
|
traj2 = self._run_env(env_id, seed=seed)
|
||||||
for i, time_step in enumerate(zip(*traj1, *traj2)):
|
for i, time_step in enumerate(zip(*traj1, *traj2)):
|
||||||
obs1, rwd1, done1, obs2, rwd2, done2 = time_step
|
obs1, rwd1, done1, obs2, rwd2, done2 = time_step
|
||||||
self.assertTrue(np.array_equal(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match.")
|
self.assertTrue(np.allclose(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match.")
|
||||||
self.assertEqual(rwd1, rwd2, f"Rewards [{i}] {rwd1} and {rwd2} do not match.")
|
self.assertEqual(rwd1, rwd2, f"Rewards [{i}] {rwd1} and {rwd2} do not match.")
|
||||||
self.assertEqual(done1, done2, f"Dones [{i}] {done1} and {done2} do not match.")
|
self.assertEqual(done1, done2, f"Dones [{i}] {done1} and {done2} do not match.")
|
||||||
|
|
@ -36,12 +36,7 @@ class TestStepDMCEnvironments(unittest.TestCase):
|
|||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
self._verify_observations(obs, env.observation_space, "reset()")
|
self._verify_observations(obs, env.observation_space, "reset()")
|
||||||
|
|
||||||
length = env.spec.max_episode_steps
|
iterations = iterations or (env.spec.max_episode_steps or 1)
|
||||||
if iterations is None:
|
|
||||||
if length is None:
|
|
||||||
iterations = 1
|
|
||||||
else:
|
|
||||||
iterations = length
|
|
||||||
|
|
||||||
# number of samples(multiple environment steps)
|
# number of samples(multiple environment steps)
|
||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
|
@ -35,12 +35,7 @@ class TestStepMetaWorlEnvironments(unittest.TestCase):
|
|||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
self._verify_observations(obs, env.observation_space, "reset()")
|
self._verify_observations(obs, env.observation_space, "reset()")
|
||||||
|
|
||||||
length = env.max_path_length
|
iterations = iterations or (env.spec.max_episode_steps or 1)
|
||||||
if iterations is None:
|
|
||||||
if length is None:
|
|
||||||
iterations = 1
|
|
||||||
else:
|
|
||||||
iterations = length
|
|
||||||
|
|
||||||
# number of samples(multiple environment steps)
|
# number of samples(multiple environment steps)
|
||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
|
Loading…
Reference in New Issue
Block a user