diff --git a/tests/test_envs.py b/tests/test_envs.py index 209ab24..9a5400f 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -7,9 +7,13 @@ import pytest def test_ant_maze(maze_id): env = gym.make("Ant{}-v0".format(maze_id)) assert env.reset().shape == (30,) + s, _, _, _ = env.step(env.action_space.sample()) + assert s.shape == (30,) @pytest.mark.parametrize("maze_id", mujoco_maze.MAZE_IDS) def test_point_maze(maze_id): env = gym.make("Point{}-v0".format(maze_id)) assert env.reset().shape == (7,) + s, _, _, _ = env.step(env.action_space.sample()) + assert s.shape == (7,)