diff --git a/alr_envs/mujoco/alr_reacher.py b/alr_envs/mujoco/alr_reacher.py index 5b13203..58e273f 100644 --- a/alr_envs/mujoco/alr_reacher.py +++ b/alr_envs/mujoco/alr_reacher.py @@ -8,7 +8,9 @@ from alr_envs.utils.utils import angle_normalize class ALRReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle): - def __init__(self, steps_before_reward=200, n_links=5, balance=True): + def __init__(self, steps_before_reward=200, n_links=5, balance=True, file_name=None): + utils.EzPickle.__init__(**locals()) + self._steps = 0 self.steps_before_reward = steps_before_reward self.n_links = n_links @@ -29,7 +31,6 @@ class ALRReacherEnv(mujoco_env.MujocoEnv, utils.EzPickle): else: raise ValueError(f"Invalid number of links {n_links}, only 5 or 7 allowed.") - utils.EzPickle.__init__(steps_before_reward=steps_before_reward, n_links=n_links, balance=balance) mujoco_env.MujocoEnv.__init__(self, os.path.join(os.path.dirname(__file__), "assets", file_name), 2) def step(self, a): diff --git a/alr_envs/utils/utils.py b/alr_envs/utils/utils.py index 0bca03e..2c15bcb 100644 --- a/alr_envs/utils/utils.py +++ b/alr_envs/utils/utils.py @@ -11,10 +11,11 @@ def angle_normalize(x, type="deg"): Returns: """ + + if type not in ["deg", "rad"]: raise ValueError(f"Invalid type {type}. Choose one of 'deg' or 'rad'.") + if type == "deg": - return ((x + np.pi) % (2 * np.pi)) - np.pi - elif type == "rad": - two_pi = 2 * np.pi - return x - two_pi * np.floor((x + np.pi) / two_pi) - else: - raise ValueError(f"Invalid type {type}. Choose on of 'deg' or 'rad'.") + x = np.deg2rad(x) # x * pi / 180 + + two_pi = 2 * np.pi + return x - two_pi * np.floor((x + np.pi) / two_pi)