Fix gym warning

This commit is contained in:
kngwyu 2020-06-30 16:33:07 +09:00
parent f23b39067a
commit bbbe0f38e3
3 changed files with 12 additions and 7 deletions

View File

@ -268,7 +268,7 @@ class MazeEnv(gym.Env):
def _get_obs_space(self) -> gym.spaces.Box: def _get_obs_space(self) -> gym.spaces.Box:
shape = self._get_obs().shape shape = self._get_obs().shape
high = np.inf * np.ones(shape) high = np.inf * np.ones(shape, dtype=np.float32)
low = -high low = -high
# Set velocity limits # Set velocity limits
wrapped_obs_space = self.wrapped_env.observation_space wrapped_obs_space = self.wrapped_env.observation_space

View File

@ -1,15 +1,20 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Type from typing import Dict, List, NamedTuple, Type
import numpy as np import numpy as np
from mujoco_maze.maze_env_utils import MazeCell from mujoco_maze.maze_env_utils import MazeCell
Rgb = Tuple[float, float, float]
RED = (0.7, 0.1, 0.1) class Rgb(NamedTuple):
GREEN = (0.1, 0.7, 0.1) red: float
BLUE = (0.1, 0.1, 0.7) green: float
blue: float
RED = Rgb(0.7, 0.1, 0.1)
GREEN = Rgb(0.1, 0.7, 0.1)
BLUE = Rgb(0.1, 0.1, 0.7)
class MazeGoal: class MazeGoal:

View File

@ -31,7 +31,7 @@ class PointEnv(AgentModel):
def __init__(self, file_path: Optional[str] = None): def __init__(self, file_path: Optional[str] = None):
super().__init__(file_path, 1) super().__init__(file_path, 1)
high = np.inf * np.ones(6) high = np.inf * np.ones(6, dtype=np.float32)
high[3:] = self.VELOCITY_LIMITS high[3:] = self.VELOCITY_LIMITS
high[self.ORI_IND] = np.pi high[self.ORI_IND] = np.pi
low = -high low = -high