fix minor bugs & merge test branch

This commit is contained in:
Hongyi Zhou 2022-10-24 22:01:56 +02:00
parent e3509f8be3
commit c457fbbfeb
2 changed files with 6 additions and 9 deletions

View File

@ -234,10 +234,7 @@ for reward_type in ["Dense", "TemporalSparse", "TemporalSpatialSparse"]:
register(
id='BoxPushing{}-v0'.format(reward_type),
entry_point='fancy_gym.envs.mujoco:BoxPushing{}'.format(reward_type),
max_episode_steps=MAX_EPISODE_STEPS_BOX_PUSHING//10, # divided by frames skip
kwargs={
"frame_skip": 10
}
max_episode_steps=MAX_EPISODE_STEPS_BOX_PUSHING,
)
# Here we use the same reward as in BeerPong-v0, but now consider after the release,

View File

@ -9,7 +9,7 @@ from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import desired_rod_quat
import mujoco
MAX_EPISODE_STEPS_BOX_PUSHING = 1000
MAX_EPISODE_STEPS_BOX_PUSHING = 100
BOX_POS_BOUND = np.array([[0.3, -0.45, -0.01], [0.6, 0.45, -0.01]])
@ -60,7 +60,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
self._steps += 1
self._episode_energy += np.sum(np.square(action))
episode_end = True if self._steps >= MAX_EPISODE_STEPS_BOX_PUSHING//self.frame_skip else False
episode_end = True if self._steps >= MAX_EPISODE_STEPS_BOX_PUSHING else False
box_pos = self.data.body("box_0").xpos.copy()
box_quat = self.data.body("box_0").xquat.copy()
@ -121,8 +121,8 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
return self._get_obs()
def sample_context(self):
pos = np.random.uniform(low=BOX_POS_BOUND[0], high=BOX_POS_BOUND[1], size=3)
theta = np.random.uniform(low=0, high=np.pi * 2)
pos = self.np_random.uniform(low=BOX_POS_BOUND[0], high=BOX_POS_BOUND[1])
theta = self.np_random.uniform(low=0, high=np.pi * 2)
quat = rot_to_quat(theta, np.array([0, 0, 1]))
return np.concatenate([pos, quat])
@ -360,7 +360,7 @@ class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase):
if __name__=="__main__":
env = BoxPushingTemporalSpatialSparse(frame_skip=10)
env.reset()
for i in range(100):
for i in range(1):
env.reset()
for _ in range(100):
env.render("human")