modified according to comments
This commit is contained in:
		
							parent
							
								
									1fd4a1e848
								
							
						
					
					
						commit
						bdd51ba61f
					
				| @ -16,7 +16,8 @@ from .mujoco.hopper_throw.hopper_throw import MAX_EPISODE_STEPS_HOPPERTHROW | ||||
| from .mujoco.hopper_throw.hopper_throw_in_basket import MAX_EPISODE_STEPS_HOPPERTHROWINBASKET | ||||
| from .mujoco.reacher.reacher import ReacherEnv, MAX_EPISODE_STEPS_REACHER | ||||
| from .mujoco.walker_2d_jump.walker_2d_jump import MAX_EPISODE_STEPS_WALKERJUMP | ||||
| from .mujoco.box_pushing.box_pushing_env import BoxPushingEnv, MAX_EPISODE_STEPS_BOX_PUSHING | ||||
| from .mujoco.box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, \ | ||||
|                                                 BoxPushingTemporalSpatialSparse, MAX_EPISODE_STEPS_BOX_PUSHING | ||||
| 
 | ||||
| ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS = {"DMP": [], "ProMP": [], "ProDMP": []} | ||||
| 
 | ||||
| @ -232,10 +233,9 @@ register( | ||||
| for reward_type in ["Dense", "TemporalSparse", "TemporalSpatialSparse"]: | ||||
|     register( | ||||
|         id='BoxPushing{}-v0'.format(reward_type), | ||||
|         entry_point='fancy_gym.envs.mujoco:BoxPushingEnv', | ||||
|         entry_point='fancy_gym.envs.mujoco:BoxPushing{}'.format(reward_type), | ||||
|         max_episode_steps=MAX_EPISODE_STEPS_BOX_PUSHING//10,  # divided by frames skip | ||||
|         kwargs={ | ||||
|             "reward_type": reward_type, | ||||
|             "frame_skip": 10 | ||||
|         } | ||||
|     ) | ||||
|  | ||||
| @ -7,4 +7,4 @@ from .hopper_throw.hopper_throw import HopperThrowEnv | ||||
| from .hopper_throw.hopper_throw_in_basket import HopperThrowInBasketEnv | ||||
| from .reacher.reacher import ReacherEnv | ||||
| from .walker_2d_jump.walker_2d_jump import Walker2dJumpEnv | ||||
| from .box_pushing.box_pushing_env import BoxPushingEnv | ||||
| from .box_pushing.box_pushing_env import BoxPushingDense, BoxPushingTemporalSparse, BoxPushingTemporalSpatialSparse | ||||
|  | ||||
| @ -5,7 +5,7 @@ from gym import utils, spaces | ||||
| from gym.envs.mujoco import MujocoEnv | ||||
| from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import rot_to_quat, get_quaternion_error, rotation_distance | ||||
| from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import q_max, q_min, q_dot_max, q_torque_max | ||||
| from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import BoxPushingReward | ||||
| from fancy_gym.envs.mujoco.box_pushing.box_pushing_utils import desired_rod_quat | ||||
| 
 | ||||
| import mujoco | ||||
| 
 | ||||
| @ -13,7 +13,7 @@ MAX_EPISODE_STEPS_BOX_PUSHING = 1000 | ||||
| 
 | ||||
| BOX_POS_BOUND = np.array([[0.3, -0.45, -0.01], [0.6, 0.45, -0.01]]) | ||||
| 
 | ||||
| class BoxPushingEnv(MujocoEnv, utils.EzPickle): | ||||
| class BoxPushingEnvBase(MujocoEnv, utils.EzPickle): | ||||
|     """ | ||||
|     franka box pushing environment | ||||
|     action space: | ||||
| @ -26,14 +26,18 @@ class BoxPushingEnv(MujocoEnv, utils.EzPickle): | ||||
|     3. time-spatial-depend sparse reward | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, reward_type: str = "Dense", frame_skip: int = 10): | ||||
|     def __init__(self, frame_skip: int = 10): | ||||
|         utils.EzPickle.__init__(**locals()) | ||||
|         self._steps = 0 | ||||
|         self.init_qpos_box_pushing = np.array([0., 0., 0., -1.5, 0., 1.5, 0., 0., 0., 0.6, 0.45, 0.0, 1., 0., 0., 0.]) | ||||
|         self.init_qvel_box_pushing = np.zeros(15) | ||||
|         self.frame_skip = frame_skip | ||||
|         assert reward_type in ["Dense", "TemporalSparse", "TemporalSpatialSparse"], "unrecognized reward type" | ||||
|         self.reward = BoxPushingReward(reward_type, q_max, q_min, q_dot_max) | ||||
| 
 | ||||
|         self._q_max = q_max | ||||
|         self._q_min = q_min | ||||
|         self._q_dot_max = q_dot_max | ||||
|         self._desired_rod_quat = desired_rod_quat | ||||
| 
 | ||||
|         self._episode_energy = 0. | ||||
|         MujocoEnv.__init__(self, | ||||
|                            model_path=os.path.join(os.path.dirname(__file__), "assets", "box_pushing.xml"), | ||||
| @ -68,7 +72,7 @@ class BoxPushingEnv(MujocoEnv, utils.EzPickle): | ||||
|         qvel = self.data.qvel[:7].copy() | ||||
| 
 | ||||
|         if not unstable_simulation: | ||||
|             reward = self.reward.get_reward(episode_end, box_pos, box_quat, target_pos, target_quat, | ||||
|             reward = self._get_reward(episode_end, box_pos, box_quat, target_pos, target_quat, | ||||
|                                       rod_tip_pos, rod_quat, qpos, qvel, action) | ||||
|         else: | ||||
|             reward = -50 | ||||
| @ -122,6 +126,10 @@ class BoxPushingEnv(MujocoEnv, utils.EzPickle): | ||||
|         quat = rot_to_quat(theta, np.array([0, 0, 1])) | ||||
|         return np.concatenate([pos, quat]) | ||||
| 
 | ||||
|     def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat, | ||||
|                     rod_tip_pos, rod_quat, qpos, qvel, action): | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     def _get_obs(self): | ||||
|         obs = np.concatenate([ | ||||
|             self.data.qpos[:7].copy(),  # joint position | ||||
| @ -136,6 +144,22 @@ class BoxPushingEnv(MujocoEnv, utils.EzPickle): | ||||
|         ]) | ||||
|         return obs | ||||
| 
 | ||||
|     def _joint_limit_violate_penalty(self, qpos, qvel, enable_pos_limit=False, enable_vel_limit=False): | ||||
|         penalty = 0. | ||||
|         p_coeff = 1. | ||||
|         v_coeff = 1. | ||||
|         # q_limit | ||||
|         if enable_pos_limit: | ||||
|             higher_error = qpos - self._q_max | ||||
|             lower_error = self._q_min - qpos | ||||
|             penalty -= p_coeff * (abs(np.sum(higher_error[qpos > self._q_max])) + | ||||
|                                   abs(np.sum(lower_error[qpos < self._q_min]))) | ||||
|         # q_dot_limit | ||||
|         if enable_vel_limit: | ||||
|             q_dot_error = abs(qvel) - abs(self._q_dot_max) | ||||
|             penalty -= v_coeff * abs(np.sum(q_dot_error[q_dot_error > 0.])) | ||||
|         return penalty | ||||
| 
 | ||||
|     def get_body_jacp(self, name): | ||||
|         id = mujoco.mj_name2id(self.model, 1, name) | ||||
|         jacp = np.zeros((3, self.model.nv)) | ||||
| @ -252,8 +276,89 @@ class BoxPushingEnv(MujocoEnv, utils.EzPickle): | ||||
| 
 | ||||
|         return q | ||||
| 
 | ||||
| class BoxPushingDense(BoxPushingEnvBase): | ||||
|     def __init__(self, frame_skip: int = 10): | ||||
|         super(BoxPushingDense, self).__init__(frame_skip=frame_skip) | ||||
|     def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat, | ||||
|                     rod_tip_pos, rod_quat, qpos, qvel, action): | ||||
|         joint_penalty = self._joint_limit_violate_penalty(qpos, | ||||
|                                                           qvel, | ||||
|                                                           enable_pos_limit=True, | ||||
|                                                           enable_vel_limit=True) | ||||
|         tcp_box_dist_reward = -2 * np.clip(np.linalg.norm(box_pos - rod_tip_pos), 0.05, 100) | ||||
|         box_goal_pos_dist_reward = -3.5 * np.linalg.norm(box_pos - target_pos) | ||||
|         box_goal_rot_dist_reward = -rotation_distance(box_quat, target_quat) / np.pi | ||||
|         energy_cost = -0.0005 * np.sum(np.square(action)) | ||||
| 
 | ||||
|         reward = joint_penalty + tcp_box_dist_reward + \ | ||||
|                  box_goal_pos_dist_reward + box_goal_rot_dist_reward + energy_cost | ||||
| 
 | ||||
|         rod_inclined_angle = rotation_distance(rod_quat, self._desired_rod_quat) | ||||
|         if rod_inclined_angle > np.pi / 4: | ||||
|             reward -= rod_inclined_angle / (np.pi) | ||||
| 
 | ||||
|         return reward | ||||
| 
 | ||||
| class BoxPushingTemporalSparse(BoxPushingEnvBase): | ||||
|     def __init__(self, frame_skip: int = 10): | ||||
|         super(BoxPushingTemporalSparse, self).__init__(frame_skip=frame_skip) | ||||
| 
 | ||||
|     def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat, | ||||
|                     rod_tip_pos, rod_quat, qpos, qvel, action): | ||||
|         reward = 0. | ||||
|         joint_penalty = self._joint_limit_violate_penalty(qpos, qvel, enable_pos_limit=True, enable_vel_limit=True) | ||||
|         energy_cost = -0.0005 * np.sum(np.square(action)) | ||||
|         tcp_box_dist_reward = -2 * np.clip(np.linalg.norm(box_pos - rod_tip_pos), 0.05, 100) | ||||
|         reward += joint_penalty + tcp_box_dist_reward + energy_cost | ||||
|         rod_inclined_angle = rotation_distance(rod_quat, desired_rod_quat) | ||||
| 
 | ||||
|         if rod_inclined_angle > np.pi / 4: | ||||
|             reward -= rod_inclined_angle / (np.pi) | ||||
| 
 | ||||
|         if not episode_end: | ||||
|             return reward | ||||
| 
 | ||||
|         box_goal_dist = np.linalg.norm(box_pos - target_pos) | ||||
| 
 | ||||
|         box_goal_pos_dist_reward = -3.5 * box_goal_dist * 100 | ||||
|         box_goal_rot_dist_reward = -rotation_distance(box_quat, target_quat) / np.pi * 100 | ||||
| 
 | ||||
|         reward += box_goal_pos_dist_reward + box_goal_rot_dist_reward | ||||
| 
 | ||||
|         return reward | ||||
| 
 | ||||
| class BoxPushingTemporalSpatialSparse(BoxPushingEnvBase): | ||||
| 
 | ||||
|     def __init__(self, frame_skip: int = 10): | ||||
|         super(BoxPushingTemporalSpatialSparse, self).__init__(frame_skip=frame_skip) | ||||
| 
 | ||||
|     def _get_reward(self, episode_end, box_pos, box_quat, target_pos, target_quat, | ||||
|                     rod_tip_pos, rod_quat, qpos, qvel, action): | ||||
|         reward = 0. | ||||
|         joint_penalty = self._joint_limit_violate_penalty(qpos, qvel, enable_pos_limit=True, enable_vel_limit=True) | ||||
|         energy_cost = -0.0005 * np.sum(np.square(action)) | ||||
|         tcp_box_dist_reward = -2 * np.clip(np.linalg.norm(box_pos - rod_tip_pos), 0.05, 100) | ||||
|         reward += joint_penalty + tcp_box_dist_reward + energy_cost | ||||
|         rod_inclined_angle = rotation_distance(rod_quat, desired_rod_quat) | ||||
| 
 | ||||
|         if rod_inclined_angle > np.pi / 4: | ||||
|             reward -= rod_inclined_angle / (np.pi) | ||||
| 
 | ||||
|         if not episode_end: | ||||
|             return reward | ||||
| 
 | ||||
|         box_goal_dist = np.linalg.norm(box_pos - target_pos) | ||||
| 
 | ||||
|         if box_goal_dist < 0.1: | ||||
|             reward += 300 | ||||
|             box_goal_pos_dist_reward = np.clip(- 3.5 * box_goal_dist * 100 * 3, -100, 0) | ||||
|             box_goal_rot_dist_reward = np.clip(- rotation_distance(box_quat, target_quat)/np.pi * 100 * 1.5, -100, 0) | ||||
|             reward += box_goal_pos_dist_reward + box_goal_rot_dist_reward | ||||
| 
 | ||||
|         return reward | ||||
| 
 | ||||
| if __name__=="__main__": | ||||
|     env = BoxPushingEnv(reward_type="dense", frame_skip=10) | ||||
|     env = BoxPushingTemporalSpatialSparse(frame_skip=10) | ||||
|     env.reset() | ||||
|     for i in range(100): | ||||
|         env.reset() | ||||
|  | ||||
| @ -10,184 +10,44 @@ q_torque_max = np.array([90., 90., 90., 90., 12., 12., 12.]) | ||||
| # | ||||
| desired_rod_quat = np.array([0.0, 1.0, 0.0, 0.0]) | ||||
| 
 | ||||
| def skew(x): | ||||
|     """ | ||||
|     Returns the skew-symmetric matrix of x | ||||
|     param x: 3x1 vector | ||||
|     """ | ||||
|     return np.array([[0, -x[2], x[1]], [x[2], 0, -x[0]], [-x[1], x[0], 0]]) | ||||
| 
 | ||||
| def get_quaternion_error(curr_quat, des_quat): | ||||
|     """ | ||||
|     Calculates the difference between the current quaternion and the desired quaternion. | ||||
|     See Siciliano textbook page 140 Eq 3.91 | ||||
| 
 | ||||
|     :param curr_quat: current quaternion | ||||
|     :param des_quat: desired quaternion | ||||
|     :return: difference between current quaternion and desired quaternion | ||||
|     param curr_quat: current quaternion | ||||
|     param des_quat: desired quaternion | ||||
|     return: difference between current quaternion and desired quaternion | ||||
|     """ | ||||
|     quatError = np.zeros((3, )) | ||||
| 
 | ||||
|     quatError[0] = (curr_quat[0] * des_quat[1] - des_quat[0] * curr_quat[1] - | ||||
|                     curr_quat[3] * des_quat[2] + curr_quat[2] * des_quat[3]) | ||||
| 
 | ||||
|     quatError[1] = (curr_quat[0] * des_quat[2] - des_quat[0] * curr_quat[2] + | ||||
|                     curr_quat[3] * des_quat[1] - curr_quat[1] * des_quat[3]) | ||||
| 
 | ||||
|     quatError[2] = (curr_quat[0] * des_quat[3] - des_quat[0] * curr_quat[3] - | ||||
|                     curr_quat[2] * des_quat[1] + curr_quat[1] * des_quat[2]) | ||||
| 
 | ||||
|     return quatError | ||||
| 
 | ||||
|     return curr_quat[0] * des_quat[1:] - des_quat[0] * curr_quat[1:] - skew(des_quat[1:]) @ curr_quat[1:] | ||||
| 
 | ||||
| def rotation_distance(p: np.array, q: np.array): | ||||
|     """ | ||||
|     p: quaternion | ||||
|     q: quaternion | ||||
|     Calculates the rotation angular between two quaternions | ||||
|     param p: quaternion | ||||
|     param q: quaternion | ||||
|     theta: rotation angle between p and q (rad) | ||||
|     """ | ||||
|     assert p.shape == q.shape, "p and q should be quaternion" | ||||
|     product = p[0] * q[0] + p[1] * q[1] + p[2] * q[2] + p[3] * q[3] | ||||
|     theta = 2 * np.arccos(abs(product)) | ||||
|     theta = 2 * np.arccos(abs(p @ q)) | ||||
|     return theta | ||||
| 
 | ||||
| 
 | ||||
| def rot_to_quat(theta, axis): | ||||
|     """ | ||||
|     Converts rotation angle along an axis to quaternion | ||||
|     param theta: rotation angle (rad) | ||||
|     param axis: rotation axis | ||||
|     return: quaternion | ||||
|     """ | ||||
|     quant = np.zeros(4) | ||||
|     quant[0] = np.sin(theta / 2.) | ||||
|     quant[1] = np.cos(theta / 2.) * axis[0] | ||||
|     quant[2] = np.cos(theta / 2.) * axis[1] | ||||
|     quant[3] = np.cos(theta / 2.) * axis[2] | ||||
|     quant[1:] = np.cos(theta / 2.) * axis | ||||
|     return quant | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| class RewardBase(): | ||||
|     def __init__(self, q_max, q_min, q_dot_max): | ||||
|         self._reward = 0. | ||||
|         self._done = False | ||||
|         self._q_max = q_max | ||||
|         self._q_min = q_min | ||||
|         self._q_dot_max = q_dot_max | ||||
| 
 | ||||
|     def get_reward(self, episodic_end, box_pos, box_quat, target_pos, target_quat, | ||||
|                    rod_tip_pos, rod_quat, qpos, qvel, action): | ||||
|         raise NotImplementedError | ||||
| 
 | ||||
|     def _joint_limit_violate_penalty(self, | ||||
|                                     qpos, | ||||
|                                     qvel, | ||||
|                                     enable_pos_limit=False, | ||||
|                                     enable_vel_limit=False): | ||||
|         penalty = 0. | ||||
|         p_coeff = 1. | ||||
|         v_coeff = 1. | ||||
|         # q_limit | ||||
|         if enable_pos_limit: | ||||
|             higher_indice = np.where(qpos > self._q_max) | ||||
|             lower_indice = np.where(qpos < self._q_min) | ||||
|             higher_error = qpos - self._q_max | ||||
|             lower_error = self._q_min - qpos | ||||
|             penalty -= p_coeff * (abs(np.sum(higher_error[higher_indice])) + | ||||
|                                   abs(np.sum(lower_error[lower_indice]))) | ||||
|         # q_dot_limit | ||||
|         if enable_vel_limit: | ||||
|             q_dot_error = abs(qvel) - abs(self._q_dot_max) | ||||
|             q_dot_violate_idx = np.where(q_dot_error > 0.) | ||||
|             penalty -= v_coeff * abs(np.sum(q_dot_error[q_dot_violate_idx])) | ||||
|         return penalty | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| class DenseReward(RewardBase): | ||||
|     def __init__(self, q_max, q_min, q_dot_max): | ||||
|         super(DenseReward, self).__init__(q_max, q_min, q_dot_max) | ||||
| 
 | ||||
|     def get_reward(self, episodic_end, box_pos, box_quat, target_pos, target_quat, | ||||
|                    rod_tip_pos, rod_quat, qpos, qvel, action): | ||||
|         joint_penalty = self._joint_limit_violate_penalty(qpos, | ||||
|                                                           qvel, | ||||
|                                                           enable_pos_limit=True, | ||||
|                                                           enable_vel_limit=True) | ||||
|         tcp_box_dist_reward = -2 * np.clip(np.linalg.norm(box_pos - rod_tip_pos), 0.05, 100) | ||||
|         box_goal_pos_dist_reward = -3.5 * np.linalg.norm(box_pos - target_pos) | ||||
|         box_goal_rot_dist_reward = -rotation_distance(box_quat, target_quat) / np.pi | ||||
|         energy_cost = -0.0005 * np.sum(np.square(action)) | ||||
| 
 | ||||
|         reward = joint_penalty + tcp_box_dist_reward + \ | ||||
|                  box_goal_pos_dist_reward + box_goal_rot_dist_reward + energy_cost | ||||
| 
 | ||||
|         rod_inclined_angle = rotation_distance(rod_quat, desired_rod_quat) | ||||
|         if rod_inclined_angle > np.pi / 4: | ||||
|             reward -= rod_inclined_angle / (np.pi) | ||||
| 
 | ||||
|         return reward | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| class TemporalSparseReward(RewardBase): | ||||
|     def __init__(self, q_max, q_min, q_dot_max): | ||||
|         super(TemporalSparseReward, self).__init__(q_max, q_min, q_dot_max) | ||||
| 
 | ||||
|     def get_reward(self, episodic_end, box_pos, box_quat, target_pos, target_quat, | ||||
|                    rod_tip_pos, rod_quat, qpos, qvel, action): | ||||
|         reward = 0. | ||||
|         joint_penalty = self._joint_limit_violate_penalty(qpos, qvel, enable_pos_limit=True, enable_vel_limit=True) | ||||
|         energy_cost = -0.0005 * np.sum(np.square(action)) | ||||
|         tcp_box_dist_reward = -2 * np.clip(np.linalg.norm(box_pos - rod_tip_pos), 0.05, 100) | ||||
|         reward += joint_penalty + tcp_box_dist_reward + energy_cost | ||||
|         rod_inclined_angle = rotation_distance(rod_quat, desired_rod_quat) | ||||
| 
 | ||||
|         if rod_inclined_angle > np.pi / 4: | ||||
|             reward -= rod_inclined_angle / (np.pi) | ||||
| 
 | ||||
|         if not episodic_end: | ||||
|             return reward | ||||
| 
 | ||||
|         box_goal_dist = np.linalg.norm(box_pos - target_pos) | ||||
| 
 | ||||
|         box_goal_pos_dist_reward = -3.5 * box_goal_dist * 100 | ||||
|         box_goal_rot_dist_reward = -rotation_distance(box_quat, target_quat) / np.pi * 100 | ||||
| 
 | ||||
|         reward += box_goal_pos_dist_reward + box_goal_rot_dist_reward | ||||
| 
 | ||||
|         return reward | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| class TemporalSpatialSparseReward(RewardBase): | ||||
|     def __init__(self, q_max, q_min, q_dot_max): | ||||
|         super(TemporalSpatialSparseReward, self).__init__(q_max, q_min, q_dot_max) | ||||
| 
 | ||||
|     def get_reward(self, episodic_end, box_pos, box_quat, target_pos, target_quat, | ||||
|                    rod_tip_pos, rod_quat, qpos, qvel, action): | ||||
|         reward = 0. | ||||
|         joint_penalty = self._joint_limit_violate_penalty(qpos, qvel, enable_pos_limit=True, enable_vel_limit=True) | ||||
|         energy_cost = -0.0005 * np.sum(np.square(action)) | ||||
|         tcp_box_dist_reward = -2 * np.clip(np.linalg.norm(box_pos - rod_tip_pos), 0.05, 100) | ||||
|         reward += joint_penalty + tcp_box_dist_reward + energy_cost | ||||
|         rod_inclined_angle = rotation_distance(rod_quat, desired_rod_quat) | ||||
| 
 | ||||
|         if rod_inclined_angle > np.pi / 4: | ||||
|             reward -= rod_inclined_angle / (np.pi) | ||||
| 
 | ||||
|         if not episodic_end: | ||||
|             return reward | ||||
| 
 | ||||
|         box_goal_dist = np.linalg.norm(box_pos - target_pos) | ||||
| 
 | ||||
|         if box_goal_dist < 0.1: | ||||
|             reward += 300 | ||||
|             box_goal_pos_dist_reward = np.clip(- 3.5 * box_goal_dist * 100 * 3, -100, 0) | ||||
|             box_goal_rot_dist_reward = np.clip(- rotation_distance(box_quat, target_quat)/np.pi * 100 * 1.5, -100, 0) | ||||
|             reward += box_goal_pos_dist_reward + box_goal_rot_dist_reward | ||||
| 
 | ||||
|         return reward | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| def BoxPushingReward(reward_type, q_max, q_min, q_dot_max): | ||||
|     if reward_type == 'Dense': | ||||
|         return DenseReward(q_max, q_min, q_dot_max) | ||||
|     elif reward_type == 'TemporalSparse': | ||||
|         return TemporalSparseReward(q_max, q_min, q_dot_max) | ||||
|     elif reward_type == 'TemporalSpatialSparse': | ||||
|         return TemporalSpatialSparseReward(q_max, q_min, q_dot_max) | ||||
|     else: | ||||
|         raise NotImplementedError | ||||
| @ -157,17 +157,17 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True): | ||||
| if __name__ == '__main__': | ||||
|     render = True | ||||
|     # DMP | ||||
|     # example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render) | ||||
|     example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render) | ||||
| 
 | ||||
|     # ProMP | ||||
|     # example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) | ||||
|     # example_mp("BoxPushingDenseProMP-v0", seed=10, iterations=50, render=render) | ||||
|     example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) | ||||
|     example_mp("BoxPushingTemporalSparseProMP-v0", seed=10, iterations=1, render=render) | ||||
| 
 | ||||
|     # ProDMP | ||||
|     example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=50, render=render) | ||||
|     example_mp("BoxPushingDenseProDMP-v0", seed=10, iterations=1, render=render) | ||||
| 
 | ||||
|     # Altered basis functions | ||||
|     # obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render) | ||||
|     obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=1, render=render) | ||||
| 
 | ||||
|     # Custom MP | ||||
|     # example_fully_custom_mp(seed=10, iterations=1, render=render) | ||||
|     example_fully_custom_mp(seed=10, iterations=1, render=render) | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user