renameing alr module and updating tests

This commit is contained in:
Fabian 2022-07-12 15:17:02 +02:00
parent 79c26681c9
commit 0339361656
127 changed files with 418 additions and 321 deletions

View File

@ -107,7 +107,7 @@ keys `DMP` and `ProMP` that store a list of available environment names.
import alr_envs import alr_envs
print("Custom MP tasks:") print("Custom MP tasks:")
print(alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS) print(alr_envs.ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS)
print("OpenAI Gym MP tasks:") print("OpenAI Gym MP tasks:")
print(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS) print(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS)
@ -116,7 +116,7 @@ print("Deepmind Control MP tasks:")
print(alr_envs.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS) print(alr_envs.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS)
print("MetaWorld MP tasks:") print("MetaWorld MP tasks:")
print(alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS) print(alr_envs.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS)
``` ```
### How to create a new MP task ### How to create a new MP task

View File

@ -2,13 +2,13 @@ from alr_envs import dmc, meta, open_ai
from alr_envs.utils.make_env_helpers import make, make_bb, make_rank from alr_envs.utils.make_env_helpers import make, make_bb, make_rank
# Convenience function for all MP environments # Convenience function for all MP environments
from .alr import ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS from .envs import ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS
from .dmc import ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS from .dmc import ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS
from .meta import ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS from .meta import ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS
from .open_ai import ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS from .open_ai import ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS
ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS = { ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {
key: value + ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS[key] + key: value + ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS[key] +
ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS[key] + ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS[key] +
ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS[key] ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS[key]
for key, value in ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS.items()} for key, value in ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS.items()}

View File

@ -16,7 +16,7 @@ from .mujoco.hopper_throw.hopper_throw_in_basket import MAX_EPISODE_STEPS_HOPPER
from .mujoco.reacher.reacher import ReacherEnv from .mujoco.reacher.reacher import ReacherEnv
from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []} ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": []}
DEFAULT_BB_DICT_ProMP = { DEFAULT_BB_DICT_ProMP = {
"name": 'EnvName', "name": 'EnvName',
@ -63,7 +63,7 @@ DEFAULT_BB_DICT_DMP = {
## Simple Reacher ## Simple Reacher
register( register(
id='SimpleReacher-v0', id='SimpleReacher-v0',
entry_point='alr_envs.alr.classic_control:SimpleReacherEnv', entry_point='alr_envs.envs.classic_control:SimpleReacherEnv',
max_episode_steps=200, max_episode_steps=200,
kwargs={ kwargs={
"n_links": 2, "n_links": 2,
@ -72,7 +72,7 @@ register(
register( register(
id='LongSimpleReacher-v0', id='LongSimpleReacher-v0',
entry_point='alr_envs.alr.classic_control:SimpleReacherEnv', entry_point='alr_envs.envs.classic_control:SimpleReacherEnv',
max_episode_steps=200, max_episode_steps=200,
kwargs={ kwargs={
"n_links": 5, "n_links": 5,
@ -83,7 +83,7 @@ register(
register( register(
id='ViaPointReacher-v0', id='ViaPointReacher-v0',
entry_point='alr_envs.alr.classic_control:ViaPointReacherEnv', entry_point='alr_envs.envs.classic_control:ViaPointReacherEnv',
max_episode_steps=200, max_episode_steps=200,
kwargs={ kwargs={
"n_links": 5, "n_links": 5,
@ -95,7 +95,7 @@ register(
## Hole Reacher ## Hole Reacher
register( register(
id='HoleReacher-v0', id='HoleReacher-v0',
entry_point='alr_envs.alr.classic_control:HoleReacherEnv', entry_point='alr_envs.envs.classic_control:HoleReacherEnv',
max_episode_steps=200, max_episode_steps=200,
kwargs={ kwargs={
"n_links": 5, "n_links": 5,
@ -115,7 +115,7 @@ register(
for _dims in [5, 7]: for _dims in [5, 7]:
register( register(
id=f'Reacher{_dims}d-v0', id=f'Reacher{_dims}d-v0',
entry_point='alr_envs.alr.mujoco:ReacherEnv', entry_point='alr_envs.envs.mujoco:ReacherEnv',
max_episode_steps=200, max_episode_steps=200,
kwargs={ kwargs={
"n_links": _dims, "n_links": _dims,
@ -124,7 +124,7 @@ for _dims in [5, 7]:
register( register(
id=f'Reacher{_dims}dSparse-v0', id=f'Reacher{_dims}dSparse-v0',
entry_point='alr_envs.alr.mujoco:ReacherEnv', entry_point='alr_envs.envs.mujoco:ReacherEnv',
max_episode_steps=200, max_episode_steps=200,
kwargs={ kwargs={
"sparse": True, "sparse": True,
@ -134,7 +134,7 @@ for _dims in [5, 7]:
register( register(
id='HopperJumpSparse-v0', id='HopperJumpSparse-v0',
entry_point='alr_envs.alr.mujoco:HopperJumpEnv', entry_point='alr_envs.envs.mujoco:HopperJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP, max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
kwargs={ kwargs={
"sparse": True, "sparse": True,
@ -143,7 +143,7 @@ register(
register( register(
id='HopperJump-v0', id='HopperJump-v0',
entry_point='alr_envs.alr.mujoco:HopperJumpEnv', entry_point='alr_envs.envs.mujoco:HopperJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP, max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
kwargs={ kwargs={
"sparse": False, "sparse": False,
@ -155,43 +155,43 @@ register(
register( register(
id='ALRAntJump-v0', id='ALRAntJump-v0',
entry_point='alr_envs.alr.mujoco:AntJumpEnv', entry_point='alr_envs.envs.mujoco:AntJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP, max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP,
) )
register( register(
id='ALRHalfCheetahJump-v0', id='ALRHalfCheetahJump-v0',
entry_point='alr_envs.alr.mujoco:ALRHalfCheetahJumpEnv', entry_point='alr_envs.envs.mujoco:ALRHalfCheetahJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP, max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
) )
register( register(
id='HopperJumpOnBox-v0', id='HopperJumpOnBox-v0',
entry_point='alr_envs.alr.mujoco:HopperJumpOnBoxEnv', entry_point='alr_envs.envs.mujoco:HopperJumpOnBoxEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX, max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
) )
register( register(
id='ALRHopperThrow-v0', id='ALRHopperThrow-v0',
entry_point='alr_envs.alr.mujoco:ALRHopperThrowEnv', entry_point='alr_envs.envs.mujoco:ALRHopperThrowEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW, max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
) )
register( register(
id='ALRHopperThrowInBasket-v0', id='ALRHopperThrowInBasket-v0',
entry_point='alr_envs.alr.mujoco:ALRHopperThrowInBasketEnv', entry_point='alr_envs.envs.mujoco:ALRHopperThrowInBasketEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET, max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
) )
register( register(
id='ALRWalker2DJump-v0', id='ALRWalker2DJump-v0',
entry_point='alr_envs.alr.mujoco:ALRWalker2dJumpEnv', entry_point='alr_envs.envs.mujoco:ALRWalker2dJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP, max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
) )
register( register(
id='BeerPong-v0', id='BeerPong-v0',
entry_point='alr_envs.alr.mujoco:BeerPongEnv', entry_point='alr_envs.envs.mujoco:BeerPongEnv',
max_episode_steps=300, max_episode_steps=300,
) )
@ -199,14 +199,14 @@ register(
# only one time step, i.e. we simulate until the end of th episode # only one time step, i.e. we simulate until the end of th episode
register( register(
id='BeerPongStepBased-v0', id='BeerPongStepBased-v0',
entry_point='alr_envs.alr.mujoco:BeerPongEnvStepBasedEpisodicReward', entry_point='alr_envs.envs.mujoco:BeerPongEnvStepBasedEpisodicReward',
max_episode_steps=300, max_episode_steps=300,
) )
# Beerpong with episodic reward, but fixed release time step # Beerpong with episodic reward, but fixed release time step
register( register(
id='BeerPongFixedRelease-v0', id='BeerPongFixedRelease-v0',
entry_point='alr_envs.alr.mujoco:BeerPongEnvFixedReleaseStep', entry_point='alr_envs.envs.mujoco:BeerPongEnvFixedReleaseStep',
max_episode_steps=300, max_episode_steps=300,
) )
@ -229,7 +229,7 @@ for _v in _versions:
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_simple_reacher_dmp kwargs=kwargs_dict_simple_reacher_dmp
) )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
_env_id = f'{_name[0]}ProMP-{_name[1]}' _env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_simple_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP) kwargs_dict_simple_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
@ -242,7 +242,7 @@ for _v in _versions:
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_simple_reacher_promp kwargs=kwargs_dict_simple_reacher_promp
) )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
# Viapoint reacher # Viapoint reacher
kwargs_dict_via_point_reacher_dmp = deepcopy(DEFAULT_BB_DICT_DMP) kwargs_dict_via_point_reacher_dmp = deepcopy(DEFAULT_BB_DICT_DMP)
@ -257,7 +257,7 @@ register(
# max_episode_steps=1, # max_episode_steps=1,
kwargs=kwargs_dict_via_point_reacher_dmp kwargs=kwargs_dict_via_point_reacher_dmp
) )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append("ViaPointReacherDMP-v0") ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append("ViaPointReacherDMP-v0")
kwargs_dict_via_point_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP) kwargs_dict_via_point_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
kwargs_dict_via_point_reacher_promp['wrappers'].append(classic_control.viapoint_reacher.MPWrapper) kwargs_dict_via_point_reacher_promp['wrappers'].append(classic_control.viapoint_reacher.MPWrapper)
@ -268,7 +268,7 @@ register(
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_via_point_reacher_promp kwargs=kwargs_dict_via_point_reacher_promp
) )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ViaPointReacherProMP-v0") ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append("ViaPointReacherProMP-v0")
## Hole Reacher ## Hole Reacher
_versions = ["HoleReacher-v0"] _versions = ["HoleReacher-v0"]
@ -288,7 +288,7 @@ for _v in _versions:
# max_episode_steps=1, # max_episode_steps=1,
kwargs=kwargs_dict_hole_reacher_dmp kwargs=kwargs_dict_hole_reacher_dmp
) )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
_env_id = f'{_name[0]}ProMP-{_name[1]}' _env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_hole_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP) kwargs_dict_hole_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
@ -301,7 +301,7 @@ for _v in _versions:
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_hole_reacher_promp kwargs=kwargs_dict_hole_reacher_promp
) )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
## ReacherNd ## ReacherNd
_versions = ["Reacher5d-v0", "Reacher7d-v0", "Reacher5dSparse-v0", "Reacher7dSparse-v0"] _versions = ["Reacher5d-v0", "Reacher7d-v0", "Reacher5dSparse-v0", "Reacher7dSparse-v0"]
@ -320,7 +320,7 @@ for _v in _versions:
# max_episode_steps=1, # max_episode_steps=1,
kwargs=kwargs_dict_reacherNd_dmp kwargs=kwargs_dict_reacherNd_dmp
) )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id)
_env_id = f'{_name[0]}ProMP-{_name[1]}' _env_id = f'{_name[0]}ProMP-{_name[1]}'
kwargs_dict_alr_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP) kwargs_dict_alr_reacher_promp = deepcopy(DEFAULT_BB_DICT_ProMP)
@ -333,7 +333,7 @@ for _v in _versions:
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_alr_reacher_promp kwargs=kwargs_dict_alr_reacher_promp
) )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
######################################################################################################################## ########################################################################################################################
## Beerpong ProMP ## Beerpong ProMP
@ -354,7 +354,7 @@ for _v in _versions:
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_bp_promp kwargs=kwargs_dict_bp_promp
) )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
### BP with Fixed release ### BP with Fixed release
_versions = ["BeerPongStepBased-v0", "BeerPongFixedRelease-v0"] _versions = ["BeerPongStepBased-v0", "BeerPongFixedRelease-v0"]
@ -374,7 +374,7 @@ for _v in _versions:
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_bp_promp kwargs=kwargs_dict_bp_promp
) )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
######################################################################################################################## ########################################################################################################################
## Table Tennis needs to be fixed according to Zhou's implementation ## Table Tennis needs to be fixed according to Zhou's implementation
@ -395,7 +395,7 @@ for _v in _versions:
# entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper', # entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
# kwargs=kwargs_dict_ant_jump_promp # kwargs=kwargs_dict_ant_jump_promp
# ) # )
# ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) # ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
# #
# ######################################################################################################################## # ########################################################################################################################
# #
@ -412,7 +412,7 @@ for _v in _versions:
# entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper', # entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
# kwargs=kwargs_dict_halfcheetah_jump_promp # kwargs=kwargs_dict_halfcheetah_jump_promp
# ) # )
# ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) # ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
# #
# ######################################################################################################################## # ########################################################################################################################
@ -433,7 +433,7 @@ for _v in _versions:
entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper', entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
kwargs=kwargs_dict_hopper_jump_promp kwargs=kwargs_dict_hopper_jump_promp
) )
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
# ######################################################################################################################## # ########################################################################################################################
# #
@ -451,13 +451,13 @@ for _v in _versions:
# entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper', # entry_point='alr_envs.utils.make_env_helpers:make_bb_env_helper',
# kwargs=kwargs_dict_walker2d_jump_promp # kwargs=kwargs_dict_walker2d_jump_promp
# ) # )
# ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) # ALL_ALR_MOVEMENT_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
### Depricated, we will not provide non random starts anymore ### Depricated, we will not provide non random starts anymore
""" """
register( register(
id='SimpleReacher-v1', id='SimpleReacher-v1',
entry_point='alr_envs.alr.classic_control:SimpleReacherEnv', entry_point='alr_envs.envs.classic_control:SimpleReacherEnv',
max_episode_steps=200, max_episode_steps=200,
kwargs={ kwargs={
"n_links": 2, "n_links": 2,
@ -467,7 +467,7 @@ register(
register( register(
id='LongSimpleReacher-v1', id='LongSimpleReacher-v1',
entry_point='alr_envs.alr.classic_control:SimpleReacherEnv', entry_point='alr_envs.envs.classic_control:SimpleReacherEnv',
max_episode_steps=200, max_episode_steps=200,
kwargs={ kwargs={
"n_links": 5, "n_links": 5,
@ -476,7 +476,7 @@ register(
) )
register( register(
id='HoleReacher-v1', id='HoleReacher-v1',
entry_point='alr_envs.alr.classic_control:HoleReacherEnv', entry_point='alr_envs.envs.classic_control:HoleReacherEnv',
max_episode_steps=200, max_episode_steps=200,
kwargs={ kwargs={
"n_links": 5, "n_links": 5,
@ -491,7 +491,7 @@ register(
) )
register( register(
id='HoleReacher-v2', id='HoleReacher-v2',
entry_point='alr_envs.alr.classic_control:HoleReacherEnv', entry_point='alr_envs.envs.classic_control:HoleReacherEnv',
max_episode_steps=200, max_episode_steps=200,
kwargs={ kwargs={
"n_links": 5, "n_links": 5,
@ -508,7 +508,7 @@ register(
# CtxtFree are v0, Contextual are v1 # CtxtFree are v0, Contextual are v1
register( register(
id='ALRAntJump-v0', id='ALRAntJump-v0',
entry_point='alr_envs.alr.mujoco:AntJumpEnv', entry_point='alr_envs.envs.mujoco:AntJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP, max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP,
kwargs={ kwargs={
"max_episode_steps": MAX_EPISODE_STEPS_ANTJUMP, "max_episode_steps": MAX_EPISODE_STEPS_ANTJUMP,
@ -518,7 +518,7 @@ register(
# CtxtFree are v0, Contextual are v1 # CtxtFree are v0, Contextual are v1
register( register(
id='ALRHalfCheetahJump-v0', id='ALRHalfCheetahJump-v0',
entry_point='alr_envs.alr.mujoco:ALRHalfCheetahJumpEnv', entry_point='alr_envs.envs.mujoco:ALRHalfCheetahJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP, max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
kwargs={ kwargs={
"max_episode_steps": MAX_EPISODE_STEPS_HALFCHEETAHJUMP, "max_episode_steps": MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
@ -527,7 +527,7 @@ register(
) )
register( register(
id='ALRHopperJump-v0', id='ALRHopperJump-v0',
entry_point='alr_envs.alr.mujoco:HopperJumpEnv', entry_point='alr_envs.envs.mujoco:HopperJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP, max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMP,
kwargs={ kwargs={
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMP, "max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMP,
@ -545,7 +545,7 @@ for i in _vs:
_env_id = f'ALRReacher{i}-v0' _env_id = f'ALRReacher{i}-v0'
register( register(
id=_env_id, id=_env_id,
entry_point='alr_envs.alr.mujoco:ReacherEnv', entry_point='alr_envs.envs.mujoco:ReacherEnv',
max_episode_steps=200, max_episode_steps=200,
kwargs={ kwargs={
"steps_before_reward": 0, "steps_before_reward": 0,
@ -558,7 +558,7 @@ for i in _vs:
_env_id = f'ALRReacherSparse{i}-v0' _env_id = f'ALRReacherSparse{i}-v0'
register( register(
id=_env_id, id=_env_id,
entry_point='alr_envs.alr.mujoco:ReacherEnv', entry_point='alr_envs.envs.mujoco:ReacherEnv',
max_episode_steps=200, max_episode_steps=200,
kwargs={ kwargs={
"steps_before_reward": 200, "steps_before_reward": 200,
@ -617,7 +617,7 @@ for i in _vs:
register( register(
id='ALRHopperJumpOnBox-v0', id='ALRHopperJumpOnBox-v0',
entry_point='alr_envs.alr.mujoco:HopperJumpOnBoxEnv', entry_point='alr_envs.envs.mujoco:HopperJumpOnBoxEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX, max_episode_steps=MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
kwargs={ kwargs={
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMPONBOX, "max_episode_steps": MAX_EPISODE_STEPS_HOPPERJUMPONBOX,
@ -626,7 +626,7 @@ for i in _vs:
) )
register( register(
id='ALRHopperThrow-v0', id='ALRHopperThrow-v0',
entry_point='alr_envs.alr.mujoco:ALRHopperThrowEnv', entry_point='alr_envs.envs.mujoco:ALRHopperThrowEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW, max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
kwargs={ kwargs={
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERTHROW, "max_episode_steps": MAX_EPISODE_STEPS_HOPPERTHROW,
@ -635,7 +635,7 @@ for i in _vs:
) )
register( register(
id='ALRHopperThrowInBasket-v0', id='ALRHopperThrowInBasket-v0',
entry_point='alr_envs.alr.mujoco:ALRHopperThrowInBasketEnv', entry_point='alr_envs.envs.mujoco:ALRHopperThrowInBasketEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET, max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
kwargs={ kwargs={
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERTHROWINBASKET, "max_episode_steps": MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
@ -644,7 +644,7 @@ for i in _vs:
) )
register( register(
id='ALRWalker2DJump-v0', id='ALRWalker2DJump-v0',
entry_point='alr_envs.alr.mujoco:ALRWalker2dJumpEnv', entry_point='alr_envs.envs.mujoco:ALRWalker2dJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP, max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
kwargs={ kwargs={
"max_episode_steps": MAX_EPISODE_STEPS_WALKERJUMP, "max_episode_steps": MAX_EPISODE_STEPS_WALKERJUMP,
@ -652,13 +652,13 @@ for i in _vs:
} }
) )
register(id='TableTennis2DCtxt-v1', register(id='TableTennis2DCtxt-v1',
entry_point='alr_envs.alr.mujoco:TTEnvGym', entry_point='alr_envs.envs.mujoco:TTEnvGym',
max_episode_steps=MAX_EPISODE_STEPS, max_episode_steps=MAX_EPISODE_STEPS,
kwargs={'ctxt_dim': 2, 'fixed_goal': True}) kwargs={'ctxt_dim': 2, 'fixed_goal': True})
register( register(
id='ALRBeerPong-v0', id='ALRBeerPong-v0',
entry_point='alr_envs.alr.mujoco:ALRBeerBongEnv', entry_point='alr_envs.envs.mujoco:ALRBeerBongEnv',
max_episode_steps=300, max_episode_steps=300,
kwargs={ kwargs={
"rndm_goal": False, "rndm_goal": False,

View File

@ -7,7 +7,7 @@ from gym import spaces
from gym.core import ObsType from gym.core import ObsType
from gym.utils import seeding from gym.utils import seeding
from alr_envs.alr.classic_control.utils import intersect from alr_envs.envs.classic_control.utils import intersect
class BaseReacherEnv(gym.Env, ABC): class BaseReacherEnv(gym.Env, ABC):

View File

@ -2,7 +2,7 @@ from abc import ABC
from gym import spaces from gym import spaces
import numpy as np import numpy as np
from alr_envs.alr.classic_control.base_reacher.base_reacher import BaseReacherEnv from alr_envs.envs.classic_control.base_reacher.base_reacher import BaseReacherEnv
class BaseReacherDirectEnv(BaseReacherEnv, ABC): class BaseReacherDirectEnv(BaseReacherEnv, ABC):

View File

@ -2,7 +2,7 @@ from abc import ABC
from gym import spaces from gym import spaces
import numpy as np import numpy as np
from alr_envs.alr.classic_control.base_reacher.base_reacher import BaseReacherEnv from alr_envs.envs.classic_control.base_reacher.base_reacher import BaseReacherEnv
class BaseReacherTorqueEnv(BaseReacherEnv, ABC): class BaseReacherTorqueEnv(BaseReacherEnv, ABC):

View File

@ -6,7 +6,7 @@ import numpy as np
from gym.core import ObsType from gym.core import ObsType
from matplotlib import patches from matplotlib import patches
from alr_envs.alr.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv from alr_envs.envs.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv
class HoleReacherEnv(BaseReacherDirectEnv): class HoleReacherEnv(BaseReacherDirectEnv):
@ -41,13 +41,13 @@ class HoleReacherEnv(BaseReacherDirectEnv):
self.observation_space = gym.spaces.Box(low=-state_bound, high=state_bound, shape=state_bound.shape) self.observation_space = gym.spaces.Box(low=-state_bound, high=state_bound, shape=state_bound.shape)
if rew_fct == "simple": if rew_fct == "simple":
from alr_envs.alr.classic_control.hole_reacher.hr_simple_reward import HolereacherReward from alr_envs.envs.classic_control.hole_reacher.hr_simple_reward import HolereacherReward
self.reward_function = HolereacherReward(allow_self_collision, allow_wall_collision, collision_penalty) self.reward_function = HolereacherReward(allow_self_collision, allow_wall_collision, collision_penalty)
elif rew_fct == "vel_acc": elif rew_fct == "vel_acc":
from alr_envs.alr.classic_control.hole_reacher.hr_dist_vel_acc_reward import HolereacherReward from alr_envs.envs.classic_control.hole_reacher.hr_dist_vel_acc_reward import HolereacherReward
self.reward_function = HolereacherReward(allow_self_collision, allow_wall_collision, collision_penalty) self.reward_function = HolereacherReward(allow_self_collision, allow_wall_collision, collision_penalty)
elif rew_fct == "unbounded": elif rew_fct == "unbounded":
from alr_envs.alr.classic_control.hole_reacher.hr_unbounded_reward import HolereacherReward from alr_envs.envs.classic_control.hole_reacher.hr_unbounded_reward import HolereacherReward
self.reward_function = HolereacherReward(allow_self_collision, allow_wall_collision) self.reward_function = HolereacherReward(allow_self_collision, allow_wall_collision)
else: else:
raise ValueError("Unknown reward function {}".format(rew_fct)) raise ValueError("Unknown reward function {}".format(rew_fct))

View File

@ -5,7 +5,7 @@ import numpy as np
from gym import spaces from gym import spaces
from gym.core import ObsType from gym.core import ObsType
from alr_envs.alr.classic_control.base_reacher.base_reacher_torque import BaseReacherTorqueEnv from alr_envs.envs.classic_control.base_reacher.base_reacher_torque import BaseReacherTorqueEnv
class SimpleReacherEnv(BaseReacherTorqueEnv): class SimpleReacherEnv(BaseReacherTorqueEnv):

View File

@ -6,7 +6,7 @@ import numpy as np
from gym.core import ObsType from gym.core import ObsType
from gym.utils import seeding from gym.utils import seeding
from alr_envs.alr.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv from alr_envs.envs.classic_control.base_reacher.base_reacher_direct import BaseReacherDirectEnv
class ViaPointReacherEnv(BaseReacherDirectEnv): class ViaPointReacherEnv(BaseReacherDirectEnv):

View File

@ -55,7 +55,7 @@ class AntJumpEnv(AntEnv):
costs = ctrl_cost + contact_cost costs = ctrl_cost + contact_cost
done = height < 0.3 # fall over -> is the 0.3 value from healthy_z_range? TODO change 0.3 to the value of healthy z angle done = bool(height < 0.3) # fall over -> is the 0.3 value from healthy_z_range? TODO change 0.3 to the value of healthy z angle
if self.current_step == MAX_EPISODE_STEPS_ANTJUMP or done: if self.current_step == MAX_EPISODE_STEPS_ANTJUMP or done:
# -10 for scaling the value of the distance between the max_height and the goal height; only used when context is enabled # -10 for scaling the value of the distance between the max_height and the goal height; only used when context is enabled
@ -84,8 +84,8 @@ class AntJumpEnv(AntEnv):
options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]: options: Optional[dict] = None, ) -> Union[ObsType, Tuple[ObsType, dict]]:
self.current_step = 0 self.current_step = 0
self.max_height = 0 self.max_height = 0
self.goal = self.np_random.uniform(1.0, 2.5, # goal heights from 1.0 to 2.5; can be increased, but didnt work well with CMORE
1) # goal heights from 1.0 to 2.5; can be increased, but didnt work well with CMORE self.goal = self.np_random.uniform(1.0, 2.5, 1)
return super().reset() return super().reset()
# reset_model had to be implemented in every env to make it deterministic # reset_model had to be implemented in every env to make it deterministic

View File

@ -5,7 +5,7 @@ import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import MujocoEnv from gym.envs.mujoco import MujocoEnv
from alr_envs.alr.mujoco.beerpong.deprecated.beerpong_reward_staged import BeerPongReward from alr_envs.envs.mujoco.beerpong.deprecated.beerpong_reward_staged import BeerPongReward
class BeerPongEnv(MujocoEnv, utils.EzPickle): class BeerPongEnv(MujocoEnv, utils.EzPickle):

Some files were not shown because too many files have changed in this diff Show More