diff --git a/fancy_gym/envs/classic_control/base_reacher/base_reacher.py b/fancy_gym/envs/classic_control/base_reacher/base_reacher.py index f0e0a3e..18305fd 100644 --- a/fancy_gym/envs/classic_control/base_reacher/base_reacher.py +++ b/fancy_gym/envs/classic_control/base_reacher/base_reacher.py @@ -55,7 +55,6 @@ class BaseReacherEnv(gym.Env): self.fig = None self._steps = 0 - self.seed() @property def dt(self) -> Union[float, int]: @@ -72,6 +71,7 @@ class BaseReacherEnv(gym.Env): def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \ -> Tuple[ObsType, Dict[str, Any]]: # Sample only orientation of first link, i.e. the arm is always straight. + super(BaseReacherEnv, self).reset(seed=seed, options=options) try: random_start = options.get('random_start', self.random_start) except AttributeError: @@ -128,10 +128,6 @@ class BaseReacherEnv(gym.Env): def _terminate(self, info) -> bool: raise NotImplementedError - def seed(self, seed=None): - self.np_random, seed = seeding.np_random(seed) - return [seed] - def close(self): super(BaseReacherEnv, self).close() del self.fig diff --git a/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py b/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py index c3e5020..0ed03f2 100644 --- a/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py +++ b/fancy_gym/envs/classic_control/hole_reacher/hole_reacher.py @@ -57,11 +57,16 @@ class HoleReacherEnv(BaseReacherDirectEnv): def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \ -> Tuple[ObsType, Dict[str, Any]]: + + # initialize seed here as the random goal needs to be generated before the super reset() + gym.Env.reset(self, seed=seed, options=options) + self._generate_hole() self._set_patches() self.reward_function.reset() - return super().reset() + # do not provide seed to avoid setting it twice + return super(HoleReacherEnv, self).reset(options=options) def _get_reward(self, action: np.ndarray) -> (float, dict): return self.reward_function.get_reward(self) @@ -224,6 +229,3 @@ class HoleReacherEnv(BaseReacherDirectEnv): self.fig.gca().add_patch(left_block) self.fig.gca().add_patch(right_block) self.fig.gca().add_patch(hole_floor) - - - diff --git a/fancy_gym/utils/env_compatibility.py b/fancy_gym/utils/env_compatibility.py new file mode 100644 index 0000000..a278451 --- /dev/null +++ b/fancy_gym/utils/env_compatibility.py @@ -0,0 +1,11 @@ +import gymnasium as gym + + +class EnvCompatibility(gym.wrappers.EnvCompatibility): + def __getattr__(self, item): + """Propagate only non-existent properties to wrapped env.""" + if item.startswith('_'): + raise AttributeError("attempted to get missing private attribute '{}'".format(item)) + if item in self.__dict__: + return getattr(self, item) + return getattr(self.env, item) diff --git a/fancy_gym/utils/make_env_helpers.py b/fancy_gym/utils/make_env_helpers.py index 50aa38f..eb7b49c 100644 --- a/fancy_gym/utils/make_env_helpers.py +++ b/fancy_gym/utils/make_env_helpers.py @@ -3,12 +3,14 @@ import uuid from collections.abc import MutableMapping from copy import deepcopy from math import ceil -from typing import Iterable, Type, Union +from typing import Iterable, Type, Union, Optional import gymnasium as gym import numpy as np from gymnasium.envs.registration import register, registry +from fancy_gym.utils.env_compatibility import EnvCompatibility + try: from dm_control import suite, manipulation from shimmy.dm_control_compatibility import EnvType @@ -186,9 +188,9 @@ def make_bb( def get_env_duration(env: gym.Env): try: - # TODO Remove if this is in the compatibility class duration = env.spec.max_episode_steps * env.dt except (AttributeError, TypeError) as e: + # TODO Remove if this information is in the compatibility class logging.error(f'Attributes env.spec.max_episode_steps and env.dt are not available. ' f'Assuming you are using dm_control. Please make sure you have ran ' f'"pip install shimmy[dm_control]" for that.') @@ -300,7 +302,7 @@ def make_bb_env_helper(**kwargs): # return env -def make_metaworld(env_id: str, seed: int, **kwargs): +def make_metaworld(env_id: str, seed: int, render_mode: Optional[str] = None, **kwargs): if env_id not in metaworld.ML1.ENV_NAMES: raise ValueError(f'Specified environment "{env_id}" not present in metaworld ML1.') @@ -314,7 +316,7 @@ def make_metaworld(env_id: str, seed: int, **kwargs): max_episode_steps = _env.max_path_length # TODO remove this as soon as there is support for the new API - _env = gym.wrappers.EnvCompatibility(_env) + _env = EnvCompatibility(_env, render_mode) gym_id = uuid.uuid4().hex + '-v1' diff --git a/test/test_dmc_envs.py b/test/test_dmc_envs.py index 53119af..266a12f 100644 --- a/test/test_dmc_envs.py +++ b/test/test_dmc_envs.py @@ -14,20 +14,20 @@ DM_CONTROL_IDS = [spec.id for spec in gym.envs.registry.values() if spec.id.startswith('dm_control/') and 'compatibility-env-v0' not in spec.id and 'lqr-lqr' not in spec.id] -DM_control_MP_IDS = chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()) +DM_control_MP_IDS = list(chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())) SEED = 1 @pytest.mark.parametrize('env_id', DM_CONTROL_IDS) def test_step_dm_control_functionality(env_id: str): """Tests that suite step environments run without errors using random actions.""" - run_env(env_id, 1000) + run_env(env_id, 5000, wrappers=[gym.wrappers.FlattenObservation]) @pytest.mark.parametrize('env_id', DM_CONTROL_IDS) def test_step_dm_control_determinism(env_id: str): """Tests that for step environments identical seeds produce identical trajectories.""" - run_env_determinism(env_id, SEED, 1000) + run_env_determinism(env_id, SEED, 5000, wrappers=[gym.wrappers.FlattenObservation]) # @pytest.mark.parametrize('env_id', MANIPULATION_IDS) diff --git a/test/test_fancy_envs.py b/test/test_fancy_envs.py index 02208ce..898cc08 100644 --- a/test/test_fancy_envs.py +++ b/test/test_fancy_envs.py @@ -1,4 +1,4 @@ -import itertools +from itertools import chain from typing import Callable import fancy_gym @@ -10,7 +10,7 @@ from test.utils import run_env, run_env_determinism CUSTOM_IDS = [id for id, spec in gym.envs.registry.items() if not isinstance(spec.entry_point, Callable) and "fancy_gym" in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point] -CUSTOM_MP_IDS = itertools.chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()) +CUSTOM_MP_IDS = list(chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())) SEED = 1 diff --git a/test/test_gym_envs.py b/test/test_gym_envs.py index 20b089d..76b5c85 100644 --- a/test/test_gym_envs.py +++ b/test/test_gym_envs.py @@ -1,4 +1,6 @@ +import re from itertools import chain +from typing import Callable import gymnasium as gym import pytest @@ -7,8 +9,12 @@ import fancy_gym from test.utils import run_env, run_env_determinism GYM_IDS = [spec.id for spec in gym.envs.registry.values() if - "fancy_gym" not in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point] -GYM_MP_IDS = chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()) + not isinstance(spec.entry_point, Callable) and + "fancy_gym" not in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point + and 'jax' not in spec.id.lower() + and not re.match(r'GymV2.Environment', spec.id) + ] +GYM_MP_IDS = list(chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())) SEED = 1 diff --git a/test/test_metaworld_envs.py b/test/test_metaworld_envs.py index 768958d..55de621 100644 --- a/test/test_metaworld_envs.py +++ b/test/test_metaworld_envs.py @@ -8,8 +8,7 @@ from test.utils import run_env, run_env_determinism METAWORLD_IDS = [f'metaworld:{env.split("-goal-observable")[0]}' for env, _ in ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.items()] -METAWORLD_MP_IDS = chain(*fancy_gym.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()) -print(METAWORLD_MP_IDS) +METAWORLD_MP_IDS = list(chain(*fancy_gym.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())) SEED = 1 diff --git a/test/utils.py b/test/utils.py index 56f739f..51f0c37 100644 --- a/test/utils.py +++ b/test/utils.py @@ -1,9 +1,12 @@ +from typing import List, Type + import gymnasium as gym import numpy as np from fancy_gym import make -def run_env(env_id, iterations=None, seed=0, render=False): +def run_env(env_id: str, iterations: int = None, seed: int = 0, wrappers: List[Type[gym.Wrapper]] = [], + render: bool = False): """ Example for running a DMC based env in the step based setting. The env_id has to be specified as `dmc:domain_name-task_name` or @@ -13,12 +16,15 @@ def run_env(env_id, iterations=None, seed=0, render=False): env_id: Either `dmc:domain_name-task_name` or `dmc:manipulation-environment_name` iterations: Number of rollout steps to run seed: random seeding + wrappers: List of Wrappers to apply to the environment render: Render the episode Returns: observations, rewards, terminations, truncations, actions """ env: gym.Env = make(env_id, seed=seed) + for w in wrappers: + env = w(env) rewards = [] observations = [] actions = [] @@ -60,13 +66,13 @@ def run_env(env_id, iterations=None, seed=0, render=False): return np.array(observations), np.array(rewards), np.array(terminations), np.array(truncations), np.array(actions) -def run_env_determinism(env_id: str, seed: int, iterations: int = None): - traj1 = run_env(env_id, iterations=iterations, seed=seed) - traj2 = run_env(env_id, iterations=iterations, seed=seed) +def run_env_determinism(env_id: str, seed: int, iterations: int = None, wrappers: List[Type[gym.Wrapper]] = []): + traj1 = run_env(env_id, iterations=iterations, seed=seed, wrappers=wrappers) + traj2 = run_env(env_id, iterations=iterations, seed=seed, wrappers=wrappers) # Iterate over two trajectories, which should have the same state and action sequence for i, time_step in enumerate(zip(*traj1, *traj2)): obs1, rwd1, term1, trunc1, ac1, obs2, rwd2, term2, trunc2, ac2 = time_step - assert np.array_equal(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match." + assert np.allclose(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match." assert np.array_equal(ac1, ac2), f"Actions [{i}] {ac1} and {ac2} do not match." assert np.array_equal(rwd1, rwd2), f"Rewards [{i}] {rwd1} and {rwd2} do not match." assert np.array_equal(term1, term2), f"Terminateds [{i}] {term1} and {term2} do not match."