Merge branch 'master' into readthedocs
This commit is contained in:
commit
b53917fcce
26
.github/workflows/ensure-release-tagged.yaml
vendored
Normal file
26
.github/workflows/ensure-release-tagged.yaml
vendored
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
name: Ensure Tagged Commits on Release
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- release
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check_tag:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Check out code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Check if base commit of PR is tagged
|
||||||
|
run: |
|
||||||
|
BASE_COMMIT=$(jq -r .pull_request.base.sha < "$GITHUB_EVENT_PATH")
|
||||||
|
TAG=$(git tag --contains $BASE_COMMIT)
|
||||||
|
if [ -z "$TAG" ]; then
|
||||||
|
echo "Base commit of PR is not tagged. PRs onto release must be tagged with the version number."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "Base commit of PR is tagged. Check passed."
|
||||||
|
|
16
.github/workflows/publish-to-pypi.yml
vendored
16
.github/workflows/publish-to-pypi.yml
vendored
@ -2,29 +2,27 @@ name: Publish Python package to PyPI
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
branches:
|
||||||
- '*'
|
- release
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
publish:
|
publish:
|
||||||
name: Publish to PyPI
|
name: Publish to PyPI
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
if: false && startsWith(github.ref, 'refs/tags/') # Only run on tagged commits
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Check out code
|
- name: Check out code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # This fetches all history for all branches and tags
|
fetch-depth: 0 # This fetches all history for all branches and tags
|
||||||
|
|
||||||
- name: Verify tag is on master branch
|
- name: Check if commit is tagged
|
||||||
run: |
|
run: |
|
||||||
TAG_IS_ON_MASTER=$(git branch -r --contains ${{ github.ref }} | grep 'origin/master')
|
TAG=$(git tag --contains HEAD)
|
||||||
if [ -z "$TAG_IS_ON_MASTER" ]; then
|
if [ -z "$TAG" ]; then
|
||||||
echo "Tag is not on the master branch. Cancelling the workflow."
|
echo "Commit is not tagged. Failing the workflow."
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
echo "Tag is on the master branch. Proceeding with the workflow."
|
echo "Commit is tagged. Proceeding with the workflow."
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v4
|
||||||
|
52
.github/workflows/publish-to-test-pypi.yml
vendored
52
.github/workflows/publish-to-test-pypi.yml
vendored
@ -1,52 +0,0 @@
|
|||||||
name: Publish Python package to TestPyPI
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
tags:
|
|
||||||
- '*'
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
publish:
|
|
||||||
name: Publish to TestPyPI
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
if: startsWith(github.ref, 'refs/tags/') # Only run on tagged commits
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Check out code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0 # This fetches all history for all branches and tags
|
|
||||||
|
|
||||||
- name: Verify tag is on master branch
|
|
||||||
run: |
|
|
||||||
TAG_IS_ON_MASTER=$(git branch -r --contains ${{ github.ref }} | grep 'origin/master')
|
|
||||||
if [ -z "$TAG_IS_ON_MASTER" ]; then
|
|
||||||
echo "Tag is not on the master branch. Cancelling the workflow."
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
echo "Tag is on the master branch. Proceeding with the workflow."
|
|
||||||
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v4
|
|
||||||
with:
|
|
||||||
python-version: "3.x"
|
|
||||||
|
|
||||||
- name: Install pypa/build/setuptools/twine
|
|
||||||
run: >-
|
|
||||||
python3 -m
|
|
||||||
pip install
|
|
||||||
build setuptools twine
|
|
||||||
--user
|
|
||||||
|
|
||||||
- name: Prevent fallback onto setup.py
|
|
||||||
run: rm setup.py
|
|
||||||
|
|
||||||
- name: Build a binary wheel and a source tarball
|
|
||||||
run: python3 -m build
|
|
||||||
|
|
||||||
- name: Publish to TestPyPI
|
|
||||||
env:
|
|
||||||
TWINE_USERNAME: __token__
|
|
||||||
TWINE_PASSWORD: ${{ secrets.TEST_PYPI_TOKEN }}
|
|
||||||
run: twine upload --repository-url https://test.pypi.org/legacy/ dist/*
|
|
||||||
|
|
@ -0,0 +1 @@
|
|||||||
|
# TODO
|
@ -6,21 +6,21 @@ def example_run_replanning_env(env_name="fancy_ProDMP/BoxPushingDenseReplan-v0",
|
|||||||
env = gym.make(env_name)
|
env = gym.make(env_name)
|
||||||
env.reset(seed=seed)
|
env.reset(seed=seed)
|
||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
done = False
|
while True:
|
||||||
while done is False:
|
|
||||||
ac = env.action_space.sample()
|
ac = env.action_space.sample()
|
||||||
obs, reward, terminated, truncated, info = env.step(ac)
|
obs, reward, terminated, truncated, info = env.step(ac)
|
||||||
if render:
|
if render:
|
||||||
env.render(mode="human")
|
env.render(mode="human")
|
||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
env.reset()
|
env.reset()
|
||||||
|
break
|
||||||
env.close()
|
env.close()
|
||||||
del env
|
del env
|
||||||
|
|
||||||
|
|
||||||
def example_custom_replanning_envs(seed=0, iteration=100, render=True):
|
def example_custom_replanning_envs(seed=0, iteration=100, render=True):
|
||||||
# id for a step-based environment
|
# id for a step-based environment
|
||||||
base_env_id = "BoxPushingDense-v0"
|
base_env_id = "fancy/BoxPushingDense-v0"
|
||||||
|
|
||||||
wrappers = [fancy_gym.envs.mujoco.box_pushing.mp_wrapper.MPWrapper]
|
wrappers = [fancy_gym.envs.mujoco.box_pushing.mp_wrapper.MPWrapper]
|
||||||
|
|
||||||
@ -38,7 +38,8 @@ def example_custom_replanning_envs(seed=0, iteration=100, render=True):
|
|||||||
'replanning_schedule': lambda pos, vel, obs, action, t: t % 25 == 0,
|
'replanning_schedule': lambda pos, vel, obs, action, t: t % 25 == 0,
|
||||||
'condition_on_desired': True}
|
'condition_on_desired': True}
|
||||||
|
|
||||||
env = fancy_gym.make_bb(env_id=base_env_id, wrappers=wrappers, black_box_kwargs=black_box_kwargs,
|
base_env = gym.make(base_env_id)
|
||||||
|
env = fancy_gym.make_bb(env=base_env, wrappers=wrappers, black_box_kwargs=black_box_kwargs,
|
||||||
traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs,
|
traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs,
|
||||||
phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs,
|
phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs,
|
||||||
seed=seed)
|
seed=seed)
|
||||||
@ -56,10 +57,12 @@ def example_custom_replanning_envs(seed=0, iteration=100, render=True):
|
|||||||
env.close()
|
env.close()
|
||||||
del env
|
del env
|
||||||
|
|
||||||
|
def main(render=False):
|
||||||
if __name__ == "__main__":
|
|
||||||
# run a registered replanning environment
|
# run a registered replanning environment
|
||||||
example_run_replanning_env(env_name="fancy_ProDMP/BoxPushingDenseReplan-v0", seed=1, iterations=1, render=False)
|
example_run_replanning_env(env_name="fancy_ProDMP/BoxPushingDenseReplan-v0", seed=1, iterations=1, render=render)
|
||||||
|
|
||||||
# run a custom replanning environment
|
# run a custom replanning environment
|
||||||
example_custom_replanning_envs(seed=0, iteration=8, render=True)
|
example_custom_replanning_envs(seed=0, iteration=8, render=render)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -84,7 +84,8 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
|||||||
# basis_generator_kwargs = {'basis_generator_type': 'rbf',
|
# basis_generator_kwargs = {'basis_generator_type': 'rbf',
|
||||||
# 'num_basis': 5
|
# 'num_basis': 5
|
||||||
# }
|
# }
|
||||||
env = fancy_gym.make_bb(env_id=base_env_id, wrappers=wrappers, black_box_kwargs={},
|
base_env = gym.make(base_env_id)
|
||||||
|
env = fancy_gym.make_bb(env=base_env, wrappers=wrappers, black_box_kwargs={},
|
||||||
traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs,
|
traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs,
|
||||||
phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs,
|
phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs,
|
||||||
seed=seed)
|
seed=seed)
|
||||||
@ -114,21 +115,13 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True):
|
|||||||
env.close()
|
env.close()
|
||||||
del env
|
del env
|
||||||
|
|
||||||
|
def main(render = True):
|
||||||
if __name__ == '__main__':
|
|
||||||
# Disclaimer: DMC environments require the seed to be specified in the beginning.
|
|
||||||
# Adjusting it afterwards with env.seed() is not recommended as it does not affect the underlying physics.
|
|
||||||
|
|
||||||
# For rendering DMC
|
|
||||||
# export MUJOCO_GL="osmesa"
|
|
||||||
render = True
|
|
||||||
|
|
||||||
# # Standard DMC Suite tasks
|
# # Standard DMC Suite tasks
|
||||||
example_dmc("dm_control/fish-swim", seed=10, iterations=1000, render=render)
|
example_dmc("dm_control/fish-swim", seed=10, iterations=1000, render=render)
|
||||||
#
|
#
|
||||||
# # Manipulation tasks
|
# # Manipulation tasks
|
||||||
# # Disclaimer: The vision versions are currently not integrated and yield an error
|
# # Disclaimer: The vision versions are currently not integrated and yield an error
|
||||||
example_dmc("dm_control/manipulation-reach_site_features", seed=10, iterations=250, render=render)
|
example_dmc("dm_control/reach_site_features", seed=10, iterations=250, render=render)
|
||||||
#
|
#
|
||||||
# # Gym + DMC hybrid task provided in the MP framework
|
# # Gym + DMC hybrid task provided in the MP framework
|
||||||
example_dmc("dm_control_ProMP/ball_in_cup-catch-v0", seed=10, iterations=1, render=render)
|
example_dmc("dm_control_ProMP/ball_in_cup-catch-v0", seed=10, iterations=1, render=render)
|
||||||
@ -136,3 +129,20 @@ if __name__ == '__main__':
|
|||||||
# Custom DMC task # Different seed, because the episode is longer for this example and the name+seed combo is
|
# Custom DMC task # Different seed, because the episode is longer for this example and the name+seed combo is
|
||||||
# already registered above
|
# already registered above
|
||||||
example_custom_dmc_and_mp(seed=11, iterations=1, render=render)
|
example_custom_dmc_and_mp(seed=11, iterations=1, render=render)
|
||||||
|
|
||||||
|
# # Standard DMC Suite tasks
|
||||||
|
example_dmc("dm_control/fish-swim", seed=10, iterations=1000, render=render)
|
||||||
|
#
|
||||||
|
# # Manipulation tasks
|
||||||
|
# # Disclaimer: The vision versions are currently not integrated and yield an error
|
||||||
|
example_dmc("dm_control/reach_site_features", seed=10, iterations=250, render=render)
|
||||||
|
#
|
||||||
|
# # Gym + DMC hybrid task provided in the MP framework
|
||||||
|
example_dmc("dm_control_ProMP/ball_in_cup-catch-v0", seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
|
# Custom DMC task # Different seed, because the episode is longer for this example and the name+seed combo is
|
||||||
|
# already registered above
|
||||||
|
example_custom_dmc_and_mp(seed=11, iterations=1, render=render)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -85,10 +85,7 @@ def example_async(env_id="fancy/HoleReacher-v0", n_cpu=4, seed=int('533D', 16),
|
|||||||
# do not return values above threshold
|
# do not return values above threshold
|
||||||
return *map(lambda v: np.stack(v)[:n_samples], buffer.values()),
|
return *map(lambda v: np.stack(v)[:n_samples], buffer.values()),
|
||||||
|
|
||||||
|
def main(render = True):
|
||||||
if __name__ == '__main__':
|
|
||||||
render = True
|
|
||||||
|
|
||||||
# Basic gym task
|
# Basic gym task
|
||||||
example_general("Pendulum-v1", seed=10, iterations=200, render=render)
|
example_general("Pendulum-v1", seed=10, iterations=200, render=render)
|
||||||
|
|
||||||
@ -100,3 +97,6 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# Vectorized multiprocessing environments
|
# Vectorized multiprocessing environments
|
||||||
# example_async(env_id="HoleReacher-v0", n_cpu=2, seed=int('533D', 16), n_samples=2 * 200)
|
# example_async(env_id="HoleReacher-v0", n_cpu=2, seed=int('533D', 16), n_samples=2 * 200)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -35,7 +35,7 @@ def example_meta(env_id="fish-swim", seed=1, iterations=1000, render=True):
|
|||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
print(env_id, rewards)
|
print(env_id, rewards)
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset()
|
obs = env.reset(seed=seed+i+1)
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
del env
|
del env
|
||||||
@ -81,7 +81,8 @@ def example_custom_meta_and_mp(seed=1, iterations=1, render=True):
|
|||||||
basis_generator_kwargs = {'basis_generator_type': 'rbf',
|
basis_generator_kwargs = {'basis_generator_type': 'rbf',
|
||||||
'num_basis': 5
|
'num_basis': 5
|
||||||
}
|
}
|
||||||
env = fancy_gym.make_bb(env_id=base_env_id, wrappers=wrappers, black_box_kwargs={},
|
base_env = gym.make(base_env_id)
|
||||||
|
env = fancy_gym.make_bb(env=base_env, wrappers=wrappers, black_box_kwargs={},
|
||||||
traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs,
|
traj_gen_kwargs=trajectory_generator_kwargs, controller_kwargs=controller_kwargs,
|
||||||
phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs,
|
phase_kwargs=phase_generator_kwargs, basis_kwargs=basis_generator_kwargs,
|
||||||
seed=seed)
|
seed=seed)
|
||||||
@ -92,14 +93,10 @@ def example_custom_meta_and_mp(seed=1, iterations=1, render=True):
|
|||||||
# It is also possible to change them mode multiple times when
|
# It is also possible to change them mode multiple times when
|
||||||
# e.g. only every nth trajectory should be displayed.
|
# e.g. only every nth trajectory should be displayed.
|
||||||
if render:
|
if render:
|
||||||
raise ValueError("Metaworld render interface bug does not allow to render() fixes its interface. "
|
env.render(mode="human")
|
||||||
"A temporary workaround is to alter their code in MujocoEnv render() from "
|
|
||||||
"`if not offscreen` to `if not offscreen or offscreen == 'human'`.")
|
|
||||||
# TODO: Remove this, when Metaworld fixes its interface.
|
|
||||||
# env.render(mode="human")
|
|
||||||
|
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset()
|
obs = env.reset(seed=seed)
|
||||||
|
|
||||||
# number of samples/full trajectories (multiple environment steps)
|
# number of samples/full trajectories (multiple environment steps)
|
||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
@ -110,25 +107,23 @@ def example_custom_meta_and_mp(seed=1, iterations=1, render=True):
|
|||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
print(base_env_id, rewards)
|
print(base_env_id, rewards)
|
||||||
rewards = 0
|
rewards = 0
|
||||||
obs = env.reset()
|
obs = env.reset(seed=seed+i+1)
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
del env
|
del env
|
||||||
|
|
||||||
|
def main(render = False):
|
||||||
if __name__ == '__main__':
|
|
||||||
# Disclaimer: MetaWorld environments require the seed to be specified in the beginning.
|
|
||||||
# Adjusting it afterwards with env.seed() is not recommended as it may not affect the underlying behavior.
|
|
||||||
|
|
||||||
# For rendering it might be necessary to specify your OpenGL installation
|
# For rendering it might be necessary to specify your OpenGL installation
|
||||||
# export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libGLEW.so
|
# export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libGLEW.so
|
||||||
render = False
|
|
||||||
|
|
||||||
# # Standard Meta world tasks
|
# # Standard Meta world tasks
|
||||||
example_meta("metaworld/button-press-v2", seed=10, iterations=500, render=render)
|
example_meta("metaworld/button-press-v2", seed=10, iterations=500, render=render)
|
||||||
|
|
||||||
# # MP + MetaWorld hybrid task provided in the our framework
|
# # MP + MetaWorld hybrid task provided in the our framework
|
||||||
example_meta("metaworld_ProMP/ButtonPress-v2", seed=10, iterations=1, render=render)
|
example_meta("metaworld_ProMP/button-press-v2", seed=10, iterations=1, render=render)
|
||||||
#
|
#
|
||||||
# # Custom MetaWorld task
|
# # Custom MetaWorld task
|
||||||
example_custom_meta_and_mp(seed=10, iterations=1, render=render)
|
example_custom_meta_and_mp(seed=10, iterations=1, render=render)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -26,6 +26,8 @@ def example_mp(env_name="fancy_ProMP/HoleReacher-v0", seed=1, iterations=1, rend
|
|||||||
for i in range(iterations):
|
for i in range(iterations):
|
||||||
|
|
||||||
if render and i % 1 == 0:
|
if render and i % 1 == 0:
|
||||||
|
# This renders the full MP trajectory
|
||||||
|
# It is only required to call render() once in the beginning, which renders every consecutive trajectory.
|
||||||
env.render()
|
env.render()
|
||||||
|
|
||||||
# Now the action space is not the raw action but the parametrization of the trajectory generator,
|
# Now the action space is not the raw action but the parametrization of the trajectory generator,
|
||||||
@ -248,8 +250,7 @@ def example_fully_custom_mp_alternative(seed=1, iterations=1, render=True):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main(render=False):
|
||||||
render = False
|
|
||||||
# DMP
|
# DMP
|
||||||
example_mp("fancy_DMP/HoleReacher-v0", seed=10, iterations=5, render=render)
|
example_mp("fancy_DMP/HoleReacher-v0", seed=10, iterations=5, render=render)
|
||||||
|
|
||||||
|
@ -31,6 +31,8 @@ def example_mp(env_name, seed=1, render=True):
|
|||||||
print(returns)
|
print(returns)
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
|
def main(render=True):
|
||||||
|
example_mp("gym_ProMP/Reacher-v2", render=render)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
example_mp("gym_ProMP/Reacher-v2")
|
main()
|
@ -51,28 +51,43 @@ class FixMetaworldIgnoresSeedOnResetWrapper(gym.Wrapper, gym.utils.RecordConstru
|
|||||||
gym.Wrapper.__init__(self, env)
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
print('[!] You just called .reset on a Metaworld env and supplied a seed. Metaworld curretly does not correctly implement seeding. Do not rely on deterministic behavior.')
|
|
||||||
if 'seed' in kwargs:
|
if 'seed' in kwargs:
|
||||||
|
print('[!] You just called .reset on a Metaworld env and supplied a seed. Metaworld curretly does not correctly implement seeding. Do not rely on deterministic behavior.')
|
||||||
self.env.seed(kwargs['seed'])
|
self.env.seed(kwargs['seed'])
|
||||||
return self.env.reset(**kwargs)
|
return self.env.reset(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class FixMetaworldRenderOnStep(gym.Wrapper, gym.utils.RecordConstructorArgs):
|
||||||
|
def __init__(self, env: gym.Env):
|
||||||
|
gym.utils.RecordConstructorArgs.__init__(self)
|
||||||
|
gym.Wrapper.__init__(self, env)
|
||||||
|
self.render_active = False
|
||||||
|
|
||||||
|
def render(self, *args, **kwargs):
|
||||||
|
self.render_active = True
|
||||||
|
return self.env.render(*args, **kwargs)
|
||||||
|
|
||||||
|
def step(self, *args, **kwargs):
|
||||||
|
ret = self.env.step(*args, **kwargs)
|
||||||
|
if self.render_active:
|
||||||
|
self.env.render()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def make_metaworld(underlying_id: str, seed: int = 1, render_mode: Optional[str] = None, **kwargs):
|
def make_metaworld(underlying_id: str, seed: int = 1, render_mode: Optional[str] = None, **kwargs):
|
||||||
if underlying_id not in metaworld.ML1.ENV_NAMES:
|
if underlying_id not in metaworld.ML1.ENV_NAMES:
|
||||||
raise ValueError(f'Specified environment "{underlying_id}" not present in metaworld ML1.')
|
raise ValueError(f'Specified environment "{underlying_id}" not present in metaworld ML1.')
|
||||||
|
|
||||||
env = metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[underlying_id + "-goal-observable"](seed=seed, **kwargs)
|
env = metaworld.envs.ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE[underlying_id + "-goal-observable"](seed=seed, render_mode=render_mode, **kwargs)
|
||||||
|
|
||||||
# setting this avoids generating the same initialization after each reset
|
# setting this avoids generating the same initialization after each reset
|
||||||
env._freeze_rand_vec = False
|
env._freeze_rand_vec = False
|
||||||
# New argument to use global seeding
|
# New argument to use global seeding
|
||||||
env.seeded_rand_vec = True
|
env.seeded_rand_vec = True
|
||||||
|
|
||||||
# TODO remove, when this has been fixed upstream
|
|
||||||
env = FixMetaworldHasIncorrectObsSpaceWrapper(env)
|
env = FixMetaworldHasIncorrectObsSpaceWrapper(env)
|
||||||
# TODO remove, when this has been fixed upstream
|
|
||||||
# env = FixMetaworldIncorrectResetPathLengthWrapper(env)
|
# env = FixMetaworldIncorrectResetPathLengthWrapper(env)
|
||||||
# TODO remove, when this has been fixed upstream
|
env = FixMetaworldRenderOnStep(env)
|
||||||
env = FixMetaworldIgnoresSeedOnResetWrapper(env)
|
env = FixMetaworldIgnoresSeedOnResetWrapper(env)
|
||||||
return env
|
return env
|
||||||
|
|
||||||
|
@ -50,6 +50,7 @@ dmc = ["shimmy[dm-control]", "Shimmy==1.0.0"]
|
|||||||
box2d = ["gymnasium[box2d]>=0.26.0"]
|
box2d = ["gymnasium[box2d]>=0.26.0"]
|
||||||
mujoco-legacy = ["mujoco-py>=2.1,<2.2", "cython<3"]
|
mujoco-legacy = ["mujoco-py>=2.1,<2.2", "cython<3"]
|
||||||
jax = ["jax>=0.4.0", "jaxlib>=0.4.0"]
|
jax = ["jax>=0.4.0", "jaxlib>=0.4.0"]
|
||||||
|
mushroom-rl = ["mushroom-rl"]
|
||||||
|
|
||||||
all = [
|
all = [
|
||||||
# include all the optional dependencies
|
# include all the optional dependencies
|
||||||
@ -61,7 +62,8 @@ all = [
|
|||||||
"mujoco-py>=2.1,<2.2",
|
"mujoco-py>=2.1,<2.2",
|
||||||
"cython<3",
|
"cython<3",
|
||||||
"jax>=0.4.0",
|
"jax>=0.4.0",
|
||||||
"jaxlib>=0.4.0"
|
"jaxlib>=0.4.0",
|
||||||
|
"mushroom-rl",
|
||||||
]
|
]
|
||||||
|
|
||||||
testing = [
|
testing = [
|
||||||
@ -75,5 +77,6 @@ testing = [
|
|||||||
"mujoco-py>=2.1,<2.2",
|
"mujoco-py>=2.1,<2.2",
|
||||||
"cython<3",
|
"cython<3",
|
||||||
"jax>=0.4.0",
|
"jax>=0.4.0",
|
||||||
"jaxlib>=0.4.0"
|
"jaxlib>=0.4.0",
|
||||||
|
"mushroom-rl",
|
||||||
]
|
]
|
||||||
|
13
test/test_examples.py
Normal file
13
test/test_examples.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from fancy_gym.examples.example_replanning_envs import main as replanning_envs_main
|
||||||
|
from fancy_gym.examples.examples_dmc import main as dmc_main
|
||||||
|
from fancy_gym.examples.examples_general import main as general_main
|
||||||
|
from fancy_gym.examples.examples_metaworld import main as metaworld_main
|
||||||
|
from fancy_gym.examples.examples_movement_primitives import main as mp_main
|
||||||
|
from fancy_gym.examples.examples_open_ai import main as open_ai_main
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('entry', [replanning_envs_main, dmc_main, general_main, metaworld_main, mp_main, open_ai_main])
|
||||||
|
@pytest.mark.parametrize('render', [False])
|
||||||
|
def test_run_example(entry, render):
|
||||||
|
entry(render=render)
|
Loading…
Reference in New Issue
Block a user