updated to new API, so tests still failing
This commit is contained in:
		
							parent
							
								
									ec2063aa0b
								
							
						
					
					
						commit
						c53924d9fc
					
				@ -55,7 +55,6 @@ class BaseReacherEnv(gym.Env):
 | 
			
		||||
        self.fig = None
 | 
			
		||||
 | 
			
		||||
        self._steps = 0
 | 
			
		||||
        self.seed()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def dt(self) -> Union[float, int]:
 | 
			
		||||
@ -72,6 +71,7 @@ class BaseReacherEnv(gym.Env):
 | 
			
		||||
    def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
 | 
			
		||||
            -> Tuple[ObsType, Dict[str, Any]]:
 | 
			
		||||
        # Sample only orientation of first link, i.e. the arm is always straight.
 | 
			
		||||
        super(BaseReacherEnv, self).reset(seed=seed, options=options)
 | 
			
		||||
        try:
 | 
			
		||||
            random_start = options.get('random_start', self.random_start)
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
@ -128,10 +128,6 @@ class BaseReacherEnv(gym.Env):
 | 
			
		||||
    def _terminate(self, info) -> bool:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def seed(self, seed=None):
 | 
			
		||||
        self.np_random, seed = seeding.np_random(seed)
 | 
			
		||||
        return [seed]
 | 
			
		||||
 | 
			
		||||
    def close(self):
 | 
			
		||||
        super(BaseReacherEnv, self).close()
 | 
			
		||||
        del self.fig
 | 
			
		||||
 | 
			
		||||
@ -57,11 +57,16 @@ class HoleReacherEnv(BaseReacherDirectEnv):
 | 
			
		||||
 | 
			
		||||
    def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
 | 
			
		||||
            -> Tuple[ObsType, Dict[str, Any]]:
 | 
			
		||||
 | 
			
		||||
        # initialize seed here as the random goal needs to be generated before the super reset()
 | 
			
		||||
        gym.Env.reset(self, seed=seed, options=options)
 | 
			
		||||
 | 
			
		||||
        self._generate_hole()
 | 
			
		||||
        self._set_patches()
 | 
			
		||||
        self.reward_function.reset()
 | 
			
		||||
 | 
			
		||||
        return super().reset()
 | 
			
		||||
        # do not provide seed to avoid setting it twice
 | 
			
		||||
        return super(HoleReacherEnv, self).reset(options=options)
 | 
			
		||||
 | 
			
		||||
    def _get_reward(self, action: np.ndarray) -> (float, dict):
 | 
			
		||||
        return self.reward_function.get_reward(self)
 | 
			
		||||
@ -224,6 +229,3 @@ class HoleReacherEnv(BaseReacherDirectEnv):
 | 
			
		||||
            self.fig.gca().add_patch(left_block)
 | 
			
		||||
            self.fig.gca().add_patch(right_block)
 | 
			
		||||
            self.fig.gca().add_patch(hole_floor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										11
									
								
								fancy_gym/utils/env_compatibility.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								fancy_gym/utils/env_compatibility.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,11 @@
 | 
			
		||||
import gymnasium as gym
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EnvCompatibility(gym.wrappers.EnvCompatibility):
 | 
			
		||||
    def __getattr__(self, item):
 | 
			
		||||
        """Propagate only non-existent properties to wrapped env."""
 | 
			
		||||
        if item.startswith('_'):
 | 
			
		||||
            raise AttributeError("attempted to get missing private attribute '{}'".format(item))
 | 
			
		||||
        if item in self.__dict__:
 | 
			
		||||
            return getattr(self, item)
 | 
			
		||||
        return getattr(self.env, item)
 | 
			
		||||
@ -3,12 +3,14 @@ import uuid
 | 
			
		||||
from collections.abc import MutableMapping
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from math import ceil
 | 
			
		||||
from typing import Iterable, Type, Union
 | 
			
		||||
from typing import Iterable, Type, Union, Optional
 | 
			
		||||
 | 
			
		||||
import gymnasium as gym
 | 
			
		||||
import numpy as np
 | 
			
		||||
from gymnasium.envs.registration import register, registry
 | 
			
		||||
 | 
			
		||||
from fancy_gym.utils.env_compatibility import EnvCompatibility
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from dm_control import suite, manipulation
 | 
			
		||||
    from shimmy.dm_control_compatibility import EnvType
 | 
			
		||||
@ -186,9 +188,9 @@ def make_bb(
 | 
			
		||||
 | 
			
		||||
def get_env_duration(env: gym.Env):
 | 
			
		||||
    try:
 | 
			
		||||
        # TODO Remove if this is in the compatibility class
 | 
			
		||||
        duration = env.spec.max_episode_steps * env.dt
 | 
			
		||||
    except (AttributeError, TypeError) as e:
 | 
			
		||||
        # TODO Remove if this information is in the compatibility class
 | 
			
		||||
        logging.error(f'Attributes env.spec.max_episode_steps and env.dt are not available. '
 | 
			
		||||
                      f'Assuming you are using dm_control. Please make sure you have ran '
 | 
			
		||||
                      f'"pip install shimmy[dm_control]" for that.')
 | 
			
		||||
@ -300,7 +302,7 @@ def make_bb_env_helper(**kwargs):
 | 
			
		||||
#     return env
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_metaworld(env_id: str, seed: int, **kwargs):
 | 
			
		||||
def make_metaworld(env_id: str, seed: int, render_mode: Optional[str] = None, **kwargs):
 | 
			
		||||
    if env_id not in metaworld.ML1.ENV_NAMES:
 | 
			
		||||
        raise ValueError(f'Specified environment "{env_id}" not present in metaworld ML1.')
 | 
			
		||||
 | 
			
		||||
@ -314,7 +316,7 @@ def make_metaworld(env_id: str, seed: int, **kwargs):
 | 
			
		||||
    max_episode_steps = _env.max_path_length
 | 
			
		||||
 | 
			
		||||
    # TODO remove this as soon as there is support for the new API
 | 
			
		||||
    _env = gym.wrappers.EnvCompatibility(_env)
 | 
			
		||||
    _env = EnvCompatibility(_env, render_mode)
 | 
			
		||||
 | 
			
		||||
    gym_id = uuid.uuid4().hex + '-v1'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -14,20 +14,20 @@ DM_CONTROL_IDS = [spec.id for spec in gym.envs.registry.values() if
 | 
			
		||||
                  spec.id.startswith('dm_control/')
 | 
			
		||||
                  and 'compatibility-env-v0' not in spec.id
 | 
			
		||||
                  and 'lqr-lqr' not in spec.id]
 | 
			
		||||
DM_control_MP_IDS = chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
 | 
			
		||||
DM_control_MP_IDS = list(chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()))
 | 
			
		||||
SEED = 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize('env_id', DM_CONTROL_IDS)
 | 
			
		||||
def test_step_dm_control_functionality(env_id: str):
 | 
			
		||||
    """Tests that suite step environments run without errors using random actions."""
 | 
			
		||||
    run_env(env_id, 1000)
 | 
			
		||||
    run_env(env_id, 5000, wrappers=[gym.wrappers.FlattenObservation])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize('env_id', DM_CONTROL_IDS)
 | 
			
		||||
def test_step_dm_control_determinism(env_id: str):
 | 
			
		||||
    """Tests that for step environments identical seeds produce identical trajectories."""
 | 
			
		||||
    run_env_determinism(env_id, SEED, 1000)
 | 
			
		||||
    run_env_determinism(env_id, SEED, 5000, wrappers=[gym.wrappers.FlattenObservation])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# @pytest.mark.parametrize('env_id', MANIPULATION_IDS)
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
import itertools
 | 
			
		||||
from itertools import chain
 | 
			
		||||
from typing import Callable
 | 
			
		||||
 | 
			
		||||
import fancy_gym
 | 
			
		||||
@ -10,7 +10,7 @@ from test.utils import run_env, run_env_determinism
 | 
			
		||||
CUSTOM_IDS = [id for id, spec in gym.envs.registry.items() if
 | 
			
		||||
              not isinstance(spec.entry_point, Callable) and
 | 
			
		||||
              "fancy_gym" in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point]
 | 
			
		||||
CUSTOM_MP_IDS = itertools.chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
 | 
			
		||||
CUSTOM_MP_IDS = list(chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()))
 | 
			
		||||
SEED = 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,6 @@
 | 
			
		||||
import re
 | 
			
		||||
from itertools import chain
 | 
			
		||||
from typing import Callable
 | 
			
		||||
 | 
			
		||||
import gymnasium as gym
 | 
			
		||||
import pytest
 | 
			
		||||
@ -7,8 +9,12 @@ import fancy_gym
 | 
			
		||||
from test.utils import run_env, run_env_determinism
 | 
			
		||||
 | 
			
		||||
GYM_IDS = [spec.id for spec in gym.envs.registry.values() if
 | 
			
		||||
           "fancy_gym" not in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point]
 | 
			
		||||
GYM_MP_IDS = chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
 | 
			
		||||
           not isinstance(spec.entry_point, Callable) and
 | 
			
		||||
           "fancy_gym" not in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point
 | 
			
		||||
           and 'jax' not in spec.id.lower()
 | 
			
		||||
           and not re.match(r'GymV2.Environment', spec.id)
 | 
			
		||||
           ]
 | 
			
		||||
GYM_MP_IDS = list(chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()))
 | 
			
		||||
SEED = 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -8,8 +8,7 @@ from test.utils import run_env, run_env_determinism
 | 
			
		||||
 | 
			
		||||
METAWORLD_IDS = [f'metaworld:{env.split("-goal-observable")[0]}' for env, _ in
 | 
			
		||||
                 ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.items()]
 | 
			
		||||
METAWORLD_MP_IDS = chain(*fancy_gym.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
 | 
			
		||||
print(METAWORLD_MP_IDS)
 | 
			
		||||
METAWORLD_MP_IDS = list(chain(*fancy_gym.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()))
 | 
			
		||||
SEED = 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,9 +1,12 @@
 | 
			
		||||
from typing import List, Type
 | 
			
		||||
 | 
			
		||||
import gymnasium as gym
 | 
			
		||||
import numpy as np
 | 
			
		||||
from fancy_gym import make
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_env(env_id, iterations=None, seed=0, render=False):
 | 
			
		||||
def run_env(env_id: str, iterations: int = None, seed: int = 0, wrappers: List[Type[gym.Wrapper]] = [],
 | 
			
		||||
            render: bool = False):
 | 
			
		||||
    """
 | 
			
		||||
    Example for running a DMC based env in the step based setting.
 | 
			
		||||
    The env_id has to be specified as `dmc:domain_name-task_name` or
 | 
			
		||||
@ -13,12 +16,15 @@ def run_env(env_id, iterations=None, seed=0, render=False):
 | 
			
		||||
        env_id: Either `dmc:domain_name-task_name` or `dmc:manipulation-environment_name`
 | 
			
		||||
        iterations: Number of rollout steps to run
 | 
			
		||||
        seed: random seeding
 | 
			
		||||
        wrappers: List of Wrappers to apply to the environment
 | 
			
		||||
        render: Render the episode
 | 
			
		||||
 | 
			
		||||
    Returns: observations, rewards, terminations, truncations, actions
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
    env: gym.Env = make(env_id, seed=seed)
 | 
			
		||||
    for w in wrappers:
 | 
			
		||||
        env = w(env)
 | 
			
		||||
    rewards = []
 | 
			
		||||
    observations = []
 | 
			
		||||
    actions = []
 | 
			
		||||
@ -60,13 +66,13 @@ def run_env(env_id, iterations=None, seed=0, render=False):
 | 
			
		||||
    return np.array(observations), np.array(rewards), np.array(terminations), np.array(truncations), np.array(actions)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_env_determinism(env_id: str, seed: int, iterations: int = None):
 | 
			
		||||
    traj1 = run_env(env_id, iterations=iterations, seed=seed)
 | 
			
		||||
    traj2 = run_env(env_id, iterations=iterations, seed=seed)
 | 
			
		||||
def run_env_determinism(env_id: str, seed: int, iterations: int = None, wrappers: List[Type[gym.Wrapper]] = []):
 | 
			
		||||
    traj1 = run_env(env_id, iterations=iterations, seed=seed, wrappers=wrappers)
 | 
			
		||||
    traj2 = run_env(env_id, iterations=iterations, seed=seed, wrappers=wrappers)
 | 
			
		||||
    # Iterate over two trajectories, which should have the same state and action sequence
 | 
			
		||||
    for i, time_step in enumerate(zip(*traj1, *traj2)):
 | 
			
		||||
        obs1, rwd1, term1, trunc1, ac1, obs2, rwd2, term2, trunc2, ac2 = time_step
 | 
			
		||||
        assert np.array_equal(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match."
 | 
			
		||||
        assert np.allclose(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match."
 | 
			
		||||
        assert np.array_equal(ac1, ac2), f"Actions [{i}] {ac1} and {ac2} do not match."
 | 
			
		||||
        assert np.array_equal(rwd1, rwd2), f"Rewards [{i}] {rwd1} and {rwd2} do not match."
 | 
			
		||||
        assert np.array_equal(term1, term2), f"Terminateds [{i}] {term1} and {term2} do not match."
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user