include callback in step
This commit is contained in:
parent
137eb726eb
commit
f33996c27a
@ -22,6 +22,3 @@ class MPWrapper(BaseMPWrapper):
|
|||||||
# self.get_body_com("target"), # only return target to make problem harder
|
# self.get_body_com("target"), # only return target to make problem harder
|
||||||
[False], # step
|
[False], # step
|
||||||
])
|
])
|
||||||
|
|
||||||
def _step_callback(self, action):
|
|
||||||
pass
|
|
@ -33,6 +33,7 @@ class BaseMPWrapper(gym.Env, ABC):
|
|||||||
duration: float,
|
duration: float,
|
||||||
policy_type: Union[str, BaseController] = None,
|
policy_type: Union[str, BaseController] = None,
|
||||||
render_mode: str = None,
|
render_mode: str = None,
|
||||||
|
verbose=2,
|
||||||
**mp_kwargs
|
**mp_kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -44,7 +45,7 @@ class BaseMPWrapper(gym.Env, ABC):
|
|||||||
self.traj_steps = int(duration / self.dt)
|
self.traj_steps = int(duration / self.dt)
|
||||||
self.post_traj_steps = self.env.spec.max_episode_steps - self.traj_steps
|
self.post_traj_steps = self.env.spec.max_episode_steps - self.traj_steps
|
||||||
|
|
||||||
# TODO: move to constructer, use policy factory instead what Fabian already coded
|
# TODO: move to constructor, use policy factory instead what Fabian already coded
|
||||||
if isinstance(policy_type, str):
|
if isinstance(policy_type, str):
|
||||||
# pop policy kwargs here such that they are not passed to the initialize_mp method
|
# pop policy kwargs here such that they are not passed to the initialize_mp method
|
||||||
self.policy = get_policy_class(policy_type, self, **mp_kwargs.pop('policy_kwargs', {}))
|
self.policy = get_policy_class(policy_type, self, **mp_kwargs.pop('policy_kwargs', {}))
|
||||||
@ -53,6 +54,7 @@ class BaseMPWrapper(gym.Env, ABC):
|
|||||||
|
|
||||||
self.mp = mp
|
self.mp = mp
|
||||||
self.env = env
|
self.env = env
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
# rendering
|
# rendering
|
||||||
self.render_mode = render_mode
|
self.render_mode = render_mode
|
||||||
@ -114,14 +116,12 @@ class BaseMPWrapper(gym.Env, ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
def _step_callback(self, t, action):
|
||||||
def _step_callback(self, action):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def step(self, action: np.ndarray):
|
def step(self, action: np.ndarray):
|
||||||
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
|
""" This function generates a trajectory based on a MP and then does the usual loop over reset and step"""
|
||||||
# TODO: Think about sequencing
|
# TODO: Think about sequencing
|
||||||
# TODO: put in a callback function here which every environment can implement. Important for e.g. BP to allow the
|
|
||||||
# TODO: Reward Function rather here?
|
# TODO: Reward Function rather here?
|
||||||
# agent to learn when to release the ball
|
# agent to learn when to release the ball
|
||||||
trajectory, velocity = self.get_trajectory(action)
|
trajectory, velocity = self.get_trajectory(action)
|
||||||
@ -141,6 +141,9 @@ class BaseMPWrapper(gym.Env, ABC):
|
|||||||
|
|
||||||
for t, pos_vel in enumerate(zip(trajectory, velocity)):
|
for t, pos_vel in enumerate(zip(trajectory, velocity)):
|
||||||
ac = self.policy.get_action(pos_vel[0], pos_vel[1])
|
ac = self.policy.get_action(pos_vel[0], pos_vel[1])
|
||||||
|
callback_action = self._step_callback(t, action)
|
||||||
|
if callback_action is not None:
|
||||||
|
ac = np.concatenate((callback_action, ac)) # include callback action at first pos of vector
|
||||||
actions[t, :] = np.clip(ac, self.env.action_space.low, self.env.action_space.high)
|
actions[t, :] = np.clip(ac, self.env.action_space.low, self.env.action_space.high)
|
||||||
obs, rewards[t], done, info = self.env.step(actions[t, :])
|
obs, rewards[t], done, info = self.env.step(actions[t, :])
|
||||||
observations[t, :] = obs["observation"] if isinstance(self.env.observation_space, spaces.Dict) else obs
|
observations[t, :] = obs["observation"] if isinstance(self.env.observation_space, spaces.Dict) else obs
|
||||||
@ -156,7 +159,7 @@ class BaseMPWrapper(gym.Env, ABC):
|
|||||||
break
|
break
|
||||||
infos.update({k: v[:t + 1] for k, v in infos.items()})
|
infos.update({k: v[:t + 1] for k, v in infos.items()})
|
||||||
infos['trajectory'] = trajectory
|
infos['trajectory'] = trajectory
|
||||||
# TODO: remove step information? Might be relevant for debugging -> return only in debug mode (verbose)?
|
if self.verbose == 2:
|
||||||
infos['step_actions'] = actions[:t + 1]
|
infos['step_actions'] = actions[:t + 1]
|
||||||
infos['step_observations'] = observations[:t + 1]
|
infos['step_observations'] = observations[:t + 1]
|
||||||
infos['step_rewards'] = rewards[:t + 1]
|
infos['step_rewards'] = rewards[:t + 1]
|
||||||
|
Loading…
Reference in New Issue
Block a user