Auto convert output spaces.Dict to Box for BB-Envs
This commit is contained in:
parent
dabfc7cafe
commit
d6df6779a1
@ -8,6 +8,7 @@ from typing import Iterable, Type, Union, Optional
|
|||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gymnasium.envs.registration import register, registry
|
from gymnasium.envs.registration import register, registry
|
||||||
|
from gymnasium.wrappers import FlattenObservation
|
||||||
|
|
||||||
from fancy_gym.utils.env_compatibility import EnvCompatibility
|
from fancy_gym.utils.env_compatibility import EnvCompatibility
|
||||||
|
|
||||||
@ -165,6 +166,10 @@ def make_bb(
|
|||||||
|
|
||||||
env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
env = _make_wrapped_env(env_id=env_id, wrappers=wrappers, seed=seed, **kwargs)
|
||||||
|
|
||||||
|
# BB expects a spaces.Box to be exposed, need to convert for dict-observations
|
||||||
|
if type(env.observation_space) == gym.spaces.dict.Dict:
|
||||||
|
env = FlattenObservation(env)
|
||||||
|
|
||||||
traj_gen_kwargs['action_dim'] = traj_gen_kwargs.get('action_dim', np.prod(env.action_space.shape).item())
|
traj_gen_kwargs['action_dim'] = traj_gen_kwargs.get('action_dim', np.prod(env.action_space.shape).item())
|
||||||
|
|
||||||
if black_box_kwargs.get('duration') is None:
|
if black_box_kwargs.get('duration') is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user