updated to new API, so tests still failing

This commit is contained in:
Fabian 2023-01-17 08:27:29 +01:00
parent ec2063aa0b
commit c53924d9fc
9 changed files with 49 additions and 27 deletions

View File

@ -55,7 +55,6 @@ class BaseReacherEnv(gym.Env):
self.fig = None
self._steps = 0
self.seed()
@property
def dt(self) -> Union[float, int]:
@ -72,6 +71,7 @@ class BaseReacherEnv(gym.Env):
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
-> Tuple[ObsType, Dict[str, Any]]:
# Sample only orientation of first link, i.e. the arm is always straight.
super(BaseReacherEnv, self).reset(seed=seed, options=options)
try:
random_start = options.get('random_start', self.random_start)
except AttributeError:
@ -128,10 +128,6 @@ class BaseReacherEnv(gym.Env):
def _terminate(self, info) -> bool:
raise NotImplementedError
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def close(self):
super(BaseReacherEnv, self).close()
del self.fig

View File

@ -57,11 +57,16 @@ class HoleReacherEnv(BaseReacherDirectEnv):
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) \
-> Tuple[ObsType, Dict[str, Any]]:
# initialize seed here as the random goal needs to be generated before the super reset()
gym.Env.reset(self, seed=seed, options=options)
self._generate_hole()
self._set_patches()
self.reward_function.reset()
return super().reset()
# do not provide seed to avoid setting it twice
return super(HoleReacherEnv, self).reset(options=options)
def _get_reward(self, action: np.ndarray) -> (float, dict):
return self.reward_function.get_reward(self)
@ -224,6 +229,3 @@ class HoleReacherEnv(BaseReacherDirectEnv):
self.fig.gca().add_patch(left_block)
self.fig.gca().add_patch(right_block)
self.fig.gca().add_patch(hole_floor)

View File

@ -0,0 +1,11 @@
import gymnasium as gym
class EnvCompatibility(gym.wrappers.EnvCompatibility):
def __getattr__(self, item):
"""Propagate only non-existent properties to wrapped env."""
if item.startswith('_'):
raise AttributeError("attempted to get missing private attribute '{}'".format(item))
if item in self.__dict__:
return getattr(self, item)
return getattr(self.env, item)

View File

@ -3,12 +3,14 @@ import uuid
from collections.abc import MutableMapping
from copy import deepcopy
from math import ceil
from typing import Iterable, Type, Union
from typing import Iterable, Type, Union, Optional
import gymnasium as gym
import numpy as np
from gymnasium.envs.registration import register, registry
from fancy_gym.utils.env_compatibility import EnvCompatibility
try:
from dm_control import suite, manipulation
from shimmy.dm_control_compatibility import EnvType
@ -186,9 +188,9 @@ def make_bb(
def get_env_duration(env: gym.Env):
try:
# TODO Remove if this is in the compatibility class
duration = env.spec.max_episode_steps * env.dt
except (AttributeError, TypeError) as e:
# TODO Remove if this information is in the compatibility class
logging.error(f'Attributes env.spec.max_episode_steps and env.dt are not available. '
f'Assuming you are using dm_control. Please make sure you have ran '
f'"pip install shimmy[dm_control]" for that.')
@ -300,7 +302,7 @@ def make_bb_env_helper(**kwargs):
# return env
def make_metaworld(env_id: str, seed: int, **kwargs):
def make_metaworld(env_id: str, seed: int, render_mode: Optional[str] = None, **kwargs):
if env_id not in metaworld.ML1.ENV_NAMES:
raise ValueError(f'Specified environment "{env_id}" not present in metaworld ML1.')
@ -314,7 +316,7 @@ def make_metaworld(env_id: str, seed: int, **kwargs):
max_episode_steps = _env.max_path_length
# TODO remove this as soon as there is support for the new API
_env = gym.wrappers.EnvCompatibility(_env)
_env = EnvCompatibility(_env, render_mode)
gym_id = uuid.uuid4().hex + '-v1'

View File

@ -14,20 +14,20 @@ DM_CONTROL_IDS = [spec.id for spec in gym.envs.registry.values() if
spec.id.startswith('dm_control/')
and 'compatibility-env-v0' not in spec.id
and 'lqr-lqr' not in spec.id]
DM_control_MP_IDS = chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
DM_control_MP_IDS = list(chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()))
SEED = 1
@pytest.mark.parametrize('env_id', DM_CONTROL_IDS)
def test_step_dm_control_functionality(env_id: str):
"""Tests that suite step environments run without errors using random actions."""
run_env(env_id, 1000)
run_env(env_id, 5000, wrappers=[gym.wrappers.FlattenObservation])
@pytest.mark.parametrize('env_id', DM_CONTROL_IDS)
def test_step_dm_control_determinism(env_id: str):
"""Tests that for step environments identical seeds produce identical trajectories."""
run_env_determinism(env_id, SEED, 1000)
run_env_determinism(env_id, SEED, 5000, wrappers=[gym.wrappers.FlattenObservation])
# @pytest.mark.parametrize('env_id', MANIPULATION_IDS)

View File

@ -1,4 +1,4 @@
import itertools
from itertools import chain
from typing import Callable
import fancy_gym
@ -10,7 +10,7 @@ from test.utils import run_env, run_env_determinism
CUSTOM_IDS = [id for id, spec in gym.envs.registry.items() if
not isinstance(spec.entry_point, Callable) and
"fancy_gym" in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point]
CUSTOM_MP_IDS = itertools.chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
CUSTOM_MP_IDS = list(chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()))
SEED = 1

View File

@ -1,4 +1,6 @@
import re
from itertools import chain
from typing import Callable
import gymnasium as gym
import pytest
@ -7,8 +9,12 @@ import fancy_gym
from test.utils import run_env, run_env_determinism
GYM_IDS = [spec.id for spec in gym.envs.registry.values() if
"fancy_gym" not in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point]
GYM_MP_IDS = chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
not isinstance(spec.entry_point, Callable) and
"fancy_gym" not in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point
and 'jax' not in spec.id.lower()
and not re.match(r'GymV2.Environment', spec.id)
]
GYM_MP_IDS = list(chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()))
SEED = 1

View File

@ -8,8 +8,7 @@ from test.utils import run_env, run_env_determinism
METAWORLD_IDS = [f'metaworld:{env.split("-goal-observable")[0]}' for env, _ in
ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE.items()]
METAWORLD_MP_IDS = chain(*fancy_gym.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
print(METAWORLD_MP_IDS)
METAWORLD_MP_IDS = list(chain(*fancy_gym.ALL_METAWORLD_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()))
SEED = 1

View File

@ -1,9 +1,12 @@
from typing import List, Type
import gymnasium as gym
import numpy as np
from fancy_gym import make
def run_env(env_id, iterations=None, seed=0, render=False):
def run_env(env_id: str, iterations: int = None, seed: int = 0, wrappers: List[Type[gym.Wrapper]] = [],
render: bool = False):
"""
Example for running a DMC based env in the step based setting.
The env_id has to be specified as `dmc:domain_name-task_name` or
@ -13,12 +16,15 @@ def run_env(env_id, iterations=None, seed=0, render=False):
env_id: Either `dmc:domain_name-task_name` or `dmc:manipulation-environment_name`
iterations: Number of rollout steps to run
seed: random seeding
wrappers: List of Wrappers to apply to the environment
render: Render the episode
Returns: observations, rewards, terminations, truncations, actions
"""
env: gym.Env = make(env_id, seed=seed)
for w in wrappers:
env = w(env)
rewards = []
observations = []
actions = []
@ -60,13 +66,13 @@ def run_env(env_id, iterations=None, seed=0, render=False):
return np.array(observations), np.array(rewards), np.array(terminations), np.array(truncations), np.array(actions)
def run_env_determinism(env_id: str, seed: int, iterations: int = None):
traj1 = run_env(env_id, iterations=iterations, seed=seed)
traj2 = run_env(env_id, iterations=iterations, seed=seed)
def run_env_determinism(env_id: str, seed: int, iterations: int = None, wrappers: List[Type[gym.Wrapper]] = []):
traj1 = run_env(env_id, iterations=iterations, seed=seed, wrappers=wrappers)
traj2 = run_env(env_id, iterations=iterations, seed=seed, wrappers=wrappers)
# Iterate over two trajectories, which should have the same state and action sequence
for i, time_step in enumerate(zip(*traj1, *traj2)):
obs1, rwd1, term1, trunc1, ac1, obs2, rwd2, term2, trunc2, ac2 = time_step
assert np.array_equal(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match."
assert np.allclose(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match."
assert np.array_equal(ac1, ac2), f"Actions [{i}] {ac1} and {ac2} do not match."
assert np.array_equal(rwd1, rwd2), f"Rewards [{i}] {rwd1} and {rwd2} do not match."
assert np.array_equal(term1, term2), f"Terminateds [{i}] {term1} and {term2} do not match."