updated to new API, so tests still failing
This commit is contained in:
parent
ec2063aa0b
commit
c53924d9fc
@ -55,7 +55,6 @@ class BaseReacherEnv(gym.Env):
|
||||
self.fig = None
|
||||
|
||||
self._steps = 0
|
||||
self.seed()
|
||||
|
||||
@property
|
||||
def dt(self) -> Union[float, int]:
|
||||
@ -72,6 +71,7 @@ class BaseReacherEnv(gym.Env):
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||
-> Tuple[ObsType, Dict[str, Any]]:
|
||||
# Sample only orientation of first link, i.e. the arm is always straight.
|
||||
super(BaseReacherEnv, self).reset(seed=seed, options=options)
|
||||
try:
|
||||
random_start = options.get('random_start', self.random_start)
|
||||
except AttributeError:
|
||||
@ -128,10 +128,6 @@ class BaseReacherEnv(gym.Env):
|
||||
def _terminate(self, info) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.np_random, seed = seeding.np_random(seed)
|
||||
return [seed]
|
||||
|
||||
def close(self):
|
||||
super(BaseReacherEnv, self).close()
|
||||
del self.fig
|
||||
|
@ -57,11 +57,16 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
|
||||
-> Tuple[ObsType, Dict[str, Any]]:
|
||||
|
||||
# initialize seed here as the random goal needs to be generated before the super reset()
|
||||
gym.Env.reset(self, seed=seed, options=options)
|
||||
|
||||
self._generate_hole()
|
||||
self._set_patches()
|
||||
self.reward_function.reset()
|
||||
|
||||
return super().reset()
|
||||
# do not provide seed to avoid setting it twice
|
||||
return super(HoleReacherEnv, self).reset(options=options)
|
||||
|
||||
def _get_reward(self, action: np.ndarray) -> (float, dict):
|
||||
return self.reward_function.get_reward(self)
|
||||
@ -224,6 +229,3 @@ class HoleReacherEnv(BaseReacherDirectEnv):
|
||||
self.fig.gca().add_patch(left_block)
|
||||
self.fig.gca().add_patch(right_block)
|
||||
self.fig.gca().add_patch(hole_floor)
|
||||
|
||||
|
||||
|
||||
|
11
fancy_gym/utils/env_compatibility.py
Normal file
11
fancy_gym/utils/env_compatibility.py
Normal file
@ -0,0 +1,11 @@
|
||||
import gymnasium as gym
|
||||
|
||||
|
||||
class EnvCompatibility(gym.wrappers.EnvCompatibility):
|
||||
def __getattr__(self, item):
|
||||
"""Propagate only non-existent properties to wrapped env."""
|
||||
if item.startswith('_'):
|
||||
raise AttributeError("attempted to get missing private attribute '{}'".format(item))
|
||||
if item in self.__dict__:
|
||||
return getattr(self, item)
|
||||
return getattr(self.env, item)
|
@ -3,12 +3,14 @@ import uuid
|
||||
from collections.abc import MutableMapping
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
from typing import Iterable, Type, Union
|
||||
from typing import Iterable, Type, Union, Optional
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium.envs.registration import register, registry
|
||||
|
||||
from fancy_gym.utils.env_compatibility import EnvCompatibility
|
||||
|
||||
try:
|
||||
from dm_control import suite, manipulation
|
||||
from shimmy.dm_control_compatibility import EnvType
|
||||
@ -186,9 +188,9 @@ def make_bb(
|
||||
|
||||
def get_env_duration(env: gym.Env):
|
||||
try:
|
||||
# TODO Remove if this is in the compatibility class
|
||||
duration = env.spec.max_episode_steps * env.dt
|
||||
except (AttributeError, TypeError) as e:
|
||||
# TODO Remove if this information is in the compatibility class
|
||||
logging.error(f'Attributes env.spec.max_episode_steps and env.dt are not available. '
|
||||
f'Assuming you are using dm_control. Please make sure you have ran '
|
||||
f'"pip install shimmy[dm_control]" for that.')
|
||||
@ -300,7 +302,7 @@ def make_bb_env_helper(**kwargs):
|
||||
# return env
|
||||
|
||||
|
||||
def make_metaworld(env_id: str, seed: int, **kwargs):
|
||||
def make_metaworld(env_id: str, seed: int, render_mode: Optional[str] = None, **kwargs):
|
||||
if env_id not in metaworld.ML1.ENV_NAMES:
|
||||
raise ValueError(f'Specified environment "{env_id}" not present in metaworld ML1.')
|
||||
|
||||
@ -314,7 +316,7 @@ def make_metaworld(env_id: str, seed: int, **kwargs):
|
||||
max_episode_steps = _env.max_path_length
|
||||
|
||||
# TODO remove this as soon as there is support for the new API
|
||||
_env = gym.wrappers.EnvCompatibility(_env)
|
||||
_env = EnvCompatibility(_env, render_mode)
|
||||
|
||||
gym_id = uuid.uuid4().hex + '-v1'
|
||||
|
||||
|
@ -14,20 +14,20 @@ DM_CONTROL_IDS = [spec.id for spec in gym.envs.registry.values() if
|
||||
spec.id.startswith('dm_control/')
|
||||
and 'compatibility-env-v0' not in spec.id
|
||||
and 'lqr-lqr' not in spec.id]
|
||||
DM_control_MP_IDS = chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
||||
DM_control_MP_IDS = list(chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()))
|
||||
SEED = 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize('env_id', DM_CONTROL_IDS)
|
||||
def test_step_dm_control_functionality(env_id: str):
|
||||
"""Tests that suite step environments run without errors using random actions."""
|
||||
run_env(env_id, 1000)
|
||||
run_env(env_id, 5000, wrappers=[gym.wrappers.FlattenObservation])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('env_id', DM_CONTROL_IDS)
|
||||
def test_step_dm_control_determinism(env_id: str):
|
||||
"""Tests that for step environments identical seeds produce identical trajectories."""
|
||||
run_env_determinism(env_id, SEED, 1000)
|
||||
run_env_determinism(env_id, SEED, 5000, wrappers=[gym.wrappers.FlattenObservation])
|
||||
|
||||
|
||||
# @pytest.mark.parametrize('env_id', MANIPULATION_IDS)
|
||||
|
@ -1,4 +1,4 @@
|
||||
import itertools
|
||||
from itertools import chain
|
||||
from typing import Callable
|
||||
|
||||
import fancy_gym
|
||||
@ -10,7 +10,7 @@ from test.utils import run_env, run_env_determinism
|
||||
CUSTOM_IDS = [id for id, spec in gym.envs.registry.items() if
|
||||
not isinstance(spec.entry_point, Callable) and
|
||||
"fancy_gym" in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point]
|
||||
CUSTOM_MP_IDS = itertools.chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
||||
CUSTOM_MP_IDS = list(chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()))
|
||||
SEED = 1
|
||||
|
||||
|
||||
|
@ -1,4 +1,6 @@
|
||||
import re
|
||||
from itertools import chain
|
||||
from typing import Callable
|
||||
|
||||
import gymnasium as gym
|
||||
import pytest
|
||||
@ -7,8 +9,12 @@ import fancy_gym
|
||||
from test.utils import run_env, run_env_determinism
|
||||
|
||||
GYM_IDS = [spec.id for spec in gym.envs.registry.values() if
|
||||
"fancy_gym" not in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point]
|
||||
GYM_MP_IDS = chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
||||
not isinstance(spec.entry_point, Callable) and
|
||||
"fancy_gym" not in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point
|
||||
and 'jax' not in spec.id.lower()
|
||||
and not re.match(r'GymV2.Environment', spec.id)
|
||||
]
|
||||
GYM_MP_IDS = list(chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()))
|
||||
SEED = 1
|
||||
|
||||
|
||||
|
@ -8,8 +8,7 @@ from test.utils import run_env, run_env_determinism
|
||||
|
||||
METAWORLD_IDS = [f'metaworld:{env.split("-goal-observable")[0]}' for env, _ in
|
||||
ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.items()]
|
||||
METAWORLD_MP_IDS = chain(*fancy_gym.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
||||
print(METAWORLD_MP_IDS)
|
||||
METAWORLD_MP_IDS = list(chain(*fancy_gym.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()))
|
||||
SEED = 1
|
||||
|
||||
|
||||
|
@ -1,9 +1,12 @@
|
||||
from typing import List, Type
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from fancy_gym import make
|
||||
|
||||
|
||||
def run_env(env_id, iterations=None, seed=0, render=False):
|
||||
def run_env(env_id: str, iterations: int = None, seed: int = 0, wrappers: List[Type[gym.Wrapper]] = [],
|
||||
render: bool = False):
|
||||
"""
|
||||
Example for running a DMC based env in the step based setting.
|
||||
The env_id has to be specified as `dmc:domain_name-task_name` or
|
||||
@ -13,12 +16,15 @@ def run_env(env_id, iterations=None, seed=0, render=False):
|
||||
env_id: Either `dmc:domain_name-task_name` or `dmc:manipulation-environment_name`
|
||||
iterations: Number of rollout steps to run
|
||||
seed: random seeding
|
||||
wrappers: List of Wrappers to apply to the environment
|
||||
render: Render the episode
|
||||
|
||||
Returns: observations, rewards, terminations, truncations, actions
|
||||
|
||||
"""
|
||||
env: gym.Env = make(env_id, seed=seed)
|
||||
for w in wrappers:
|
||||
env = w(env)
|
||||
rewards = []
|
||||
observations = []
|
||||
actions = []
|
||||
@ -60,13 +66,13 @@ def run_env(env_id, iterations=None, seed=0, render=False):
|
||||
return np.array(observations), np.array(rewards), np.array(terminations), np.array(truncations), np.array(actions)
|
||||
|
||||
|
||||
def run_env_determinism(env_id: str, seed: int, iterations: int = None):
|
||||
traj1 = run_env(env_id, iterations=iterations, seed=seed)
|
||||
traj2 = run_env(env_id, iterations=iterations, seed=seed)
|
||||
def run_env_determinism(env_id: str, seed: int, iterations: int = None, wrappers: List[Type[gym.Wrapper]] = []):
|
||||
traj1 = run_env(env_id, iterations=iterations, seed=seed, wrappers=wrappers)
|
||||
traj2 = run_env(env_id, iterations=iterations, seed=seed, wrappers=wrappers)
|
||||
# Iterate over two trajectories, which should have the same state and action sequence
|
||||
for i, time_step in enumerate(zip(*traj1, *traj2)):
|
||||
obs1, rwd1, term1, trunc1, ac1, obs2, rwd2, term2, trunc2, ac2 = time_step
|
||||
assert np.array_equal(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match."
|
||||
assert np.allclose(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match."
|
||||
assert np.array_equal(ac1, ac2), f"Actions [{i}] {ac1} and {ac2} do not match."
|
||||
assert np.array_equal(rwd1, rwd2), f"Rewards [{i}] {rwd1} and {rwd2} do not match."
|
||||
assert np.array_equal(term1, term2), f"Terminateds [{i}] {term1} and {term2} do not match."
|
||||
|
Loading…
Reference in New Issue
Block a user