added balancing to reacher
This commit is contained in:
parent
c008614214
commit
d026ebc427
@ -8,13 +8,13 @@ from alr_envs.utils.utils import angle_normalize
|
|||||||
|
|
||||||
|
|
||||||
class ALRReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
class ALRReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
||||||
def __init__(self, steps_before_reward=200, n_links=5, balance=False):
|
def __init__(self, steps_before_reward=200, n_links=5, balance=True):
|
||||||
self._steps = 0
|
self._steps = 0
|
||||||
self.steps_before_reward = steps_before_reward
|
self.steps_before_reward = steps_before_reward
|
||||||
self.n_links = n_links
|
self.n_links = n_links
|
||||||
|
|
||||||
self.balance = balance
|
self.balance = balance
|
||||||
self.balance_weight = 0.01
|
self.balance_weight = 0.5
|
||||||
|
|
||||||
self.reward_weight = 1
|
self.reward_weight = 1
|
||||||
if steps_before_reward == 200:
|
if steps_before_reward == 200:
|
||||||
@ -29,7 +29,7 @@ class ALRReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid number of links {n_links}, only 5 or 7 allowed.")
|
raise ValueError(f"Invalid number of links {n_links}, only 5 or 7 allowed.")
|
||||||
|
|
||||||
utils.EzPickle.__init__(self)
|
utils.EzPickle.__init__(**locals())
|
||||||
mujoco_env.MujocoEnv.__init__(self, os.path.join(os.path.dirname(__file__), "assets", file_name), 2)
|
mujoco_env.MujocoEnv.__init__(self, os.path.join(os.path.dirname(__file__), "assets", file_name), 2)
|
||||||
|
|
||||||
def step(self, a):
|
def step(self, a):
|
||||||
@ -45,7 +45,7 @@ class ALRReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
reward_ctrl = - np.square(a).sum()
|
reward_ctrl = - np.square(a).sum()
|
||||||
|
|
||||||
if self.balance:
|
if self.balance:
|
||||||
reward_balance = - self.balance_weight * np.abs(
|
reward_balance -= self.balance_weight * np.abs(
|
||||||
angle_normalize(np.sum(self.sim.data.qpos.flat[:self.n_links]), type="rad"))
|
angle_normalize(np.sum(self.sim.data.qpos.flat[:self.n_links]), type="rad"))
|
||||||
|
|
||||||
reward = reward_dist + reward_ctrl + angular_vel + reward_balance
|
reward = reward_dist + reward_ctrl + angular_vel + reward_balance
|
||||||
@ -58,7 +58,7 @@ class ALRReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle):
|
|||||||
goal=self.goal if hasattr(self, "goal") else None)
|
goal=self.goal if hasattr(self, "goal") else None)
|
||||||
|
|
||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
self.viewer.cam.trackbodyid = 0
|
self.viewer.cam.trackbodyid = 1
|
||||||
|
|
||||||
def reset_model(self):
|
def reset_model(self):
|
||||||
qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos
|
qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos
|
||||||
|
Loading…
Reference in New Issue
Block a user