first updats to reacher and env creation
This commit is contained in:
parent
14ac5f81c7
commit
0c7ac838bf
@ -3,6 +3,7 @@ import os
|
||||
import numpy as np
|
||||
from gym import utils
|
||||
from gym.envs.mujoco import MujocoEnv
|
||||
from gym.spaces import Box
|
||||
|
||||
MAX_EPISODE_STEPS_REACHER = 200
|
||||
|
||||
@ -12,7 +13,17 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
||||
More general version of the gym mujoco Reacher environment
|
||||
"""
|
||||
|
||||
def __init__(self, sparse: bool = False, n_links: int = 5, reward_weight: float = 1, ctrl_cost_weight: float = 1):
|
||||
metadata = {
|
||||
"render_modes": [
|
||||
"human",
|
||||
"rgb_array",
|
||||
"depth_array",
|
||||
],
|
||||
"render_fps": 50,
|
||||
}
|
||||
|
||||
def __init__(self, sparse: bool = False, n_links: int = 5, reward_weight: float = 1, ctrl_cost_weight: float = 1.,
|
||||
**kwargs):
|
||||
utils.EzPickle.__init__(**locals())
|
||||
|
||||
self._steps = 0
|
||||
@ -25,10 +36,16 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
file_name = f'reacher_{n_links}links.xml'
|
||||
|
||||
# sin, cos, velocity * n_Links + goal position (2) and goal distance (3)
|
||||
shape = (self.n_links * 3 + 5,)
|
||||
observation_space = Box(low=-np.inf, high=np.inf, shape=shape, dtype=np.float64)
|
||||
|
||||
MujocoEnv.__init__(self,
|
||||
model_path=os.path.join(os.path.dirname(__file__), "assets", file_name),
|
||||
frame_skip=2,
|
||||
mujoco_bindings="mujoco")
|
||||
observation_space=observation_space,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def step(self, action):
|
||||
self._steps += 1
|
||||
@ -45,10 +62,14 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
||||
|
||||
reward = reward_dist + reward_ctrl + angular_vel
|
||||
self.do_simulation(action, self.frame_skip)
|
||||
ob = self._get_obs()
|
||||
done = False
|
||||
if self.render_mode == "human":
|
||||
self.render()
|
||||
|
||||
infos = dict(
|
||||
ob = self._get_obs()
|
||||
terminated = False
|
||||
truncated = False
|
||||
|
||||
info = dict(
|
||||
reward_dist=reward_dist,
|
||||
reward_ctrl=reward_ctrl,
|
||||
velocity=angular_vel,
|
||||
@ -56,7 +77,7 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
||||
goal=self.goal if hasattr(self, "goal") else None
|
||||
)
|
||||
|
||||
return ob, reward, done, infos
|
||||
return ob, reward, terminated, truncated, info
|
||||
|
||||
def distance_reward(self):
|
||||
vec = self.get_body_com("fingertip") - self.get_body_com("target")
|
||||
@ -66,6 +87,7 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
||||
return -10 * np.square(self.data.qvel.flat[:self.n_links]).sum() if self.sparse else 0.0
|
||||
|
||||
def viewer_setup(self):
|
||||
assert self.viewer is not None
|
||||
self.viewer.cam.trackbodyid = 0
|
||||
|
||||
def reset_model(self):
|
||||
|
@ -9,6 +9,7 @@ from typing import Iterable, Type, Union
|
||||
import gym
|
||||
import numpy as np
|
||||
from gym.envs.registration import register, registry
|
||||
from gym.utils import seeding
|
||||
|
||||
try:
|
||||
from dm_control import suite, manipulation
|
||||
@ -88,7 +89,9 @@ def make(env_id: str, seed: int, **kwargs):
|
||||
else:
|
||||
env = make_gym(env_id, seed, **kwargs)
|
||||
|
||||
env.seed(seed)
|
||||
np_random, _ = seeding.np_random(seed)
|
||||
env.np_random = np_random
|
||||
# env.seed(seed)
|
||||
env.action_space.seed(seed)
|
||||
env.observation_space.seed(seed)
|
||||
|
||||
|
@ -6,7 +6,7 @@ import pytest
|
||||
|
||||
from test.utils import run_env, run_env_determinism
|
||||
|
||||
CUSTOM_IDS = [spec.id for spec in gym.envs.registry.all() if
|
||||
CUSTOM_IDS = [id for id, spec in gym.envs.registry.items() if
|
||||
"fancy_gym" in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point]
|
||||
CUSTOM_MP_IDS = itertools.chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
||||
SEED = 1
|
||||
|
@ -24,6 +24,7 @@ def run_env(env_id, iterations=None, seed=0, render=False):
|
||||
actions = []
|
||||
dones = []
|
||||
obs = env.reset()
|
||||
print(obs.dtype)
|
||||
verify_observations(obs, env.observation_space, "reset()")
|
||||
|
||||
iterations = iterations or (env.spec.max_episode_steps or 1)
|
||||
|
Loading…
Reference in New Issue
Block a user