Better handling of envs without defined max_steps

This commit is contained in:
Dominik Moritz Roth 2023-06-18 14:23:59 +02:00
parent 60a4cf11d6
commit b032dec5fe

View File

@ -8,9 +8,10 @@ from typing import Iterable, Type, Union, Optional
import gymnasium as gym
import numpy as np
from gymnasium.envs.registration import register, registry
from gymnasium.wrappers import FlattenObservation
from gymnasium.wrappers import TimeLimit
from fancy_gym.utils.env_compatibility import EnvCompatibility
from fancy_gym.utils.wrappers import FlattenObservation
try:
from dm_control import suite, manipulation
@ -31,7 +32,7 @@ from fancy_gym.black_box.factory.controller_factory import get_controller
from fancy_gym.black_box.factory.phase_generator_factory import get_phase_generator
from fancy_gym.black_box.factory.trajectory_generator_factory import get_trajectory_generator
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
from fancy_gym.utils.time_aware_observation import TimeAwareObservation
from fancy_gym.utils.wrappers import TimeAwareObservation
from fancy_gym.utils.utils import nested_update
@ -114,7 +115,7 @@ def make(env_id: str, seed: int, **kwargs):
return env
def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1, **kwargs):
def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1, fallback_max_steps=None, **kwargs):
"""
Helper function for creating a wrapped gym environment using MPs.
It adds all provided wrappers to the specified environment and verifies at least one RawInterfaceWrapper is
@ -130,6 +131,8 @@ def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1
"""
# _env = gym.make(env_id)
_env = make(env_id, seed, **kwargs)
if fallback_max_steps:
_env = ensure_finite_time(_env, fallback_max_steps)
has_black_box_wrapper = False
for w in wrappers:
# only wrap the environment if not BlackBoxWrapper, e.g. for vision
@ -144,7 +147,7 @@ def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1
def make_bb(
env_id: str, wrappers: Iterable, black_box_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping,
controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping, seed: int = 1,
**kwargs):
fallback_max_steps: int = None, **kwargs):
"""
This can also be used standalone for manually building a custom DMP environment.
Args:
@ -172,7 +175,7 @@ def make_bb(
# Add as first wrapper in order to alter observation
wrappers.insert(0, TimeAwareObservation)
env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, fallback_max_steps=fallback_max_steps, **kwargs)
# BB expects a spaces.Box to be exposed, need to convert for dict-observations
if type(env.observation_space) == gym.spaces.dict.Dict:
@ -209,6 +212,15 @@ def make_bb(
return bb_env
def ensure_finite_time(env: gym.Env, fallback_max_steps=500):
cur_limit = env.spec.max_episode_steps
if not cur_limit:
if hasattr(env.unwrapped, 'max_path_length'):
return TimeLimit(env, env.unwrapped.__getattribute__('max_path_length'))
return TimeLimit(env, fallback_max_steps)
return env
def get_env_duration(env: gym.Env):
try:
duration = env.spec.max_episode_steps * env.dt