From 0c7ac838bf9bf5ac3295a8002c5b8b381bf6592c Mon Sep 17 00:00:00 2001 From: Fabian Date: Thu, 20 Oct 2022 10:10:44 +0200 Subject: [PATCH] first updats to reacher and env creation --- fancy_gym/envs/mujoco/reacher/reacher.py | 34 +++++++++++++++++++----- fancy_gym/utils/make_env_helpers.py | 5 +++- test/test_fancy_envs.py | 2 +- test/utils.py | 1 + 4 files changed, 34 insertions(+), 8 deletions(-) diff --git a/fancy_gym/envs/mujoco/reacher/reacher.py b/fancy_gym/envs/mujoco/reacher/reacher.py index ccd0073..c3c870b 100644 --- a/fancy_gym/envs/mujoco/reacher/reacher.py +++ b/fancy_gym/envs/mujoco/reacher/reacher.py @@ -3,6 +3,7 @@ import os import numpy as np from gym import utils from gym.envs.mujoco import MujocoEnv +from gym.spaces import Box MAX_EPISODE_STEPS_REACHER = 200 @@ -12,7 +13,17 @@ class ReacherEnv(MujocoEnv, utils.EzPickle): More general version of the gym mujoco Reacher environment """ - def __init__(self, sparse: bool = False, n_links: int = 5, reward_weight: float = 1, ctrl_cost_weight: float = 1): + metadata = { + "render_modes": [ + "human", + "rgb_array", + "depth_array", + ], + "render_fps": 50, + } + + def __init__(self, sparse: bool = False, n_links: int = 5, reward_weight: float = 1, ctrl_cost_weight: float = 1., + **kwargs): utils.EzPickle.__init__(**locals()) self._steps = 0 @@ -25,10 +36,16 @@ class ReacherEnv(MujocoEnv, utils.EzPickle): file_name = f'reacher_{n_links}links.xml' + # sin, cos, velocity * n_Links + goal position (2) and goal distance (3) + shape = (self.n_links * 3 + 5,) + observation_space = Box(low=-np.inf, high=np.inf, shape=shape, dtype=np.float64) + MujocoEnv.__init__(self, model_path=os.path.join(os.path.dirname(__file__), "assets", file_name), frame_skip=2, - mujoco_bindings="mujoco") + observation_space=observation_space, + **kwargs + ) def step(self, action): self._steps += 1 @@ -45,10 +62,14 @@ class ReacherEnv(MujocoEnv, utils.EzPickle): reward = reward_dist + reward_ctrl + angular_vel self.do_simulation(action, self.frame_skip) - ob = self._get_obs() - done = False + if self.render_mode == "human": + self.render() - infos = dict( + ob = self._get_obs() + terminated = False + truncated = False + + info = dict( reward_dist=reward_dist, reward_ctrl=reward_ctrl, velocity=angular_vel, @@ -56,7 +77,7 @@ class ReacherEnv(MujocoEnv, utils.EzPickle): goal=self.goal if hasattr(self, "goal") else None ) - return ob, reward, done, infos + return ob, reward, terminated, truncated, info def distance_reward(self): vec = self.get_body_com("fingertip") - self.get_body_com("target") @@ -66,6 +87,7 @@ class ReacherEnv(MujocoEnv, utils.EzPickle): return -10 * np.square(self.data.qvel.flat[:self.n_links]).sum() if self.sparse else 0.0 def viewer_setup(self): + assert self.viewer is not None self.viewer.cam.trackbodyid = 0 def reset_model(self): diff --git a/fancy_gym/utils/make_env_helpers.py b/fancy_gym/utils/make_env_helpers.py index 5221423..68bb66d 100644 --- a/fancy_gym/utils/make_env_helpers.py +++ b/fancy_gym/utils/make_env_helpers.py @@ -9,6 +9,7 @@ from typing import Iterable, Type, Union import gym import numpy as np from gym.envs.registration import register, registry +from gym.utils import seeding try: from dm_control import suite, manipulation @@ -88,7 +89,9 @@ def make(env_id: str, seed: int, **kwargs): else: env = make_gym(env_id, seed, **kwargs) - env.seed(seed) + np_random, _ = seeding.np_random(seed) + env.np_random = np_random + # env.seed(seed) env.action_space.seed(seed) env.observation_space.seed(seed) diff --git a/test/test_fancy_envs.py b/test/test_fancy_envs.py index 9acd696..7b7d5ca 100644 --- a/test/test_fancy_envs.py +++ b/test/test_fancy_envs.py @@ -6,7 +6,7 @@ import pytest from test.utils import run_env, run_env_determinism -CUSTOM_IDS = [spec.id for spec in gym.envs.registry.all() if +CUSTOM_IDS = [id for id, spec in gym.envs.registry.items() if "fancy_gym" in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point] CUSTOM_MP_IDS = itertools.chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()) SEED = 1 diff --git a/test/utils.py b/test/utils.py index 7ed8d61..88b56bc 100644 --- a/test/utils.py +++ b/test/utils.py @@ -24,6 +24,7 @@ def run_env(env_id, iterations=None, seed=0, render=False): actions = [] dones = [] obs = env.reset() + print(obs.dtype) verify_observations(obs, env.observation_space, "reset()") iterations = iterations or (env.spec.max_episode_steps or 1)