first updats to reacher and env creation

This commit is contained in:
Fabian 2022-10-20 10:10:44 +02:00
parent 14ac5f81c7
commit 0c7ac838bf
4 changed files with 34 additions and 8 deletions

View File

@ -3,6 +3,7 @@ import os
import numpy as np
from gym import utils
from gym.envs.mujoco import MujocoEnv
from gym.spaces import Box
MAX_EPISODE_STEPS_REACHER = 200
@ -12,7 +13,17 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
More general version of the gym mujoco Reacher environment
"""
def __init__(self, sparse: bool = False, n_links: int = 5, reward_weight: float = 1, ctrl_cost_weight: float = 1):
metadata = {
"render_modes": [
"human",
"rgb_array",
"depth_array",
],
"render_fps": 50,
}
def __init__(self, sparse: bool = False, n_links: int = 5, reward_weight: float = 1, ctrl_cost_weight: float = 1.,
**kwargs):
utils.EzPickle.__init__(**locals())
self._steps = 0
@ -25,10 +36,16 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
file_name = f'reacher_{n_links}links.xml'
# sin, cos, velocity * n_Links + goal position (2) and goal distance (3)
shape = (self.n_links * 3 + 5,)
observation_space = Box(low=-np.inf, high=np.inf, shape=shape, dtype=np.float64)
MujocoEnv.__init__(self,
model_path=os.path.join(os.path.dirname(__file__), "assets", file_name),
frame_skip=2,
mujoco_bindings="mujoco")
observation_space=observation_space,
**kwargs
)
def step(self, action):
self._steps += 1
@ -45,10 +62,14 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
reward = reward_dist + reward_ctrl + angular_vel
self.do_simulation(action, self.frame_skip)
ob = self._get_obs()
done = False
if self.render_mode == "human":
self.render()
infos = dict(
ob = self._get_obs()
terminated = False
truncated = False
info = dict(
reward_dist=reward_dist,
reward_ctrl=reward_ctrl,
velocity=angular_vel,
@ -56,7 +77,7 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
goal=self.goal if hasattr(self, "goal") else None
)
return ob, reward, done, infos
return ob, reward, terminated, truncated, info
def distance_reward(self):
vec = self.get_body_com("fingertip") - self.get_body_com("target")
@ -66,6 +87,7 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
return -10 * np.square(self.data.qvel.flat[:self.n_links]).sum() if self.sparse else 0.0
def viewer_setup(self):
assert self.viewer is not None
self.viewer.cam.trackbodyid = 0
def reset_model(self):

View File

@ -9,6 +9,7 @@ from typing import Iterable, Type, Union
import gym
import numpy as np
from gym.envs.registration import register, registry
from gym.utils import seeding
try:
from dm_control import suite, manipulation
@ -88,7 +89,9 @@ def make(env_id: str, seed: int, **kwargs):
else:
env = make_gym(env_id, seed, **kwargs)
env.seed(seed)
np_random, _ = seeding.np_random(seed)
env.np_random = np_random
# env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)

View File

@ -6,7 +6,7 @@ import pytest
from test.utils import run_env, run_env_determinism
CUSTOM_IDS = [spec.id for spec in gym.envs.registry.all() if
CUSTOM_IDS = [id for id, spec in gym.envs.registry.items() if
"fancy_gym" in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point]
CUSTOM_MP_IDS = itertools.chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
SEED = 1

View File

@ -24,6 +24,7 @@ def run_env(env_id, iterations=None, seed=0, render=False):
actions = []
dones = []
obs = env.reset()
print(obs.dtype)
verify_observations(obs, env.observation_space, "reset()")
iterations = iterations or (env.spec.max_episode_steps or 1)