fix minor bugs & merge test branch
This commit is contained in:
parent
e3509f8be3
commit
c457fbbfeb
@ -234,10 +234,7 @@ for reward_type in ["Dense", "TemporalSparse", "TemporalSpatialSparse"]:
|
|||||||
register(
|
register(
|
||||||
id='BoxPushing{}-v0'.format(reward_type),
|
id='BoxPushing{}-v0'.format(reward_type),
|
||||||
entry_point='fancy_gym.envs.mujoco:BoxPushing{}'.format(reward_type),
|
entry_point='fancy_gym.envs.mujoco:BoxPushing{}'.format(reward_type),
|
||||||
max_episode_steps=MAX_EPISODE_STEPS_BOX_PUSHING//10, # divided by frames skip
|
max_episode_steps=MAX_EPISODE_STEPS_BOX_PUSHING,
|
||||||
kwargs={
|
|
||||||
"frame_skip": 10
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Here we use the same reward as in BeerPong-v0, but now consider after the release,
|
# Here we use the same reward as in BeerPong-v0, but now consider after the release,
|
||||||
|
@ -9,7 +9,7 @@ from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import desired_rod_quat
|
|||||||
|
|
||||||
import mujoco
|
import mujoco
|
||||||
|
|
||||||
MAX_EPISODE_STEPS_BOX_PUSHING = 1000
|
MAX_EPISODE_STEPS_BOX_PUSHING = 100
|
||||||
|
|
||||||
BOX_POS_BOUND = np.array([[0.3, -0.45, -0.01], [0.6, 0.45, -0.01]])
|
BOX_POS_BOUND = np.array([[0.3, -0.45, -0.01], [0.6, 0.45, -0.01]])
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
|||||||
self._steps += 1
|
self._steps += 1
|
||||||
self._episode_energy += np.sum(np.square(action))
|
self._episode_energy += np.sum(np.square(action))
|
||||||
|
|
||||||
episode_end = True if self._steps >= MAX_EPISODE_STEPS_BOX_PUSHING//self.frame_skip else False
|
episode_end = True if self._steps >= MAX_EPISODE_STEPS_BOX_PUSHING else False
|
||||||
|
|
||||||
box_pos = self.data.body("box_0").xpos.copy()
|
box_pos = self.data.body("box_0").xpos.copy()
|
||||||
box_quat = self.data.body("box_0").xquat.copy()
|
box_quat = self.data.body("box_0").xquat.copy()
|
||||||
@ -121,8 +121,8 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
|
|||||||
return self._get_obs()
|
return self._get_obs()
|
||||||
|
|
||||||
def sample_context(self):
|
def sample_context(self):
|
||||||
pos = np.random.uniform(low=BOX_POS_BOUND[0], high=BOX_POS_BOUND[1], size=3)
|
pos = self.np_random.uniform(low=BOX_POS_BOUND[0], high=BOX_POS_BOUND[1])
|
||||||
theta = np.random.uniform(low=0, high=np.pi * 2)
|
theta = self.np_random.uniform(low=0, high=np.pi * 2)
|
||||||
quat = rot_to_quat(theta, np.array([0, 0, 1]))
|
quat = rot_to_quat(theta, np.array([0, 0, 1]))
|
||||||
return np.concatenate([pos, quat])
|
return np.concatenate([pos, quat])
|
||||||
|
|
||||||
@ -360,7 +360,7 @@ class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase):
|
|||||||
if __name__=="__main__":
|
if __name__=="__main__":
|
||||||
env = BoxPushingTemporalSpatialSparse(frame_skip=10)
|
env = BoxPushingTemporalSpatialSparse(frame_skip=10)
|
||||||
env.reset()
|
env.reset()
|
||||||
for i in range(100):
|
for i in range(1):
|
||||||
env.reset()
|
env.reset()
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
env.render("human")
|
env.render("human")
|
||||||
|
Loading…
Reference in New Issue
Block a user