first updats to reacher and env creation

This commit is contained in:
Fabian 2022-10-20 10:10:44 +02:00
parent 14ac5f81c7
commit 0c7ac838bf
4 changed files with 34 additions and 8 deletions

View File

@ -3,6 +3,7 @@ import os
import numpy as np import numpy as np
from gym import utils from gym import utils
from gym.envs.mujoco import MujocoEnv from gym.envs.mujoco import MujocoEnv
from gym.spaces import Box
MAX_EPISODE_STEPS_REACHER = 200 MAX_EPISODE_STEPS_REACHER = 200
@ -12,7 +13,17 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
More general version of the gym mujoco Reacher environment More general version of the gym mujoco Reacher environment
""" """
def __init__(self, sparse: bool = False, n_links: int = 5, reward_weight: float = 1, ctrl_cost_weight: float = 1): metadata = {
"render_modes": [
"human",
"rgb_array",
"depth_array",
],
"render_fps": 50,
}
def __init__(self, sparse: bool = False, n_links: int = 5, reward_weight: float = 1, ctrl_cost_weight: float = 1.,
**kwargs):
utils.EzPickle.__init__(**locals()) utils.EzPickle.__init__(**locals())
self._steps = 0 self._steps = 0
@ -25,10 +36,16 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
file_name = f'reacher_{n_links}links.xml' file_name = f'reacher_{n_links}links.xml'
# sin, cos, velocity * n_Links + goal position (2) and goal distance (3)
shape = (self.n_links * 3 + 5,)
observation_space = Box(low=-np.inf, high=np.inf, shape=shape, dtype=np.float64)
MujocoEnv.__init__(self, MujocoEnv.__init__(self,
model_path=os.path.join(os.path.dirname(__file__), "assets", file_name), model_path=os.path.join(os.path.dirname(__file__), "assets", file_name),
frame_skip=2, frame_skip=2,
mujoco_bindings="mujoco") observation_space=observation_space,
**kwargs
)
def step(self, action): def step(self, action):
self._steps += 1 self._steps += 1
@ -45,10 +62,14 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
reward = reward_dist + reward_ctrl + angular_vel reward = reward_dist + reward_ctrl + angular_vel
self.do_simulation(action, self.frame_skip) self.do_simulation(action, self.frame_skip)
ob = self._get_obs() if self.render_mode == "human":
done = False self.render()
infos = dict( ob = self._get_obs()
terminated = False
truncated = False
info = dict(
reward_dist=reward_dist, reward_dist=reward_dist,
reward_ctrl=reward_ctrl, reward_ctrl=reward_ctrl,
velocity=angular_vel, velocity=angular_vel,
@ -56,7 +77,7 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
goal=self.goal if hasattr(self, "goal") else None goal=self.goal if hasattr(self, "goal") else None
) )
return ob, reward, done, infos return ob, reward, terminated, truncated, info
def distance_reward(self): def distance_reward(self):
vec = self.get_body_com("fingertip") - self.get_body_com("target") vec = self.get_body_com("fingertip") - self.get_body_com("target")
@ -66,6 +87,7 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
return -10 * np.square(self.data.qvel.flat[:self.n_links]).sum() if self.sparse else 0.0 return -10 * np.square(self.data.qvel.flat[:self.n_links]).sum() if self.sparse else 0.0
def viewer_setup(self): def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.trackbodyid = 0 self.viewer.cam.trackbodyid = 0
def reset_model(self): def reset_model(self):

View File

@ -9,6 +9,7 @@ from typing import Iterable, Type, Union
import gym import gym
import numpy as np import numpy as np
from gym.envs.registration import register, registry from gym.envs.registration import register, registry
from gym.utils import seeding
try: try:
from dm_control import suite, manipulation from dm_control import suite, manipulation
@ -88,7 +89,9 @@ def make(env_id: str, seed: int, **kwargs):
else: else:
env = make_gym(env_id, seed, **kwargs) env = make_gym(env_id, seed, **kwargs)
env.seed(seed) np_random, _ = seeding.np_random(seed)
env.np_random = np_random
# env.seed(seed)
env.action_space.seed(seed) env.action_space.seed(seed)
env.observation_space.seed(seed) env.observation_space.seed(seed)

View File

@ -6,7 +6,7 @@ import pytest
from test.utils import run_env, run_env_determinism from test.utils import run_env, run_env_determinism
CUSTOM_IDS = [spec.id for spec in gym.envs.registry.all() if CUSTOM_IDS = [id for id, spec in gym.envs.registry.items() if
"fancy_gym" in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point] "fancy_gym" in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point]
CUSTOM_MP_IDS = itertools.chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()) CUSTOM_MP_IDS = itertools.chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
SEED = 1 SEED = 1

View File

@ -24,6 +24,7 @@ def run_env(env_id, iterations=None, seed=0, render=False):
actions = [] actions = []
dones = [] dones = []
obs = env.reset() obs = env.reset()
print(obs.dtype)
verify_observations(obs, env.observation_space, "reset()") verify_observations(obs, env.observation_space, "reset()")
iterations = iterations or (env.spec.max_episode_steps or 1) iterations = iterations or (env.spec.max_episode_steps or 1)