updated to new API, so tests still failing
This commit is contained in:
parent
ec2063aa0b
commit
c53924d9fc
@ -55,7 +55,6 @@ class BaseReacherEnv(gym.Env):
|
|||||||
self.fig = None
|
self.fig = None
|
||||||
|
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
self.seed()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dt(self) -> Union[float, int]:
|
def dt(self) -> Union[float, int]:
|
||||||
@ -72,6 +71,7 @@ class BaseReacherEnv(gym.Env):
|
|||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||||
-> Tuple[ObsType, Dict[str, Any]]:
|
-> Tuple[ObsType, Dict[str, Any]]:
|
||||||
# Sample only orientation of first link, i.e. the arm is always straight.
|
# Sample only orientation of first link, i.e. the arm is always straight.
|
||||||
|
super(BaseReacherEnv, self).reset(seed=seed, options=options)
|
||||||
try:
|
try:
|
||||||
random_start = options.get('random_start', self.random_start)
|
random_start = options.get('random_start', self.random_start)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
@ -128,10 +128,6 @@ class BaseReacherEnv(gym.Env):
|
|||||||
def _terminate(self, info) -> bool:
|
def _terminate(self, info) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
self.np_random, seed = seeding.np_random(seed)
|
|
||||||
return [seed]
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
super(BaseReacherEnv, self).close()
|
super(BaseReacherEnv, self).close()
|
||||||
del self.fig
|
del self.fig
|
||||||
|
@ -57,11 +57,16 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
|||||||
|
|
||||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||||
-> Tuple[ObsType, Dict[str, Any]]:
|
-> Tuple[ObsType, Dict[str, Any]]:
|
||||||
|
|
||||||
|
# initialize seed here as the random goal needs to be generated before the super reset()
|
||||||
|
gym.Env.reset(self, seed=seed, options=options)
|
||||||
|
|
||||||
self._generate_hole()
|
self._generate_hole()
|
||||||
self._set_patches()
|
self._set_patches()
|
||||||
self.reward_function.reset()
|
self.reward_function.reset()
|
||||||
|
|
||||||
return super().reset()
|
# do not provide seed to avoid setting it twice
|
||||||
|
return super(HoleReacherEnv, self).reset(options=options)
|
||||||
|
|
||||||
def _get_reward(self, action: np.ndarray) -> (float, dict):
|
def _get_reward(self, action: np.ndarray) -> (float, dict):
|
||||||
return self.reward_function.get_reward(self)
|
return self.reward_function.get_reward(self)
|
||||||
@ -224,6 +229,3 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
|||||||
self.fig.gca().add_patch(left_block)
|
self.fig.gca().add_patch(left_block)
|
||||||
self.fig.gca().add_patch(right_block)
|
self.fig.gca().add_patch(right_block)
|
||||||
self.fig.gca().add_patch(hole_floor)
|
self.fig.gca().add_patch(hole_floor)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
11
fancy_gym/utils/env_compatibility.py
Normal file
11
fancy_gym/utils/env_compatibility.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
|
|
||||||
|
class EnvCompatibility(gym.wrappers.EnvCompatibility):
|
||||||
|
def __getattr__(self, item):
|
||||||
|
"""Propagate only non-existent properties to wrapped env."""
|
||||||
|
if item.startswith('_'):
|
||||||
|
raise AttributeError("attempted to get missing private attribute '{}'".format(item))
|
||||||
|
if item in self.__dict__:
|
||||||
|
return getattr(self, item)
|
||||||
|
return getattr(self.env, item)
|
@ -3,12 +3,14 @@ import uuid
|
|||||||
from collections.abc import MutableMapping
|
from collections.abc import MutableMapping
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from typing import Iterable, Type, Union
|
from typing import Iterable, Type, Union, Optional
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gymnasium.envs.registration import register, registry
|
from gymnasium.envs.registration import register, registry
|
||||||
|
|
||||||
|
from fancy_gym.utils.env_compatibility import EnvCompatibility
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from dm_control import suite, manipulation
|
from dm_control import suite, manipulation
|
||||||
from shimmy.dm_control_compatibility import EnvType
|
from shimmy.dm_control_compatibility import EnvType
|
||||||
@ -186,9 +188,9 @@ def make_bb(
|
|||||||
|
|
||||||
def get_env_duration(env: gym.Env):
|
def get_env_duration(env: gym.Env):
|
||||||
try:
|
try:
|
||||||
# TODO Remove if this is in the compatibility class
|
|
||||||
duration = env.spec.max_episode_steps * env.dt
|
duration = env.spec.max_episode_steps * env.dt
|
||||||
except (AttributeError, TypeError) as e:
|
except (AttributeError, TypeError) as e:
|
||||||
|
# TODO Remove if this information is in the compatibility class
|
||||||
logging.error(f'Attributes env.spec.max_episode_steps and env.dt are not available. '
|
logging.error(f'Attributes env.spec.max_episode_steps and env.dt are not available. '
|
||||||
f'Assuming you are using dm_control. Please make sure you have ran '
|
f'Assuming you are using dm_control. Please make sure you have ran '
|
||||||
f'"pip install shimmy[dm_control]" for that.')
|
f'"pip install shimmy[dm_control]" for that.')
|
||||||
@ -300,7 +302,7 @@ def make_bb_env_helper(**kwargs):
|
|||||||
# return env
|
# return env
|
||||||
|
|
||||||
|
|
||||||
def make_metaworld(env_id: str, seed: int, **kwargs):
|
def make_metaworld(env_id: str, seed: int, render_mode: Optional[str] = None, **kwargs):
|
||||||
if env_id not in metaworld.ML1.ENV_NAMES:
|
if env_id not in metaworld.ML1.ENV_NAMES:
|
||||||
raise ValueError(f'Specified environment "{env_id}" not present in metaworld ML1.')
|
raise ValueError(f'Specified environment "{env_id}" not present in metaworld ML1.')
|
||||||
|
|
||||||
@ -314,7 +316,7 @@ def make_metaworld(env_id: str, seed: int, **kwargs):
|
|||||||
max_episode_steps = _env.max_path_length
|
max_episode_steps = _env.max_path_length
|
||||||
|
|
||||||
# TODO remove this as soon as there is support for the new API
|
# TODO remove this as soon as there is support for the new API
|
||||||
_env = gym.wrappers.EnvCompatibility(_env)
|
_env = EnvCompatibility(_env, render_mode)
|
||||||
|
|
||||||
gym_id = uuid.uuid4().hex + '-v1'
|
gym_id = uuid.uuid4().hex + '-v1'
|
||||||
|
|
||||||
|
@ -14,20 +14,20 @@ DM_CONTROL_IDS = [spec.id for spec in gym.envs.registry.values() if
|
|||||||
spec.id.startswith('dm_control/')
|
spec.id.startswith('dm_control/')
|
||||||
and 'compatibility-env-v0' not in spec.id
|
and 'compatibility-env-v0' not in spec.id
|
||||||
and 'lqr-lqr' not in spec.id]
|
and 'lqr-lqr' not in spec.id]
|
||||||
DM_control_MP_IDS = chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
DM_control_MP_IDS = list(chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()))
|
||||||
SEED = 1
|
SEED = 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('env_id', DM_CONTROL_IDS)
|
@pytest.mark.parametrize('env_id', DM_CONTROL_IDS)
|
||||||
def test_step_dm_control_functionality(env_id: str):
|
def test_step_dm_control_functionality(env_id: str):
|
||||||
"""Tests that suite step environments run without errors using random actions."""
|
"""Tests that suite step environments run without errors using random actions."""
|
||||||
run_env(env_id, 1000)
|
run_env(env_id, 5000, wrappers=[gym.wrappers.FlattenObservation])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('env_id', DM_CONTROL_IDS)
|
@pytest.mark.parametrize('env_id', DM_CONTROL_IDS)
|
||||||
def test_step_dm_control_determinism(env_id: str):
|
def test_step_dm_control_determinism(env_id: str):
|
||||||
"""Tests that for step environments identical seeds produce identical trajectories."""
|
"""Tests that for step environments identical seeds produce identical trajectories."""
|
||||||
run_env_determinism(env_id, SEED, 1000)
|
run_env_determinism(env_id, SEED, 5000, wrappers=[gym.wrappers.FlattenObservation])
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize('env_id', MANIPULATION_IDS)
|
# @pytest.mark.parametrize('env_id', MANIPULATION_IDS)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import itertools
|
from itertools import chain
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import fancy_gym
|
import fancy_gym
|
||||||
@ -10,7 +10,7 @@ from test.utils import run_env, run_env_determinism
|
|||||||
CUSTOM_IDS = [id for id, spec in gym.envs.registry.items() if
|
CUSTOM_IDS = [id for id, spec in gym.envs.registry.items() if
|
||||||
not isinstance(spec.entry_point, Callable) and
|
not isinstance(spec.entry_point, Callable) and
|
||||||
"fancy_gym" in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point]
|
"fancy_gym" in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point]
|
||||||
CUSTOM_MP_IDS = itertools.chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
CUSTOM_MP_IDS = list(chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()))
|
||||||
SEED = 1
|
SEED = 1
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
|
import re
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import pytest
|
import pytest
|
||||||
@ -7,8 +9,12 @@ import fancy_gym
|
|||||||
from test.utils import run_env, run_env_determinism
|
from test.utils import run_env, run_env_determinism
|
||||||
|
|
||||||
GYM_IDS = [spec.id for spec in gym.envs.registry.values() if
|
GYM_IDS = [spec.id for spec in gym.envs.registry.values() if
|
||||||
"fancy_gym" not in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point]
|
not isinstance(spec.entry_point, Callable) and
|
||||||
GYM_MP_IDS = chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
"fancy_gym" not in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point
|
||||||
|
and 'jax' not in spec.id.lower()
|
||||||
|
and not re.match(r'GymV2.Environment', spec.id)
|
||||||
|
]
|
||||||
|
GYM_MP_IDS = list(chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()))
|
||||||
SEED = 1
|
SEED = 1
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,8 +8,7 @@ from test.utils import run_env, run_env_determinism
|
|||||||
|
|
||||||
METAWORLD_IDS = [f'metaworld:{env.split("-goal-observable")[0]}' for env, _ in
|
METAWORLD_IDS = [f'metaworld:{env.split("-goal-observable")[0]}' for env, _ in
|
||||||
ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.items()]
|
ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.items()]
|
||||||
METAWORLD_MP_IDS = chain(*fancy_gym.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
METAWORLD_MP_IDS = list(chain(*fancy_gym.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()))
|
||||||
print(METAWORLD_MP_IDS)
|
|
||||||
SEED = 1
|
SEED = 1
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
|
from typing import List, Type
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from fancy_gym import make
|
from fancy_gym import make
|
||||||
|
|
||||||
|
|
||||||
def run_env(env_id, iterations=None, seed=0, render=False):
|
def run_env(env_id: str, iterations: int = None, seed: int = 0, wrappers: List[Type[gym.Wrapper]] = [],
|
||||||
|
render: bool = False):
|
||||||
"""
|
"""
|
||||||
Example for running a DMC based env in the step based setting.
|
Example for running a DMC based env in the step based setting.
|
||||||
The env_id has to be specified as `dmc:domain_name-task_name` or
|
The env_id has to be specified as `dmc:domain_name-task_name` or
|
||||||
@ -13,12 +16,15 @@ def run_env(env_id, iterations=None, seed=0, render=False):
|
|||||||
env_id: Either `dmc:domain_name-task_name` or `dmc:manipulation-environment_name`
|
env_id: Either `dmc:domain_name-task_name` or `dmc:manipulation-environment_name`
|
||||||
iterations: Number of rollout steps to run
|
iterations: Number of rollout steps to run
|
||||||
seed: random seeding
|
seed: random seeding
|
||||||
|
wrappers: List of Wrappers to apply to the environment
|
||||||
render: Render the episode
|
render: Render the episode
|
||||||
|
|
||||||
Returns: observations, rewards, terminations, truncations, actions
|
Returns: observations, rewards, terminations, truncations, actions
|
||||||
|
|
||||||
"""
|
"""
|
||||||
env: gym.Env = make(env_id, seed=seed)
|
env: gym.Env = make(env_id, seed=seed)
|
||||||
|
for w in wrappers:
|
||||||
|
env = w(env)
|
||||||
rewards = []
|
rewards = []
|
||||||
observations = []
|
observations = []
|
||||||
actions = []
|
actions = []
|
||||||
@ -60,13 +66,13 @@ def run_env(env_id, iterations=None, seed=0, render=False):
|
|||||||
return np.array(observations), np.array(rewards), np.array(terminations), np.array(truncations), np.array(actions)
|
return np.array(observations), np.array(rewards), np.array(terminations), np.array(truncations), np.array(actions)
|
||||||
|
|
||||||
|
|
||||||
def run_env_determinism(env_id: str, seed: int, iterations: int = None):
|
def run_env_determinism(env_id: str, seed: int, iterations: int = None, wrappers: List[Type[gym.Wrapper]] = []):
|
||||||
traj1 = run_env(env_id, iterations=iterations, seed=seed)
|
traj1 = run_env(env_id, iterations=iterations, seed=seed, wrappers=wrappers)
|
||||||
traj2 = run_env(env_id, iterations=iterations, seed=seed)
|
traj2 = run_env(env_id, iterations=iterations, seed=seed, wrappers=wrappers)
|
||||||
# Iterate over two trajectories, which should have the same state and action sequence
|
# Iterate over two trajectories, which should have the same state and action sequence
|
||||||
for i, time_step in enumerate(zip(*traj1, *traj2)):
|
for i, time_step in enumerate(zip(*traj1, *traj2)):
|
||||||
obs1, rwd1, term1, trunc1, ac1, obs2, rwd2, term2, trunc2, ac2 = time_step
|
obs1, rwd1, term1, trunc1, ac1, obs2, rwd2, term2, trunc2, ac2 = time_step
|
||||||
assert np.array_equal(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match."
|
assert np.allclose(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match."
|
||||||
assert np.array_equal(ac1, ac2), f"Actions [{i}] {ac1} and {ac2} do not match."
|
assert np.array_equal(ac1, ac2), f"Actions [{i}] {ac1} and {ac2} do not match."
|
||||||
assert np.array_equal(rwd1, rwd2), f"Rewards [{i}] {rwd1} and {rwd2} do not match."
|
assert np.array_equal(rwd1, rwd2), f"Rewards [{i}] {rwd1} and {rwd2} do not match."
|
||||||
assert np.array_equal(term1, term2), f"Terminateds [{i}] {term1} and {term2} do not match."
|
assert np.array_equal(term1, term2), f"Terminateds [{i}] {term1} and {term2} do not match."
|
||||||
|
Loading…
Reference in New Issue
Block a user