From 7f512068c96c3e075b720e9a3a193028d9c5e211 Mon Sep 17 00:00:00 2001 From: Maximilian Huettenrauch Date: Mon, 17 May 2021 09:32:51 +0200 Subject: [PATCH] context wip --- .../episodic_simple_reacher.py | 2 +- alr_envs/classic_control/hole_reacher.py | 11 +++++++++-- alr_envs/classic_control/simple_reacher.py | 19 +++++++++++-------- alr_envs/utils/mp_env_async_sampler.py | 9 ++++++--- alr_envs/utils/wrapper/dmp_wrapper.py | 2 +- alr_envs/utils/wrapper/mp_wrapper.py | 3 ++- example.py | 4 ++-- 7 files changed, 32 insertions(+), 18 deletions(-) diff --git a/alr_envs/classic_control/episodic_simple_reacher.py b/alr_envs/classic_control/episodic_simple_reacher.py index b02efe8..f6d828e 100644 --- a/alr_envs/classic_control/episodic_simple_reacher.py +++ b/alr_envs/classic_control/episodic_simple_reacher.py @@ -26,7 +26,7 @@ class EpisodicSimpleReacherEnv(SimpleReacherEnv): self.observation_space = spaces.Box(low=-state_bound, high=state_bound, shape=state_bound.shape) @property - def start_pos(self): + def init_qpos(self): return self._start_pos # @property diff --git a/alr_envs/classic_control/hole_reacher.py b/alr_envs/classic_control/hole_reacher.py index 3b382f9..3ddc360 100644 --- a/alr_envs/classic_control/hole_reacher.py +++ b/alr_envs/classic_control/hole_reacher.py @@ -62,6 +62,10 @@ class HoleReacher(gym.Env): self.patches = [rect_1, rect_2, rect_3] + @property + def init_qpos(self): + return self.start_pos + @property def end_effector(self): return self._joints[self.n_links].T @@ -82,7 +86,7 @@ class HoleReacher(gym.Env): """ a single step with an action in joint velocity space """ - vel = action # + 0.01 * np.random.randn(self.num_links) + vel = action # + 0.05 * np.random.randn(self.n_links) acc = (vel - self._angle_velocity) / self.dt self._angle_velocity = vel self._joint_angles = self._joint_angles + self.dt * self._angle_velocity @@ -237,7 +241,10 @@ class HoleReacher(gym.Env): if self._steps == 1: # fig, ax = plt.subplots() # Add the patch to the Axes - [plt.gca().add_patch(rect) for rect in self.patches] + try: + [plt.gca().add_patch(rect) for rect in self.patches] + except RuntimeError: + pass # plt.pause(0.01) if self._steps % 20 == 0 or self._steps in [1, 199] or self._is_collided: diff --git a/alr_envs/classic_control/simple_reacher.py b/alr_envs/classic_control/simple_reacher.py index 7ca4ead..3f4a2a9 100644 --- a/alr_envs/classic_control/simple_reacher.py +++ b/alr_envs/classic_control/simple_reacher.py @@ -22,7 +22,7 @@ class SimpleReacherEnv(gym.Env): super().__init__() self.link_lengths = np.ones(n_links) self.n_links = n_links - self.dt = 0.1 + self.dt = 0.01 self.random_start = random_start @@ -56,10 +56,13 @@ class SimpleReacherEnv(gym.Env): def step(self, action: np.ndarray): # action = self._add_action_noise(action) - action = np.clip(action, -self.max_torque, self.max_torque) + # action = np.clip(action, -self.max_torque, self.max_torque) + vel = action - self._angle_velocity = self._angle_velocity + self.dt * action - self._joint_angle = angle_normalize(self._joint_angle + self.dt * self._angle_velocity) + # self._angle_velocity = self._angle_velocity + self.dt * action + # self._joint_angle = angle_normalize(self._joint_angle + self.dt * self._angle_velocity) + self._angle_velocity = vel + self._joint_angle = self._joint_angle + self.dt * self._angle_velocity self._update_joints() self._steps += 1 @@ -111,7 +114,7 @@ class SimpleReacherEnv(gym.Env): # reward_dist = np.exp(-0.1 * diff ** 2).mean() # reward_dist = - (diff ** 2).mean() - reward_ctrl = (action ** 2).sum() + reward_ctrl = 1e-5 * (action ** 2).sum() reward = reward_dist - reward_ctrl return reward, dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl) @@ -139,7 +142,7 @@ class SimpleReacherEnv(gym.Env): # Sample uniformly in circle with radius R around center of reacher. R = np.sum(self.link_lengths) r = R * np.sqrt(self.np_random.uniform()) - theta = self.np_random.uniform() * 2 * np.pi + theta = np.pi/2 + 0.001 * np.random.randn() # self.np_random.uniform() * 2 * np.pi return center + r * np.stack([np.cos(theta), np.sin(theta)]) def seed(self, seed=None): @@ -170,8 +173,8 @@ class SimpleReacherEnv(gym.Env): plt.xlim([-lim, lim]) plt.ylim([-lim, lim]) # plt.draw() - # plt.pause(1e-4) pushes window to foreground, which is annoying. - self.fig.canvas.flush_events() + plt.pause(1e-4) # pushes window to foreground, which is annoying. + # self.fig.canvas.flush_events() def close(self): del self.fig diff --git a/alr_envs/utils/mp_env_async_sampler.py b/alr_envs/utils/mp_env_async_sampler.py index 59cf594..2fb3645 100644 --- a/alr_envs/utils/mp_env_async_sampler.py +++ b/alr_envs/utils/mp_env_async_sampler.py @@ -56,6 +56,7 @@ class AlrMpEnvSampler: vals = defaultdict(list) for p in split_params: + self.env.reset() obs, reward, done, info = self.env.step(p) vals['obs'].append(obs) vals['reward'].append(reward) @@ -82,8 +83,9 @@ class AlrContextualMpEnvSampler: vals = defaultdict(list) for i in range(repeat): new_contexts = self.env.reset() - - new_samples = dist.sample(new_contexts) + vals['new_contexts'].append(new_contexts) + new_samples, new_contexts = dist.sample(new_contexts) + vals['new_samples'].append(new_samples) obs, reward, done, info = self.env.step(new_samples) vals['obs'].append(obs) @@ -92,7 +94,8 @@ class AlrContextualMpEnvSampler: vals['info'].append(info) # do not return values above threshold - return np.vstack(vals['obs'])[:n_samples], np.hstack(vals['reward'])[:n_samples],\ + return np.vstack(vals['new_samples'])[:n_samples], np.vstack(vals['new_contexts'])[:n_samples], \ + np.vstack(vals['obs'])[:n_samples], np.hstack(vals['reward'])[:n_samples], \ _flatten_list(vals['done'])[:n_samples], _flatten_list(vals['info'])[:n_samples] diff --git a/alr_envs/utils/wrapper/dmp_wrapper.py b/alr_envs/utils/wrapper/dmp_wrapper.py index 2a198db..36a8c92 100644 --- a/alr_envs/utils/wrapper/dmp_wrapper.py +++ b/alr_envs/utils/wrapper/dmp_wrapper.py @@ -98,7 +98,7 @@ class DmpWrapper(MPWrapper): def mp_rollout(self, action): # if self.mp.start_pos is None: - self.mp.dmp_start_pos = self.env.init_qpos # start_pos + self.mp.dmp_start_pos = self.env.init_qpos.reshape((1, self.num_dof)) # start_pos goal_pos, weight_matrix = self.goal_and_weights(action) self.mp.set_weights(weight_matrix, goal_pos) return self.mp.reference_trajectory(self.t) diff --git a/alr_envs/utils/wrapper/mp_wrapper.py b/alr_envs/utils/wrapper/mp_wrapper.py index adeba55..43f127c 100644 --- a/alr_envs/utils/wrapper/mp_wrapper.py +++ b/alr_envs/utils/wrapper/mp_wrapper.py @@ -22,7 +22,7 @@ class MPWrapper(gym.Wrapper, ABC): ): super().__init__(env) - # self.num_dof = num_dof + self.num_dof = num_dof # self.num_basis = num_basis # self.duration = duration # seconds @@ -50,6 +50,7 @@ class MPWrapper(gym.Wrapper, ABC): # for p, c in zip(params, contexts): for p in params: # self.configure(c) + # context = self.reset() ob, reward, done, info = self.step(p) obs.append(ob) rewards.append(reward) diff --git a/example.py b/example.py index 2d32ad8..1718e46 100644 --- a/example.py +++ b/example.py @@ -83,6 +83,6 @@ if __name__ == '__main__': # example_mujoco() # example_dmp() # example_async() - # env = gym.make("alr_envs:HoleReacherDMP-v0", context=0.1) - env = gym.make("alr_envs:SimpleReacherDMP-v1") + env = gym.make("alr_envs:HoleReacherDMP-v0") + # env = gym.make("alr_envs:SimpleReacherDMP-v1") print() \ No newline at end of file