diff --git a/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py b/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py index 9b512a8..9ac03dd 100644 --- a/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py +++ b/fancy_gym/envs/mujoco/box_pushing/box_pushing_env.py @@ -6,6 +6,7 @@ from gymnasium.envs.mujoco import MujocoEnv from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import rot_to_quat, get_quaternion_error, rotation_distance from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import q_max, q_min, q_dot_max, q_torque_max from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import desired_rod_quat +from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import calculate_jerk_profile, calculate_mean_squared_jerk, calculate_dimensionless_jerk, calculate_maximum_jerk import mujoco @@ -110,6 +111,26 @@ class BoxPushingEnvBase(MujocoEnv, utils.EzPickle): return obs, reward, terminated, truncated, infos + def calculate_smoothness_metrics(self, velocity_profile, dt): + """ + Calculates the smoothness metrics for the given velocity profile. + param velocity_profile: np.array + The array containing the movement velocity profile. + param dt: float + The sampling time interval of the data. + return mean_squared_jerk: float + The mean squared jerk estimate of the given movement's smoothness. + return maximum_jerk: float + The maximum jerk estimate of the given movement's smoothness. + return dimensionless_jerk: float + The dimensionless jerk estimate of the given movement's smoothness. + """ + jerk_profile = calculate_jerk_profile(velocity_profile, dt) + mean_squared_jerk = calculate_mean_squared_jerk(jerk_profile) + maximum_jerk = calculate_maximum_jerk(jerk_profile) + dimensionless_jerk = calculate_dimensionless_jerk(jerk_profile, velocity_profile, dt) + return mean_squared_jerk, maximum_jerk, dimensionless_jerk + def reset_model(self): # rest box to initial position self.set_state(self.init_qpos_box_pushing, self.init_qvel_box_pushing) diff --git a/fancy_gym/envs/mujoco/box_pushing/box_pushing_utils.py b/fancy_gym/envs/mujoco/box_pushing/box_pushing_utils.py index 0b1919e..d880421 100644 --- a/fancy_gym/envs/mujoco/box_pushing/box_pushing_utils.py +++ b/fancy_gym/envs/mujoco/box_pushing/box_pushing_utils.py @@ -51,3 +51,19 @@ def rot_to_quat(theta, axis): quant[0] = np.sin(theta / 2.) quant[1:] = np.cos(theta / 2.) * axis return quant + +def calculate_jerk_profile(velocity_profile, dt): + jerk = np.diff(velocity_profile, 2, 0) / pow(dt, 2) + return jerk + +def calculate_mean_squared_jerk(jerk_profile): + return np.mean(pow(jerk_profile, 2)) + +def calculate_maximum_jerk(jerk_profile): + return np.max(abs(jerk_profile)) + +def calculate_dimensionless_jerk(jerk_profile, velocity_profile, dt): + sum_squared_jerk = np.sum(pow(jerk_profile, 2), 0) + duration = len(velocity_profile) * dt + peak_velocity = np.max(abs(velocity_profile), 0) + return np.mean(sum_squared_jerk * pow(duration, 3) / pow(peak_velocity, 2)) \ No newline at end of file