diff --git a/fancy_gym/utils/make_env_helpers.py b/fancy_gym/utils/make_env_helpers.py index 42096f3..d3642a4 100644 --- a/fancy_gym/utils/make_env_helpers.py +++ b/fancy_gym/utils/make_env_helpers.py @@ -8,6 +8,7 @@ from typing import Iterable, Type, Union, Optional import gymnasium as gym import numpy as np from gymnasium.envs.registration import register, registry +from gymnasium.wrappers import FlattenObservation from fancy_gym.utils.env_compatibility import EnvCompatibility @@ -165,6 +166,10 @@ def make_bb( env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) + # BB expects a spaces.Box to be exposed, need to convert for dict-observations + if type(env.observation_space) == gym.spaces.dict.Dict: + env = FlattenObservation(env) + traj_gen_kwargs['action_dim'] = traj_gen_kwargs.get('action_dim', np.prod(env.action_space.shape).item()) if black_box_kwargs.get('duration') is None: