From f33996c27afcb5592b6ce325c58c9f7373864c1b Mon Sep 17 00:00:00 2001 From: Onur Date: Mon, 2 May 2022 15:06:21 +0200 Subject: [PATCH] include callback in step --- alr_envs/alr/mujoco/reacher/new_mp_wrapper.py | 5 +---- mp_wrapper.py | 21 +++++++++++-------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py b/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py index 37499e6..5475366 100644 --- a/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py +++ b/alr_envs/alr/mujoco/reacher/new_mp_wrapper.py @@ -21,7 +21,4 @@ class MPWrapper(BaseMPWrapper): [False] * 3, # goal distance # self.get_body_com("target"), # only return target to make problem harder [False], # step - ]) - - def _step_callback(self, action): - pass \ No newline at end of file + ]) \ No newline at end of file diff --git a/mp_wrapper.py b/mp_wrapper.py index 82b23bd..3defe07 100644 --- a/mp_wrapper.py +++ b/mp_wrapper.py @@ -33,6 +33,7 @@ class BaseMPWrapper(gym.Env, ABC): duration: float, policy_type: Union[str, BaseController] = None, render_mode: str = None, + verbose=2, **mp_kwargs ): super().__init__() @@ -44,7 +45,7 @@ class BaseMPWrapper(gym.Env, ABC): self.traj_steps = int(duration / self.dt) self.post_traj_steps = self.env.spec.max_episode_steps - self.traj_steps - # TODO: move to constructer, use policy factory instead what Fabian already coded + # TODO: move to constructor, use policy factory instead what Fabian already coded if isinstance(policy_type, str): # pop policy kwargs here such that they are not passed to the initialize_mp method self.policy = get_policy_class(policy_type, self, **mp_kwargs.pop('policy_kwargs', {})) @@ -53,6 +54,7 @@ class BaseMPWrapper(gym.Env, ABC): self.mp = mp self.env = env + self.verbose = verbose # rendering self.render_mode = render_mode @@ -114,14 +116,12 @@ class BaseMPWrapper(gym.Env, ABC): """ raise NotImplementedError() - @abstractmethod - def _step_callback(self, action): + def _step_callback(self, t, action): pass def step(self, action: np.ndarray): """ This function generates a trajectory based on a MP and then does the usual loop over reset and step""" # TODO: Think about sequencing - # TODO: put in a callback function here which every environment can implement. Important for e.g. BP to allow the # TODO: Reward Function rather here? # agent to learn when to release the ball trajectory, velocity = self.get_trajectory(action) @@ -141,6 +141,9 @@ class BaseMPWrapper(gym.Env, ABC): for t, pos_vel in enumerate(zip(trajectory, velocity)): ac = self.policy.get_action(pos_vel[0], pos_vel[1]) + callback_action = self._step_callback(t, action) + if callback_action is not None: + ac = np.concatenate((callback_action, ac)) # include callback action at first pos of vector actions[t, :] = np.clip(ac, self.env.action_space.low, self.env.action_space.high) obs, rewards[t], done, info = self.env.step(actions[t, :]) observations[t, :] = obs["observation"] if isinstance(self.env.observation_space, spaces.Dict) else obs @@ -156,11 +159,11 @@ class BaseMPWrapper(gym.Env, ABC): break infos.update({k: v[:t + 1] for k, v in infos.items()}) infos['trajectory'] = trajectory - # TODO: remove step information? Might be relevant for debugging -> return only in debug mode (verbose)? - infos['step_actions'] = actions[:t + 1] - infos['step_observations'] = observations[:t + 1] - infos['step_rewards'] = rewards[:t + 1] - infos['trajectory_length'] = t + 1 + if self.verbose == 2: + infos['step_actions'] = actions[:t + 1] + infos['step_observations'] = observations[:t + 1] + infos['step_rewards'] = rewards[:t + 1] + infos['trajectory_length'] = t + 1 done = True return self.get_observation_from_step(observations[t]), trajectory_return, done, infos