fixed seeding and tests

This commit is contained in:
Fabian 2022-07-12 15:43:46 +02:00
parent 0339361656
commit d64cb614fa
10 changed files with 36 additions and 34 deletions

View File

@ -154,14 +154,14 @@ register(
)
register(
id='ALRAntJump-v0',
id='AntJump-v0',
entry_point='alr_envs.envs.mujoco:AntJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP,
)
register(
id='ALRHalfCheetahJump-v0',
entry_point='alr_envs.envs.mujoco:ALRHalfCheetahJumpEnv',
id='HalfCheetahJump-v0',
entry_point='alr_envs.envs.mujoco:HalfCheetahJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
)
@ -173,19 +173,19 @@ register(
register(
id='ALRHopperThrow-v0',
entry_point='alr_envs.envs.mujoco:ALRHopperThrowEnv',
entry_point='alr_envs.envs.mujoco:HopperThrowEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
)
register(
id='ALRHopperThrowInBasket-v0',
entry_point='alr_envs.envs.mujoco:ALRHopperThrowInBasketEnv',
entry_point='alr_envs.envs.mujoco:HopperThrowInBasketEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
)
register(
id='ALRWalker2DJump-v0',
entry_point='alr_envs.envs.mujoco:ALRWalker2dJumpEnv',
entry_point='alr_envs.envs.mujoco:Walker2dJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
)
@ -518,7 +518,7 @@ register(
# CtxtFree are v0, Contextual are v1
register(
id='ALRHalfCheetahJump-v0',
entry_point='alr_envs.envs.mujoco:ALRHalfCheetahJumpEnv',
entry_point='alr_envs.envs.mujoco:HalfCheetahJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
kwargs={
"max_episode_steps": MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
@ -626,7 +626,7 @@ for i in _vs:
)
register(
id='ALRHopperThrow-v0',
entry_point='alr_envs.envs.mujoco:ALRHopperThrowEnv',
entry_point='alr_envs.envs.mujoco:HopperThrowEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
kwargs={
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERTHROW,
@ -635,7 +635,7 @@ for i in _vs:
)
register(
id='ALRHopperThrowInBasket-v0',
entry_point='alr_envs.envs.mujoco:ALRHopperThrowInBasketEnv',
entry_point='alr_envs.envs.mujoco:HopperThrowInBasketEnv',
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
kwargs={
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
@ -644,7 +644,7 @@ for i in _vs:
)
register(
id='ALRWalker2DJump-v0',
entry_point='alr_envs.envs.mujoco:ALRWalker2dJumpEnv',
entry_point='alr_envs.envs.mujoco:Walker2dJumpEnv',
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
kwargs={
"max_episode_steps": MAX_EPISODE_STEPS_WALKERJUMP,

View File

@ -1,9 +1,9 @@
from .beerpong.beerpong import BeerPongEnv, BeerPongEnvFixedReleaseStep, BeerPongEnvStepBasedEpisodicReward
from .ant_jump.ant_jump import AntJumpEnv
from .half_cheetah_jump.half_cheetah_jump import ALRHalfCheetahJumpEnv
from .half_cheetah_jump.half_cheetah_jump import HalfCheetahJumpEnv
from .hopper_jump.hopper_jump_on_box import HopperJumpOnBoxEnv
from .hopper_throw.hopper_throw import ALRHopperThrowEnv
from .hopper_throw.hopper_throw_in_basket import ALRHopperThrowInBasketEnv
from .hopper_throw.hopper_throw import HopperThrowEnv
from .hopper_throw.hopper_throw_in_basket import HopperThrowInBasketEnv
from .reacher.reacher import ReacherEnv
from .walker_2d_jump.walker_2d_jump import ALRWalker2dJumpEnv
from .walker_2d_jump.walker_2d_jump import Walker2dJumpEnv
from .hopper_jump.hopper_jump import HopperJumpEnv

View File

@ -155,7 +155,7 @@ class BeerPongEnvStepBasedEpisodicReward(BeerPongEnv):
return ob, reward, done, infos
# class ALRBeerBongEnvStepBased(ALRBeerBongEnv):
# class BeerBongEnvStepBased(ALRBeerBongEnv):
# def __init__(self, frame_skip=1, apply_gravity_comp=True, noisy=False, rndm_goal=False, cup_goal_pos=None):
# super().__init__(frame_skip, apply_gravity_comp, noisy, rndm_goal, cup_goal_pos)
# self.release_step = 62 # empirically evaluated for frame_skip=2!

View File

@ -8,7 +8,7 @@ import numpy as np
MAX_EPISODE_STEPS_HALFCHEETAHJUMP = 100
class ALRHalfCheetahJumpEnv(HalfCheetahEnv):
class HalfCheetahJumpEnv(HalfCheetahEnv):
"""
ctrl_cost_weight 0.1 -> 0.0
"""

View File

@ -7,7 +7,7 @@ import numpy as np
MAX_EPISODE_STEPS_HOPPERTHROW = 250
class ALRHopperThrowEnv(HopperEnv):
class HopperThrowEnv(HopperEnv):
"""
Initialization changes to normal Hopper:
- healthy_reward: 1.0 -> 0.0 -> 0.1
@ -98,7 +98,7 @@ class ALRHopperThrowEnv(HopperEnv):
if __name__ == '__main__':
render_mode = "human" # "human" or "partial" or "final"
env = ALRHopperThrowEnv()
env = HopperThrowEnv()
obs = env.reset()
for i in range(2000):

View File

@ -8,7 +8,7 @@ import numpy as np
MAX_EPISODE_STEPS_HOPPERTHROWINBASKET = 250
class ALRHopperThrowInBasketEnv(HopperEnv):
class HopperThrowInBasketEnv(HopperEnv):
"""
Initialization changes to normal Hopper:
- healthy_reward: 1.0 -> 0.0
@ -130,7 +130,7 @@ class ALRHopperThrowInBasketEnv(HopperEnv):
if __name__ == '__main__':
render_mode = "human" # "human" or "partial" or "final"
env = ALRHopperThrowInBasketEnv()
env = HopperThrowInBasketEnv()
obs = env.reset()
for i in range(2000):

View File

@ -12,7 +12,7 @@ MAX_EPISODE_STEPS_WALKERJUMP = 300
# as possible, while landing at a specific target position
class ALRWalker2dJumpEnv(Walker2dEnv):
class Walker2dJumpEnv(Walker2dEnv):
"""
healthy reward 1.0 -> 0.005 -> 0.0025 not from alex
penalty 10 -> 0 not from alex
@ -95,7 +95,7 @@ class ALRWalker2dJumpEnv(Walker2dEnv):
if __name__ == '__main__':
render_mode = "human" # "human" or "partial" or "final"
env = ALRWalker2dJumpEnv()
env = Walker2dJumpEnv()
obs = env.reset()
for i in range(6000):

View File

@ -1,9 +1,10 @@
from collections import OrderedDict
import numpy as np
from matplotlib import pyplot as plt
from alr_envs import dmc, meta
from alr_envs import make_bb, dmc, meta
from alr_envs.envs import mujoco
from alr_envs.utils.make_env_helpers import make_promp_env
def visualize(env):
@ -16,11 +17,12 @@ def visualize(env):
# This might work for some environments, however, please verify either way the correct trajectory information
# for your environment are extracted below
SEED = 1
# env_id = "ball_in_cup-catch"
env_id = "ALRReacherSparse-v0"
env_id = "button-press-v2"
# env_id = "dmc:ball_in_cup-catch"
# wrappers = [dmc.suite.ball_in_cup.MPWrapper]
env_id = "Reacher5dSparse-v0"
wrappers = [mujoco.reacher.MPWrapper]
wrappers = [meta.goal_object_change_mp_wrapper.MPWrapper]
# env_id = "metaworld:button-press-v2"
# wrappers = [meta.goal_object_change_mp_wrapper.MPWrapper]
mp_kwargs = {
"num_dof": 4,
@ -38,7 +40,7 @@ mp_kwargs = {
# kwargs = dict(time_limit=4, episode_length=200)
kwargs = {}
env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs, **kwargs)
env = make_bb(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs, **kwargs)
env.action_space.seed(SEED)
# Plot difference between real trajectory and target MP trajectory
@ -59,7 +61,7 @@ img = ax.imshow(env.env.render("rgb_array"))
fig.show()
for t, pos_vel in enumerate(zip(pos, vel)):
actions = env.policy.get_action(pos_vel[0], pos_vel[1],, self.current_vel, self.current_pos
actions = env.policy.get_action(pos_vel[0], pos_vel[1], env.current_vel, env.current_pos)
actions = np.clip(actions, env.full_action_space.low, env.full_action_space.high)
_, _, _, _ = env.env.step(actions)
if t % 15 == 0:
@ -81,7 +83,6 @@ p2 = plt.plot(pos, c='C1', label="MP") # , label=["MP" if i == 0 else None for
plt.xlabel("Episode steps")
# plt.legend()
handles, labels = plt.gca().get_legend_handles_labels()
from collections import OrderedDict
by_label = OrderedDict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys())

View File

@ -7,7 +7,7 @@ import alr_envs # noqa
from alr_envs.utils.make_env_helpers import make
CUSTOM_IDS = [spec.id for spec in gym.envs.registry.all() if
"alr_envs" in spec.entry_point and not 'make_bb_env_helper' in spec.entry_point]
"alr_envs" in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point]
SEED = 1

View File

@ -6,7 +6,8 @@ import numpy as np
import alr_envs
from alr_envs import make
METAWORLD_IDS = []
GYM_IDS = [spec.id for spec in gym.envs.registry.all() if
"alr_envs" not in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point]
SEED = 1
@ -58,7 +59,7 @@ class TestGymEnvironments(unittest.TestCase):
if done:
break
assert done, "Done flag is not True after end of episode."
assert done or env.spec.max_episode_steps is None, "Done flag is not True after end of episode."
observations.append(obs)
env.close()
del env