fixed seeding and tests
This commit is contained in:
parent
0339361656
commit
d64cb614fa
@ -154,14 +154,14 @@ register(
|
|||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRAntJump-v0',
|
id='AntJump-v0',
|
||||||
entry_point='alr_envs.envs.mujoco:AntJumpEnv',
|
entry_point='alr_envs.envs.mujoco:AntJumpEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP,
|
max_episode_steps=MAX_EPISODE_STEPS_ANTJUMP,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRHalfCheetahJump-v0',
|
id='HalfCheetahJump-v0',
|
||||||
entry_point='alr_envs.envs.mujoco:ALRHalfCheetahJumpEnv',
|
entry_point='alr_envs.envs.mujoco:HalfCheetahJumpEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
|
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -173,19 +173,19 @@ register(
|
|||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRHopperThrow-v0',
|
id='ALRHopperThrow-v0',
|
||||||
entry_point='alr_envs.envs.mujoco:ALRHopperThrowEnv',
|
entry_point='alr_envs.envs.mujoco:HopperThrowEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
|
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRHopperThrowInBasket-v0',
|
id='ALRHopperThrowInBasket-v0',
|
||||||
entry_point='alr_envs.envs.mujoco:ALRHopperThrowInBasketEnv',
|
entry_point='alr_envs.envs.mujoco:HopperThrowInBasketEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
|
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
|
||||||
)
|
)
|
||||||
|
|
||||||
register(
|
register(
|
||||||
id='ALRWalker2DJump-v0',
|
id='ALRWalker2DJump-v0',
|
||||||
entry_point='alr_envs.envs.mujoco:ALRWalker2dJumpEnv',
|
entry_point='alr_envs.envs.mujoco:Walker2dJumpEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
|
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -518,7 +518,7 @@ register(
|
|||||||
# CtxtFree are v0, Contextual are v1
|
# CtxtFree are v0, Contextual are v1
|
||||||
register(
|
register(
|
||||||
id='ALRHalfCheetahJump-v0',
|
id='ALRHalfCheetahJump-v0',
|
||||||
entry_point='alr_envs.envs.mujoco:ALRHalfCheetahJumpEnv',
|
entry_point='alr_envs.envs.mujoco:HalfCheetahJumpEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
|
max_episode_steps=MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
|
||||||
kwargs={
|
kwargs={
|
||||||
"max_episode_steps": MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
|
"max_episode_steps": MAX_EPISODE_STEPS_HALFCHEETAHJUMP,
|
||||||
@ -626,7 +626,7 @@ for i in _vs:
|
|||||||
)
|
)
|
||||||
register(
|
register(
|
||||||
id='ALRHopperThrow-v0',
|
id='ALRHopperThrow-v0',
|
||||||
entry_point='alr_envs.envs.mujoco:ALRHopperThrowEnv',
|
entry_point='alr_envs.envs.mujoco:HopperThrowEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
|
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROW,
|
||||||
kwargs={
|
kwargs={
|
||||||
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERTHROW,
|
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERTHROW,
|
||||||
@ -635,7 +635,7 @@ for i in _vs:
|
|||||||
)
|
)
|
||||||
register(
|
register(
|
||||||
id='ALRHopperThrowInBasket-v0',
|
id='ALRHopperThrowInBasket-v0',
|
||||||
entry_point='alr_envs.envs.mujoco:ALRHopperThrowInBasketEnv',
|
entry_point='alr_envs.envs.mujoco:HopperThrowInBasketEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
|
max_episode_steps=MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
|
||||||
kwargs={
|
kwargs={
|
||||||
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
|
"max_episode_steps": MAX_EPISODE_STEPS_HOPPERTHROWINBASKET,
|
||||||
@ -644,7 +644,7 @@ for i in _vs:
|
|||||||
)
|
)
|
||||||
register(
|
register(
|
||||||
id='ALRWalker2DJump-v0',
|
id='ALRWalker2DJump-v0',
|
||||||
entry_point='alr_envs.envs.mujoco:ALRWalker2dJumpEnv',
|
entry_point='alr_envs.envs.mujoco:Walker2dJumpEnv',
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
|
max_episode_steps=MAX_EPISODE_STEPS_WALKERJUMP,
|
||||||
kwargs={
|
kwargs={
|
||||||
"max_episode_steps": MAX_EPISODE_STEPS_WALKERJUMP,
|
"max_episode_steps": MAX_EPISODE_STEPS_WALKERJUMP,
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
from .beerpong.beerpong import BeerPongEnv, BeerPongEnvFixedReleaseStep, BeerPongEnvStepBasedEpisodicReward
|
from .beerpong.beerpong import BeerPongEnv, BeerPongEnvFixedReleaseStep, BeerPongEnvStepBasedEpisodicReward
|
||||||
from .ant_jump.ant_jump import AntJumpEnv
|
from .ant_jump.ant_jump import AntJumpEnv
|
||||||
from .half_cheetah_jump.half_cheetah_jump import ALRHalfCheetahJumpEnv
|
from .half_cheetah_jump.half_cheetah_jump import HalfCheetahJumpEnv
|
||||||
from .hopper_jump.hopper_jump_on_box import HopperJumpOnBoxEnv
|
from .hopper_jump.hopper_jump_on_box import HopperJumpOnBoxEnv
|
||||||
from .hopper_throw.hopper_throw import ALRHopperThrowEnv
|
from .hopper_throw.hopper_throw import HopperThrowEnv
|
||||||
from .hopper_throw.hopper_throw_in_basket import ALRHopperThrowInBasketEnv
|
from .hopper_throw.hopper_throw_in_basket import HopperThrowInBasketEnv
|
||||||
from .reacher.reacher import ReacherEnv
|
from .reacher.reacher import ReacherEnv
|
||||||
from .walker_2d_jump.walker_2d_jump import ALRWalker2dJumpEnv
|
from .walker_2d_jump.walker_2d_jump import Walker2dJumpEnv
|
||||||
from .hopper_jump.hopper_jump import HopperJumpEnv
|
from .hopper_jump.hopper_jump import HopperJumpEnv
|
||||||
|
@ -155,7 +155,7 @@ class BeerPongEnvStepBasedEpisodicReward(BeerPongEnv):
|
|||||||
return ob, reward, done, infos
|
return ob, reward, done, infos
|
||||||
|
|
||||||
|
|
||||||
# class ALRBeerBongEnvStepBased(ALRBeerBongEnv):
|
# class BeerBongEnvStepBased(ALRBeerBongEnv):
|
||||||
# def __init__(self, frame_skip=1, apply_gravity_comp=True, noisy=False, rndm_goal=False, cup_goal_pos=None):
|
# def __init__(self, frame_skip=1, apply_gravity_comp=True, noisy=False, rndm_goal=False, cup_goal_pos=None):
|
||||||
# super().__init__(frame_skip, apply_gravity_comp, noisy, rndm_goal, cup_goal_pos)
|
# super().__init__(frame_skip, apply_gravity_comp, noisy, rndm_goal, cup_goal_pos)
|
||||||
# self.release_step = 62 # empirically evaluated for frame_skip=2!
|
# self.release_step = 62 # empirically evaluated for frame_skip=2!
|
||||||
|
@ -8,7 +8,7 @@ import numpy as np
|
|||||||
MAX_EPISODE_STEPS_HALFCHEETAHJUMP = 100
|
MAX_EPISODE_STEPS_HALFCHEETAHJUMP = 100
|
||||||
|
|
||||||
|
|
||||||
class ALRHalfCheetahJumpEnv(HalfCheetahEnv):
|
class HalfCheetahJumpEnv(HalfCheetahEnv):
|
||||||
"""
|
"""
|
||||||
ctrl_cost_weight 0.1 -> 0.0
|
ctrl_cost_weight 0.1 -> 0.0
|
||||||
"""
|
"""
|
||||||
|
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
MAX_EPISODE_STEPS_HOPPERTHROW = 250
|
MAX_EPISODE_STEPS_HOPPERTHROW = 250
|
||||||
|
|
||||||
|
|
||||||
class ALRHopperThrowEnv(HopperEnv):
|
class HopperThrowEnv(HopperEnv):
|
||||||
"""
|
"""
|
||||||
Initialization changes to normal Hopper:
|
Initialization changes to normal Hopper:
|
||||||
- healthy_reward: 1.0 -> 0.0 -> 0.1
|
- healthy_reward: 1.0 -> 0.0 -> 0.1
|
||||||
@ -98,7 +98,7 @@ class ALRHopperThrowEnv(HopperEnv):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render_mode = "human" # "human" or "partial" or "final"
|
render_mode = "human" # "human" or "partial" or "final"
|
||||||
env = ALRHopperThrowEnv()
|
env = HopperThrowEnv()
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
for i in range(2000):
|
for i in range(2000):
|
||||||
|
@ -8,7 +8,7 @@ import numpy as np
|
|||||||
MAX_EPISODE_STEPS_HOPPERTHROWINBASKET = 250
|
MAX_EPISODE_STEPS_HOPPERTHROWINBASKET = 250
|
||||||
|
|
||||||
|
|
||||||
class ALRHopperThrowInBasketEnv(HopperEnv):
|
class HopperThrowInBasketEnv(HopperEnv):
|
||||||
"""
|
"""
|
||||||
Initialization changes to normal Hopper:
|
Initialization changes to normal Hopper:
|
||||||
- healthy_reward: 1.0 -> 0.0
|
- healthy_reward: 1.0 -> 0.0
|
||||||
@ -130,7 +130,7 @@ class ALRHopperThrowInBasketEnv(HopperEnv):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render_mode = "human" # "human" or "partial" or "final"
|
render_mode = "human" # "human" or "partial" or "final"
|
||||||
env = ALRHopperThrowInBasketEnv()
|
env = HopperThrowInBasketEnv()
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
for i in range(2000):
|
for i in range(2000):
|
||||||
|
@ -12,7 +12,7 @@ MAX_EPISODE_STEPS_WALKERJUMP = 300
|
|||||||
# as possible, while landing at a specific target position
|
# as possible, while landing at a specific target position
|
||||||
|
|
||||||
|
|
||||||
class ALRWalker2dJumpEnv(Walker2dEnv):
|
class Walker2dJumpEnv(Walker2dEnv):
|
||||||
"""
|
"""
|
||||||
healthy reward 1.0 -> 0.005 -> 0.0025 not from alex
|
healthy reward 1.0 -> 0.005 -> 0.0025 not from alex
|
||||||
penalty 10 -> 0 not from alex
|
penalty 10 -> 0 not from alex
|
||||||
@ -95,7 +95,7 @@ class ALRWalker2dJumpEnv(Walker2dEnv):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
render_mode = "human" # "human" or "partial" or "final"
|
render_mode = "human" # "human" or "partial" or "final"
|
||||||
env = ALRWalker2dJumpEnv()
|
env = Walker2dJumpEnv()
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
for i in range(6000):
|
for i in range(6000):
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
from alr_envs import dmc, meta
|
from alr_envs import make_bb, dmc, meta
|
||||||
from alr_envs.envs import mujoco
|
from alr_envs.envs import mujoco
|
||||||
from alr_envs.utils.make_env_helpers import make_promp_env
|
|
||||||
|
|
||||||
|
|
||||||
def visualize(env):
|
def visualize(env):
|
||||||
@ -16,11 +17,12 @@ def visualize(env):
|
|||||||
# This might work for some environments, however, please verify either way the correct trajectory information
|
# This might work for some environments, however, please verify either way the correct trajectory information
|
||||||
# for your environment are extracted below
|
# for your environment are extracted below
|
||||||
SEED = 1
|
SEED = 1
|
||||||
# env_id = "ball_in_cup-catch"
|
# env_id = "dmc:ball_in_cup-catch"
|
||||||
env_id = "ALRReacherSparse-v0"
|
# wrappers = [dmc.suite.ball_in_cup.MPWrapper]
|
||||||
env_id = "button-press-v2"
|
env_id = "Reacher5dSparse-v0"
|
||||||
wrappers = [mujoco.reacher.MPWrapper]
|
wrappers = [mujoco.reacher.MPWrapper]
|
||||||
wrappers = [meta.goal_object_change_mp_wrapper.MPWrapper]
|
# env_id = "metaworld:button-press-v2"
|
||||||
|
# wrappers = [meta.goal_object_change_mp_wrapper.MPWrapper]
|
||||||
|
|
||||||
mp_kwargs = {
|
mp_kwargs = {
|
||||||
"num_dof": 4,
|
"num_dof": 4,
|
||||||
@ -38,7 +40,7 @@ mp_kwargs = {
|
|||||||
# kwargs = dict(time_limit=4, episode_length=200)
|
# kwargs = dict(time_limit=4, episode_length=200)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs, **kwargs)
|
env = make_bb(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs, **kwargs)
|
||||||
env.action_space.seed(SEED)
|
env.action_space.seed(SEED)
|
||||||
|
|
||||||
# Plot difference between real trajectory and target MP trajectory
|
# Plot difference between real trajectory and target MP trajectory
|
||||||
@ -59,7 +61,7 @@ img = ax.imshow(env.env.render("rgb_array"))
|
|||||||
fig.show()
|
fig.show()
|
||||||
|
|
||||||
for t, pos_vel in enumerate(zip(pos, vel)):
|
for t, pos_vel in enumerate(zip(pos, vel)):
|
||||||
actions = env.policy.get_action(pos_vel[0], pos_vel[1],, self.current_vel, self.current_pos
|
actions = env.policy.get_action(pos_vel[0], pos_vel[1], env.current_vel, env.current_pos)
|
||||||
actions = np.clip(actions, env.full_action_space.low, env.full_action_space.high)
|
actions = np.clip(actions, env.full_action_space.low, env.full_action_space.high)
|
||||||
_, _, _, _ = env.env.step(actions)
|
_, _, _, _ = env.env.step(actions)
|
||||||
if t % 15 == 0:
|
if t % 15 == 0:
|
||||||
@ -81,7 +83,6 @@ p2 = plt.plot(pos, c='C1', label="MP") # , label=["MP" if i == 0 else None for
|
|||||||
plt.xlabel("Episode steps")
|
plt.xlabel("Episode steps")
|
||||||
# plt.legend()
|
# plt.legend()
|
||||||
handles, labels = plt.gca().get_legend_handles_labels()
|
handles, labels = plt.gca().get_legend_handles_labels()
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
by_label = OrderedDict(zip(labels, handles))
|
by_label = OrderedDict(zip(labels, handles))
|
||||||
plt.legend(by_label.values(), by_label.keys())
|
plt.legend(by_label.values(), by_label.keys())
|
||||||
|
@ -7,7 +7,7 @@ import alr_envs # noqa
|
|||||||
from alr_envs.utils.make_env_helpers import make
|
from alr_envs.utils.make_env_helpers import make
|
||||||
|
|
||||||
CUSTOM_IDS = [spec.id for spec in gym.envs.registry.all() if
|
CUSTOM_IDS = [spec.id for spec in gym.envs.registry.all() if
|
||||||
"alr_envs" in spec.entry_point and not 'make_bb_env_helper' in spec.entry_point]
|
"alr_envs" in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point]
|
||||||
SEED = 1
|
SEED = 1
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,7 +6,8 @@ import numpy as np
|
|||||||
import alr_envs
|
import alr_envs
|
||||||
from alr_envs import make
|
from alr_envs import make
|
||||||
|
|
||||||
METAWORLD_IDS = []
|
GYM_IDS = [spec.id for spec in gym.envs.registry.all() if
|
||||||
|
"alr_envs" not in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point]
|
||||||
SEED = 1
|
SEED = 1
|
||||||
|
|
||||||
|
|
||||||
@ -58,7 +59,7 @@ class TestGymEnvironments(unittest.TestCase):
|
|||||||
if done:
|
if done:
|
||||||
break
|
break
|
||||||
|
|
||||||
assert done, "Done flag is not True after end of episode."
|
assert done or env.spec.max_episode_steps is None, "Done flag is not True after end of episode."
|
||||||
observations.append(obs)
|
observations.append(obs)
|
||||||
env.close()
|
env.close()
|
||||||
del env
|
del env
|
||||||
|
Loading…
Reference in New Issue
Block a user