first updats to reacher and env creation
This commit is contained in:
parent
14ac5f81c7
commit
0c7ac838bf
@ -3,6 +3,7 @@ import os
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from gym import utils
|
from gym import utils
|
||||||
from gym.envs.mujoco import MujocoEnv
|
from gym.envs.mujoco import MujocoEnv
|
||||||
|
from gym.spaces import Box
|
||||||
|
|
||||||
MAX_EPISODE_STEPS_REACHER = 200
|
MAX_EPISODE_STEPS_REACHER = 200
|
||||||
|
|
||||||
@ -12,7 +13,17 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
|||||||
More general version of the gym mujoco Reacher environment
|
More general version of the gym mujoco Reacher environment
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, sparse: bool = False, n_links: int = 5, reward_weight: float = 1, ctrl_cost_weight: float = 1):
|
metadata = {
|
||||||
|
"render_modes": [
|
||||||
|
"human",
|
||||||
|
"rgb_array",
|
||||||
|
"depth_array",
|
||||||
|
],
|
||||||
|
"render_fps": 50,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, sparse: bool = False, n_links: int = 5, reward_weight: float = 1, ctrl_cost_weight: float = 1.,
|
||||||
|
**kwargs):
|
||||||
utils.EzPickle.__init__(**locals())
|
utils.EzPickle.__init__(**locals())
|
||||||
|
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
@ -25,10 +36,16 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
file_name = f'reacher_{n_links}links.xml'
|
file_name = f'reacher_{n_links}links.xml'
|
||||||
|
|
||||||
|
# sin, cos, velocity * n_Links + goal position (2) and goal distance (3)
|
||||||
|
shape = (self.n_links * 3 + 5,)
|
||||||
|
observation_space = Box(low=-np.inf, high=np.inf, shape=shape, dtype=np.float64)
|
||||||
|
|
||||||
MujocoEnv.__init__(self,
|
MujocoEnv.__init__(self,
|
||||||
model_path=os.path.join(os.path.dirname(__file__), "assets", file_name),
|
model_path=os.path.join(os.path.dirname(__file__), "assets", file_name),
|
||||||
frame_skip=2,
|
frame_skip=2,
|
||||||
mujoco_bindings="mujoco")
|
observation_space=observation_space,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
self._steps += 1
|
self._steps += 1
|
||||||
@ -45,10 +62,14 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
|||||||
|
|
||||||
reward = reward_dist + reward_ctrl + angular_vel
|
reward = reward_dist + reward_ctrl + angular_vel
|
||||||
self.do_simulation(action, self.frame_skip)
|
self.do_simulation(action, self.frame_skip)
|
||||||
ob = self._get_obs()
|
if self.render_mode == "human":
|
||||||
done = False
|
self.render()
|
||||||
|
|
||||||
infos = dict(
|
ob = self._get_obs()
|
||||||
|
terminated = False
|
||||||
|
truncated = False
|
||||||
|
|
||||||
|
info = dict(
|
||||||
reward_dist=reward_dist,
|
reward_dist=reward_dist,
|
||||||
reward_ctrl=reward_ctrl,
|
reward_ctrl=reward_ctrl,
|
||||||
velocity=angular_vel,
|
velocity=angular_vel,
|
||||||
@ -56,7 +77,7 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
|||||||
goal=self.goal if hasattr(self, "goal") else None
|
goal=self.goal if hasattr(self, "goal") else None
|
||||||
)
|
)
|
||||||
|
|
||||||
return ob, reward, done, infos
|
return ob, reward, terminated, truncated, info
|
||||||
|
|
||||||
def distance_reward(self):
|
def distance_reward(self):
|
||||||
vec = self.get_body_com("fingertip") - self.get_body_com("target")
|
vec = self.get_body_com("fingertip") - self.get_body_com("target")
|
||||||
@ -66,6 +87,7 @@ class ReacherEnv(MujocoEnv, utils.EzPickle):
|
|||||||
return -10 * np.square(self.data.qvel.flat[:self.n_links]).sum() if self.sparse else 0.0
|
return -10 * np.square(self.data.qvel.flat[:self.n_links]).sum() if self.sparse else 0.0
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
|
assert self.viewer is not None
|
||||||
self.viewer.cam.trackbodyid = 0
|
self.viewer.cam.trackbodyid = 0
|
||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
|
@ -9,6 +9,7 @@ from typing import Iterable, Type, Union
|
|||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gym.envs.registration import register, registry
|
from gym.envs.registration import register, registry
|
||||||
|
from gym.utils import seeding
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from dm_control import suite, manipulation
|
from dm_control import suite, manipulation
|
||||||
@ -88,7 +89,9 @@ def make(env_id: str, seed: int, **kwargs):
|
|||||||
else:
|
else:
|
||||||
env = make_gym(env_id, seed, **kwargs)
|
env = make_gym(env_id, seed, **kwargs)
|
||||||
|
|
||||||
env.seed(seed)
|
np_random, _ = seeding.np_random(seed)
|
||||||
|
env.np_random = np_random
|
||||||
|
# env.seed(seed)
|
||||||
env.action_space.seed(seed)
|
env.action_space.seed(seed)
|
||||||
env.observation_space.seed(seed)
|
env.observation_space.seed(seed)
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ import pytest
|
|||||||
|
|
||||||
from test.utils import run_env, run_env_determinism
|
from test.utils import run_env, run_env_determinism
|
||||||
|
|
||||||
CUSTOM_IDS = [spec.id for spec in gym.envs.registry.all() if
|
CUSTOM_IDS = [id for id, spec in gym.envs.registry.items() if
|
||||||
"fancy_gym" in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point]
|
"fancy_gym" in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point]
|
||||||
CUSTOM_MP_IDS = itertools.chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
CUSTOM_MP_IDS = itertools.chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values())
|
||||||
SEED = 1
|
SEED = 1
|
||||||
|
@ -24,6 +24,7 @@ def run_env(env_id, iterations=None, seed=0, render=False):
|
|||||||
actions = []
|
actions = []
|
||||||
dones = []
|
dones = []
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
print(obs.dtype)
|
||||||
verify_observations(obs, env.observation_space, "reset()")
|
verify_observations(obs, env.observation_space, "reset()")
|
||||||
|
|
||||||
iterations = iterations or (env.spec.max_episode_steps or 1)
|
iterations = iterations or (env.spec.max_episode_steps or 1)
|
||||||
|
Loading…
Reference in New Issue
Block a user