diff --git a/mujoco_maze/swimmer.py b/mujoco_maze/swimmer.py index b93d8d1..4ac7699 100644 --- a/mujoco_maze/swimmer.py +++ b/mujoco_maze/swimmer.py @@ -11,6 +11,7 @@ import numpy as np from mujoco_maze.agent_model import AgentModel from mujoco_maze.ant import ForwardRewardFn, forward_reward_vnorm +from gym import spaces class SwimmerEnv(AgentModel): @@ -36,7 +37,8 @@ class SwimmerEnv(AgentModel): self._forward_reward_weight = forward_reward_weight self._ctrl_cost_weight = ctrl_cost_weight self._forward_reward_fn = forward_reward_fn - super().__init__(file_path, 4) + self.observation_space = spaces.Box(low=-np.inf, high=np.inf) + super().__init__(file_path, 4, self.observation_space) def _forward_reward(self, xy_pos_before: np.ndarray) -> Tuple[float, np.ndarray]: xy_pos_after = self.sim.data.qpos[:2].copy()