Better handling of envs without defined max_steps
This commit is contained in:
parent
60a4cf11d6
commit
b032dec5fe
@ -8,9 +8,10 @@ from typing import Iterable, Type, Union, Optional
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gymnasium.envs.registration import register, registry
|
from gymnasium.envs.registration import register, registry
|
||||||
from gymnasium.wrappers import FlattenObservation
|
from gymnasium.wrappers import TimeLimit
|
||||||
|
|
||||||
from fancy_gym.utils.env_compatibility import EnvCompatibility
|
from fancy_gym.utils.env_compatibility import EnvCompatibility
|
||||||
|
from fancy_gym.utils.wrappers import FlattenObservation
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from dm_control import suite, manipulation
|
from dm_control import suite, manipulation
|
||||||
@ -31,7 +32,7 @@ from fancy_gym.black_box.factory.controller_factory import get_controller
|
|||||||
from fancy_gym.black_box.factory.phase_generator_factory import get_phase_generator
|
from fancy_gym.black_box.factory.phase_generator_factory import get_phase_generator
|
||||||
from fancy_gym.black_box.factory.trajectory_generator_factory import get_trajectory_generator
|
from fancy_gym.black_box.factory.trajectory_generator_factory import get_trajectory_generator
|
||||||
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
from fancy_gym.black_box.raw_interface_wrapper import RawInterfaceWrapper
|
||||||
from fancy_gym.utils.time_aware_observation import TimeAwareObservation
|
from fancy_gym.utils.wrappers import TimeAwareObservation
|
||||||
from fancy_gym.utils.utils import nested_update
|
from fancy_gym.utils.utils import nested_update
|
||||||
|
|
||||||
|
|
||||||
@ -114,7 +115,7 @@ def make(env_id: str, seed: int, **kwargs):
|
|||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1, **kwargs):
|
def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1, fallback_max_steps=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Helper function for creating a wrapped gym environment using MPs.
|
Helper function for creating a wrapped gym environment using MPs.
|
||||||
It adds all provided wrappers to the specified environment and verifies at least one RawInterfaceWrapper is
|
It adds all provided wrappers to the specified environment and verifies at least one RawInterfaceWrapper is
|
||||||
@ -130,6 +131,8 @@ def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1
|
|||||||
"""
|
"""
|
||||||
# _env = gym.make(env_id)
|
# _env = gym.make(env_id)
|
||||||
_env = make(env_id, seed, **kwargs)
|
_env = make(env_id, seed, **kwargs)
|
||||||
|
if fallback_max_steps:
|
||||||
|
_env = ensure_finite_time(_env, fallback_max_steps)
|
||||||
has_black_box_wrapper = False
|
has_black_box_wrapper = False
|
||||||
for w in wrappers:
|
for w in wrappers:
|
||||||
# only wrap the environment if not BlackBoxWrapper, e.g. for vision
|
# only wrap the environment if not BlackBoxWrapper, e.g. for vision
|
||||||
@ -144,7 +147,7 @@ def _make_wrapped_env(env_id: str, wrappers: Iterable[Type[gym.Wrapper]], seed=1
|
|||||||
def make_bb(
|
def make_bb(
|
||||||
env_id: str, wrappers: Iterable, black_box_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping,
|
env_id: str, wrappers: Iterable, black_box_kwargs: MutableMapping, traj_gen_kwargs: MutableMapping,
|
||||||
controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping, seed: int = 1,
|
controller_kwargs: MutableMapping, phase_kwargs: MutableMapping, basis_kwargs: MutableMapping, seed: int = 1,
|
||||||
**kwargs):
|
fallback_max_steps: int = None, **kwargs):
|
||||||
"""
|
"""
|
||||||
This can also be used standalone for manually building a custom DMP environment.
|
This can also be used standalone for manually building a custom DMP environment.
|
||||||
Args:
|
Args:
|
||||||
@ -172,7 +175,7 @@ def make_bb(
|
|||||||
# Add as first wrapper in order to alter observation
|
# Add as first wrapper in order to alter observation
|
||||||
wrappers.insert(0, TimeAwareObservation)
|
wrappers.insert(0, TimeAwareObservation)
|
||||||
|
|
||||||
env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, fallback_max_steps=fallback_max_steps, **kwargs)
|
||||||
|
|
||||||
# BB expects a spaces.Box to be exposed, need to convert for dict-observations
|
# BB expects a spaces.Box to be exposed, need to convert for dict-observations
|
||||||
if type(env.observation_space) == gym.spaces.dict.Dict:
|
if type(env.observation_space) == gym.spaces.dict.Dict:
|
||||||
@ -209,6 +212,15 @@ def make_bb(
|
|||||||
return bb_env
|
return bb_env
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_finite_time(env: gym.Env, fallback_max_steps=500):
|
||||||
|
cur_limit = env.spec.max_episode_steps
|
||||||
|
if not cur_limit:
|
||||||
|
if hasattr(env.unwrapped, 'max_path_length'):
|
||||||
|
return TimeLimit(env, env.unwrapped.__getattribute__('max_path_length'))
|
||||||
|
return TimeLimit(env, fallback_max_steps)
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
def get_env_duration(env: gym.Env):
|
def get_env_duration(env: gym.Env):
|
||||||
try:
|
try:
|
||||||
duration = env.spec.max_episode_steps * env.dt
|
duration = env.spec.max_episode_steps * env.dt
|
||||||
|
Loading…
Reference in New Issue
Block a user