Fix: BP was not returning new infos (smoothness metrics)

This commit is contained in:
Dominik Moritz Roth 2024-02-02 16:51:18 +01:00
parent c9ea8cb167
commit 642bf8761f
2 changed files with 8 additions and 1 deletions

View File

@ -1,7 +1,7 @@
from .ant_jump.ant_jump import AntJumpEnv
from .beerpong.beerpong import BeerPongEnv, BeerPongEnvStepBasedEpisodicReward
from .half_cheetah_jump.half_cheetah_jump import HalfCheetahJumpEnv
from .hopper_jump.hopper_jump import HopperJumpEnv
from .hopper_jump.hopper_jump import HopperJumpEnv, HopperJumpMarkovRew
from .hopper_jump.hopper_jump_on_box import HopperJumpOnBoxEnv
from .hopper_throw.hopper_throw import HopperThrowEnv
from .hopper_throw.hopper_throw_in_basket import HopperThrowInBasketEnv

View File

@ -50,6 +50,7 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
self._desired_rod_quat = desired_rod_quat
self._episode_energy = 0.
self.velocity_profile = []
self.observation_space = spaces.Box(
low=-np.inf, high=np.inf, shape=(28,), dtype=np.float64
@ -68,6 +69,8 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
unstable_simulation = False
self.velocity_profile.append(self.data.qvel[:7].copy())
try:
self.do_simulation(resultant_action, self.frame_skip)
except Exception as e:
@ -97,11 +100,15 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle):
obs = self._get_obs()
box_goal_pos_dist = 0. if not episode_end else np.linalg.norm(box_pos - target_pos)
box_goal_quat_dist = 0. if not episode_end else rotation_distance(box_quat, target_quat)
mean_squared_jerk, maximum_jerk, dimensionless_jerk = (0.0,0.0,0.0) if not episode_end else self.calculate_smoothness_metrics(np.array(self.velocity_profile), self.dt)
infos = {
'episode_end': episode_end,
'box_goal_pos_dist': box_goal_pos_dist,
'box_goal_rot_dist': box_goal_quat_dist,
'episode_energy': 0. if not episode_end else self._episode_energy,
'mean_squared_jerk': mean_squared_jerk,
'maximum_jerk': maximum_jerk,
'dimensionless_jerk': dimensionless_jerk,
'is_success': True if episode_end and box_goal_pos_dist < 0.05 and box_goal_quat_dist < 0.5 else False,
'num_steps': self._steps
}