Auto convert output spaces.Dict to Box for BB-Envs

This commit is contained in:
Dominik Moritz Roth 2023-05-27 11:39:47 +02:00
parent dabfc7cafe
commit d6df6779a1

View File

@ -8,6 +8,7 @@ from typing import Iterable, Type, Union, Optional
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
from gymnasium.envs.registration import register, registry from gymnasium.envs.registration import register, registry
from gymnasium.wrappers import FlattenObservation
from fancy_gym.utils.env_compatibility import EnvCompatibility from fancy_gym.utils.env_compatibility import EnvCompatibility
@ -165,6 +166,10 @@ def make_bb(
env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs) env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
# BB expects a spaces.Box to be exposed, need to convert for dict-observations
if type(env.observation_space) == gym.spaces.dict.Dict:
env = FlattenObservation(env)
traj_gen_kwargs['action_dim'] = traj_gen_kwargs.get('action_dim', np.prod(env.action_space.shape).item()) traj_gen_kwargs['action_dim'] = traj_gen_kwargs.get('action_dim', np.prod(env.action_space.shape).item())
if black_box_kwargs.get('duration') is None: if black_box_kwargs.get('duration') is None: