added balancing to reacher

This commit is contained in:
ottofabian 2021-02-09 17:07:52 +01:00
parent c008614214
commit d026ebc427

View File

@ -8,13 +8,13 @@ from alr_envs.utils.utils import angle_normalize
class ALRReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
def __init__(self, steps_before_reward=200, n_links=5, balance=False):
def __init__(self, steps_before_reward=200, n_links=5, balance=True):
self._steps = 0
self.steps_before_reward = steps_before_reward
self.n_links = n_links
self.balance = balance
self.balance_weight = 0.01
self.balance_weight = 0.5
self.reward_weight = 1
if steps_before_reward == 200:
@ -29,7 +29,7 @@ class ALRReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
else:
raise ValueError(f"Invalid number of links {n_links}, only 5 or 7 allowed.")
utils.EzPickle.__init__(self)
utils.EzPickle.__init__(**locals())
mujoco_env.MujocoEnv.__init__(self, os.path.join(os.path.dirname(__file__), "assets", file_name), 2)
def step(self, a):
@ -45,7 +45,7 @@ class ALRReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
reward_ctrl = - np.square(a).sum()
if self.balance:
reward_balance = - self.balance_weight * np.abs(
reward_balance -= self.balance_weight * np.abs(
angle_normalize(np.sum(self.sim.data.qpos.flat[:self.n_links]), type="rad"))
reward = reward_dist + reward_ctrl + angular_vel + reward_balance
@ -58,7 +58,7 @@ class ALRReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
goal=self.goal if hasattr(self, "goal") else None)
def viewer_setup(self):
self.viewer.cam.trackbodyid = 0
self.viewer.cam.trackbodyid = 1
def reset_model(self):
qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos