Merge branch 'master' into dev_new_mp_api
Conflicts: alr_envs/alr/__init__.py
This commit is contained in:
		
						commit
						a26b9f463b
					
				| @ -538,16 +538,68 @@ for _v in _versions: | |||||||
|             "wrappers": [classic_control.hole_reacher.MPWrapper], |             "wrappers": [classic_control.hole_reacher.MPWrapper], | ||||||
|             "mp_kwargs": { |             "mp_kwargs": { | ||||||
|                 "num_dof": 5, |                 "num_dof": 5, | ||||||
|                 "num_basis": 5, |                 "num_basis": 3, | ||||||
|                 "duration": 2, |                 "duration": 2, | ||||||
|                 "policy_type": "velocity", |                 "policy_type": "velocity", | ||||||
|                 "weights_scale": 0.1, |                 "weights_scale": 5, | ||||||
|                 "zero_start": True |                 "zero_start": True | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     ) |     ) | ||||||
|     ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) |     ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) | ||||||
| 
 | 
 | ||||||
|  | ## ALRReacher | ||||||
|  | _versions = ["ALRReacher-v0", "ALRLongReacher-v0", "ALRReacherSparse-v0", "ALRLongReacherSparse-v0"] | ||||||
|  | for _v in _versions: | ||||||
|  |     _name = _v.split("-") | ||||||
|  |     _env_id = f'{_name[0]}DMP-{_name[1]}' | ||||||
|  |     register( | ||||||
|  |         id=_env_id, | ||||||
|  |         entry_point='alr_envs.utils.make_env_helpers:make_dmp_env_helper', | ||||||
|  |         # max_episode_steps=1, | ||||||
|  |         kwargs={ | ||||||
|  |             "name": f"alr_envs:{_v}", | ||||||
|  |             "wrappers": [mujoco.reacher.MPWrapper], | ||||||
|  |             "mp_kwargs": { | ||||||
|  |                 "num_dof": 5 if "long" not in _v.lower() else 7, | ||||||
|  |                 "num_basis": 2, | ||||||
|  |                 "duration": 4, | ||||||
|  |                 "alpha_phase": 2, | ||||||
|  |                 "learn_goal": True, | ||||||
|  |                 "policy_type": "motor", | ||||||
|  |                 "weights_scale": 5, | ||||||
|  |                 "policy_kwargs": { | ||||||
|  |                     "p_gains": 1, | ||||||
|  |                     "d_gains": 0.1 | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     ) | ||||||
|  |     ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"].append(_env_id) | ||||||
|  | 
 | ||||||
|  |     _env_id = f'{_name[0]}ProMP-{_name[1]}' | ||||||
|  |     register( | ||||||
|  |         id=_env_id, | ||||||
|  |         entry_point='alr_envs.utils.make_env_helpers:make_promp_env_helper', | ||||||
|  |         kwargs={ | ||||||
|  |             "name": f"alr_envs:{_v}", | ||||||
|  |             "wrappers": [mujoco.reacher.MPWrapper], | ||||||
|  |             "mp_kwargs": { | ||||||
|  |                 "num_dof": 5 if "long" not in _v.lower() else 7, | ||||||
|  |                 "num_basis": 2, | ||||||
|  |                 "duration": 4, | ||||||
|  |                 "policy_type": "motor", | ||||||
|  |                 "weights_scale": 5, | ||||||
|  |                 "zero_start": True, | ||||||
|  |                 "policy_kwargs": { | ||||||
|  |                     "p_gains": 1, | ||||||
|  |                     "d_gains": 0.1 | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     ) | ||||||
|  |     ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id) | ||||||
|  | 
 | ||||||
| # ## Beerpong | # ## Beerpong | ||||||
| # _versions = ["v0", "v1"] | # _versions = ["v0", "v1"] | ||||||
| # for _v in _versions: | # for _v in _versions: | ||||||
|  | |||||||
| @ -45,6 +45,9 @@ class HoleReacherEnv(BaseReacherDirectEnv): | |||||||
|         elif rew_fct == "vel_acc": |         elif rew_fct == "vel_acc": | ||||||
|             from alr_envs.alr.classic_control.hole_reacher.hr_dist_vel_acc_reward import HolereacherReward |             from alr_envs.alr.classic_control.hole_reacher.hr_dist_vel_acc_reward import HolereacherReward | ||||||
|             self.reward_function = HolereacherReward(allow_self_collision, allow_wall_collision, collision_penalty) |             self.reward_function = HolereacherReward(allow_self_collision, allow_wall_collision, collision_penalty) | ||||||
|  |         elif rew_fct == "unbounded": | ||||||
|  |             from alr_envs.alr.classic_control.hole_reacher.hr_unbounded_reward import HolereacherReward | ||||||
|  |             self.reward_function = HolereacherReward(allow_self_collision, allow_wall_collision) | ||||||
|         else: |         else: | ||||||
|             raise ValueError("Unknown reward function {}".format(rew_fct)) |             raise ValueError("Unknown reward function {}".format(rew_fct)) | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -0,0 +1,60 @@ | |||||||
|  | import numpy as np | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class HolereacherReward: | ||||||
|  |     def __init__(self, allow_self_collision, allow_wall_collision): | ||||||
|  | 
 | ||||||
|  |         # collision | ||||||
|  |         self.allow_self_collision = allow_self_collision | ||||||
|  |         self.allow_wall_collision = allow_wall_collision | ||||||
|  |         self._is_collided = False | ||||||
|  | 
 | ||||||
|  |         self.reward_factors = np.array((1, -5e-6)) | ||||||
|  | 
 | ||||||
|  |     def reset(self): | ||||||
|  |         self._is_collided = False | ||||||
|  | 
 | ||||||
|  |     def get_reward(self, env): | ||||||
|  |         dist_reward = 0 | ||||||
|  |         success = False | ||||||
|  | 
 | ||||||
|  |         self_collision = False | ||||||
|  |         wall_collision = False | ||||||
|  | 
 | ||||||
|  |         if not self.allow_self_collision: | ||||||
|  |             self_collision = env._check_self_collision() | ||||||
|  | 
 | ||||||
|  |         if not self.allow_wall_collision: | ||||||
|  |             wall_collision = env.check_wall_collision() | ||||||
|  | 
 | ||||||
|  |         self._is_collided = self_collision or wall_collision | ||||||
|  | 
 | ||||||
|  |         if env._steps == 180 or self._is_collided: | ||||||
|  |             self.end_eff_pos = np.copy(env.end_effector) | ||||||
|  | 
 | ||||||
|  |         if env._steps == 199 or self._is_collided: | ||||||
|  |             # return reward only in last time step | ||||||
|  |             # Episode also terminates when colliding, hence return reward | ||||||
|  |             dist = np.linalg.norm(self.end_eff_pos - env._goal) | ||||||
|  | 
 | ||||||
|  |             if self._is_collided: | ||||||
|  |                 dist_reward = 0.25 * np.exp(- dist) | ||||||
|  |             else: | ||||||
|  |                 if env.end_effector[1] > 0: | ||||||
|  |                     dist_reward = np.exp(- dist) | ||||||
|  |                 else: | ||||||
|  |                     dist_reward = 1 - self.end_eff_pos[1] | ||||||
|  | 
 | ||||||
|  |             success = not self._is_collided | ||||||
|  | 
 | ||||||
|  |         info = {"is_success": success, | ||||||
|  |                 "is_collided": self._is_collided, | ||||||
|  |                 "end_effector": np.copy(env.end_effector), | ||||||
|  |                 "joints": np.copy(env.current_pos)} | ||||||
|  | 
 | ||||||
|  |         acc_cost = np.sum(env._acc ** 2) | ||||||
|  | 
 | ||||||
|  |         reward_features = np.array((dist_reward, acc_cost)) | ||||||
|  |         reward = np.dot(reward_features, self.reward_factors) | ||||||
|  | 
 | ||||||
|  |         return reward, info | ||||||
| @ -0,0 +1 @@ | |||||||
|  | from .mp_wrapper import MPWrapper | ||||||
| @ -1,54 +1,57 @@ | |||||||
| <mujoco model="reacher"> | <mujoco model="reacher"> | ||||||
| 	<compiler angle="radian" inertiafromgeom="true"/> |    <compiler angle="radian" inertiafromgeom="true"/> | ||||||
| 	<default> |    <default> | ||||||
| 		<joint armature="1" damping="1" limited="true"/> |       <joint armature="1" damping="1" limited="true"/> | ||||||
| 		<geom contype="0" friction="1 0.1 0.1" rgba="0.7 0.7 0 1"/> |       <geom contype="0" friction="1 0.1 0.1" rgba="0.7 0.7 0 1"/> | ||||||
| 	</default> |    </default> | ||||||
| 	<option gravity="0 0 -9.81" integrator="RK4" timestep="0.01"/> |    <option gravity="0 0 -9.81" integrator="RK4" timestep="0.01"/> | ||||||
| 	<worldbody> |    <worldbody> | ||||||
| 		<!-- Arena --> |       <!-- Arena --> | ||||||
| 		<geom conaffinity="0" contype="0" name="ground" pos="0 0 0" rgba="0.9 0.9 0.9 1" size="1 1 10" type="plane"/> |       <geom conaffinity="0" contype="0" name="ground" pos="0 0 0" rgba="0.9 0.9 0.9 1" size="1 1 10" type="plane"/> | ||||||
| 		<geom conaffinity="0" fromto="-.6 -.6 .01 .6 -.6 .01" name="sideS" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/> |       <geom conaffinity="0" fromto="-.6 -.6 .01 .6 -.6 .01" name="sideS" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/> | ||||||
| 		<geom conaffinity="0" fromto=" .6 -.6 .01 .6  .6 .01" name="sideE" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/> |       <geom conaffinity="0" fromto=" .6 -.6 .01 .6  .6 .01" name="sideE" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/> | ||||||
| 		<geom conaffinity="0" fromto="-.6  .6 .01 .6  .6 .01" name="sideN" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/> |       <geom conaffinity="0" fromto="-.6  .6 .01 .6  .6 .01" name="sideN" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/> | ||||||
| 		<geom conaffinity="0" fromto="-.6 -.6 .01 -.6 .6 .01" name="sideW" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/> |       <geom conaffinity="0" fromto="-.6 -.6 .01 -.6 .6 .01" name="sideW" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/> | ||||||
| 		<!-- Arm --> |       <!-- Arm --> | ||||||
| 		<geom conaffinity="0" contype="0" fromto="0 0 0 0 0 0.02" name="root" rgba="0.9 0.4 0.6 1" size=".011" type="cylinder"/> |       <geom conaffinity="0" contype="0" fromto="0 0 0 0 0 0.02" name="root" rgba="0.9 0.4 0.6 1" size=".011" type="cylinder"/> | ||||||
| 		<body name="body0" pos="0 0 .01"> |       <body name="body0" pos="0 0 .01"> | ||||||
| 			<geom fromto="0 0 0 0.1 0 0" name="link0" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/> |          <geom fromto="0 0 0 0.1 0 0" name="link0" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/> | ||||||
| 			<joint axis="0 0 1" limited="false" name="joint0" pos="0 0 0" type="hinge"/> |          <joint axis="0 0 1" limited="false" name="joint0" pos="0 0 0" type="hinge"/> | ||||||
| 			<body name="body1" pos="0.1 0 0"> |          <body name="body1" pos="0.1 0 0"> | ||||||
| 				<joint axis="0 0 1" limited="false" name="joint1" pos="0 0 0" type="hinge"/> |             <joint axis="0 0 1" limited="false" name="joint1" pos="0 0 0" type="hinge"/> | ||||||
| 				<geom fromto="0 0 0 0.1 0 0" name="link1" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/> |             <geom fromto="0 0 0 0.1 0 0" name="link1" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/> | ||||||
| 				<body name="body2" pos="0.1 0 0"> |             <body name="body2" pos="0.1 0 0"> | ||||||
| 					<joint axis="0 0 1" limited="false" name="joint2" pos="0 0 0" type="hinge"/> |                <joint axis="0 0 1" limited="false" name="joint2" pos="0 0 0" type="hinge"/> | ||||||
| 					<geom fromto="0 0 0 0.1 0 0" name="link2" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/> |                <geom fromto="0 0 0 0.1 0 0" name="link2" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/> | ||||||
| 					<body name="body3" pos="0.1 0 0"> |                <body name="body3" pos="0.1 0 0"> | ||||||
| 						<joint axis="0 0 1" limited="false" name="joint3" pos="0 0 0" type="hinge"/> |                   <joint axis="0 0 1" limited="false" name="joint3" pos="0 0 0" type="hinge"/> | ||||||
| 						<geom fromto="0 0 0 0.1 0 0" name="link3" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/> |                   <geom fromto="0 0 0 0.1 0 0" name="link3" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/> | ||||||
| 						<body name="body4" pos="0.1 0 0"> |                   <body name="body4" pos="0.1 0 0"> | ||||||
| 							<joint axis="0 0 1" limited="true" name="joint4" pos="0 0 0" range="-3.0 3.0" type="hinge"/> |                      <joint axis="0 0 1" limited="true" name="joint4" pos="0 0 0" range="-3.0 3.0" type="hinge"/> | ||||||
| 							<geom fromto="0 0 0 0.1 0 0" name="link4" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/> |                      <geom fromto="0 0 0 0.1 0 0" name="link4" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/> | ||||||
| 							<body name="fingertip" pos="0.11 0 0"> |                      <body name="fingertip" pos="0.11 0 0"> | ||||||
| 								<geom contype="0" name="fingertip" pos="0 0 0" rgba="0.0 0.8 0.6 1" size=".01" type="sphere"/> |                         <geom contype="0" name="fingertip" pos="0 0 0" rgba="0.0 0.8 0.6 1" size=".01" type="sphere"/> | ||||||
| 							</body> |                      </body> | ||||||
| 						</body> |                   </body> | ||||||
| 					</body> |                </body> | ||||||
| 				</body> |             </body> | ||||||
| 			</body> |          </body> | ||||||
| 		</body> |       </body> | ||||||
| 		<!-- Target --> |       <!-- Target --> | ||||||
| 		<body name="target" pos=".1 -.1 .01"> |       <body name="target" pos=".1 -.1 .01"> | ||||||
| 			<joint armature="0" axis="1 0 0" damping="0" limited="true" name="target_x" pos="0 0 0" range="-.27 .27" ref=".1" stiffness="0" type="slide"/> | <!--         <joint armature="0" axis="1 0 0" damping="0" limited="true" name="target_x" pos="0 0 0" range="-.27 .27" ref=".1" stiffness="0" type="slide"/>--> | ||||||
| 			<joint armature="0" axis="0 1 0" damping="0" limited="true" name="target_y" pos="0 0 0" range="-.27 .27" ref="-.1" stiffness="0" type="slide"/> | <!--         <joint armature="0" axis="0 1 0" damping="0" limited="true" name="target_y" pos="0 0 0" range="-.27 .27" ref="-.1" stiffness="0" type="slide"/>--> | ||||||
| 			<geom conaffinity="0" contype="0" name="target" pos="0 0 0" rgba="0.9 0.2 0.2 1" size=".009" type="sphere"/> |             <joint armature="0" axis="1 0 0" damping="0" limited="true" name="target_x" pos="0 0 0" range="-.7 .7" ref=".1" stiffness="0" type="slide"/> | ||||||
| 		</body> |          <joint armature="0" axis="0 1 0" damping="0" limited="true" name="target_y" pos="0 0 0" range="-.7 .7" ref="-.1" stiffness="0" type="slide"/> | ||||||
| 	</worldbody> |          <geom conaffinity="0" contype="0" name="target" pos="0 0 0" rgba="0.9 0.2 0.2 1" size=".009" type="sphere"/> | ||||||
| 	<actuator> |       </body> | ||||||
| 		<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint0"/> |         <site name="context_space" pos="0 0.0 0.0" euler="0 0 0" size="0.5 0.5 0.01" rgba="0 1 0 0.1" type="box"/> | ||||||
| 		<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint1"/> |    </worldbody> | ||||||
| 		<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint2"/> |    <actuator> | ||||||
| 		<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint3"/> |       <motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint0"/> | ||||||
| 		<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint4"/> |       <motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint1"/> | ||||||
| 	</actuator> |       <motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint2"/> | ||||||
|  |       <motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint3"/> | ||||||
|  |       <motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint4"/> | ||||||
|  |    </actuator> | ||||||
| </mujoco> | </mujoco> | ||||||
							
								
								
									
										43
									
								
								alr_envs/alr/mujoco/reacher/mp_wrapper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								alr_envs/alr/mujoco/reacher/mp_wrapper.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,43 @@ | |||||||
|  | from typing import Union | ||||||
|  | 
 | ||||||
|  | import numpy as np | ||||||
|  | from mp_env_api import MPEnvWrapper | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class MPWrapper(MPEnvWrapper): | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def active_obs(self): | ||||||
|  |         return np.concatenate([ | ||||||
|  |             [False] * self.n_links,  # cos | ||||||
|  |             [False] * self.n_links,  # sin | ||||||
|  |             [True] * 2,  # goal position | ||||||
|  |             [False] * self.n_links,  # angular velocity | ||||||
|  |             [False] * 3,  # goal distance | ||||||
|  |             # self.get_body_com("target"),  # only return target to make problem harder | ||||||
|  |             [False],  # step | ||||||
|  |         ]) | ||||||
|  | 
 | ||||||
|  |     # @property | ||||||
|  |     # def active_obs(self): | ||||||
|  |     #     return np.concatenate([ | ||||||
|  |     #         [True] * self.n_links,  # cos, True | ||||||
|  |     #         [True] * self.n_links,  # sin, True | ||||||
|  |     #         [True] * 2,  # goal position | ||||||
|  |     #         [True] * self.n_links,  # angular velocity, True | ||||||
|  |     #         [True] * 3,  # goal distance | ||||||
|  |     #         # self.get_body_com("target"),  # only return target to make problem harder | ||||||
|  |     #         [False],  # step | ||||||
|  |     #     ]) | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def current_vel(self) -> Union[float, int, np.ndarray]: | ||||||
|  |         return self.sim.data.qvel.flat[:self.n_links] | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def current_pos(self) -> Union[float, int, np.ndarray]: | ||||||
|  |         return self.sim.data.qpos.flat[:self.n_links] | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def dt(self) -> Union[float, int]: | ||||||
|  |         return self.env.dt | ||||||
| @ -2,70 +2,98 @@ import numpy as np | |||||||
| from matplotlib import pyplot as plt | from matplotlib import pyplot as plt | ||||||
| 
 | 
 | ||||||
| from alr_envs import dmc, meta | from alr_envs import dmc, meta | ||||||
|  | from alr_envs.alr import mujoco | ||||||
| from alr_envs.utils.make_env_helpers import make_promp_env | from alr_envs.utils.make_env_helpers import make_promp_env | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  | def visualize(env): | ||||||
|  |     t = env.t | ||||||
|  |     pos_features = env.mp.basis_generator.basis(t) | ||||||
|  |     plt.plot(t, pos_features) | ||||||
|  |     plt.show() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| # This might work for some environments, however, please verify either way the correct trajectory information | # This might work for some environments, however, please verify either way the correct trajectory information | ||||||
| # for your environment are extracted below | # for your environment are extracted below | ||||||
| SEED = 10 | SEED = 1 | ||||||
| env_id = "ball_in_cup-catch" | # env_id = "ball_in_cup-catch" | ||||||
| wrappers = [dmc.ball_in_cup.MPWrapper] | env_id = "ALRReacherSparse-v0" | ||||||
|  | env_id = "button-press-v2" | ||||||
|  | wrappers = [mujoco.reacher.MPWrapper] | ||||||
|  | wrappers = [meta.goal_object_change_mp_wrapper.MPWrapper] | ||||||
| 
 | 
 | ||||||
| mp_kwargs = { | mp_kwargs = { | ||||||
|     "num_dof": 2, |     "num_dof": 4, | ||||||
|     "num_basis": 10, |     "num_basis": 5, | ||||||
|     "duration": 2, |     "duration": 6.25, | ||||||
|     "width": 0.025, |     "policy_type": "metaworld", | ||||||
|     "policy_type": "motor", |     "weights_scale": 10, | ||||||
|     "weights_scale": 1, |  | ||||||
|     "zero_start": True, |     "zero_start": True, | ||||||
|     "policy_kwargs": { |     # "policy_kwargs": { | ||||||
|         "p_gains": 1, |     #     "p_gains": 1, | ||||||
|         "d_gains": 1 |     #     "d_gains": 0.1 | ||||||
|     } |     # } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| kwargs = dict(time_limit=2, episode_length=100) | # kwargs = dict(time_limit=4, episode_length=200) | ||||||
|  | kwargs = {} | ||||||
| 
 | 
 | ||||||
| env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs, | env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs, **kwargs) | ||||||
|                      **kwargs) | env.action_space.seed(SEED) | ||||||
| 
 | 
 | ||||||
| # Plot difference between real trajectory and target MP trajectory | # Plot difference between real trajectory and target MP trajectory | ||||||
| env.reset() | env.reset() | ||||||
| pos, vel = env.mp_rollout(env.action_space.sample()) | w = env.action_space.sample()  # N(0,1) | ||||||
|  | visualize(env) | ||||||
|  | pos, vel = env.mp_rollout(w) | ||||||
| 
 | 
 | ||||||
| base_shape = env.full_action_space.shape | base_shape = env.full_action_space.shape | ||||||
| actual_pos = np.zeros((len(pos), *base_shape)) | actual_pos = np.zeros((len(pos), *base_shape)) | ||||||
| actual_vel = np.zeros((len(pos), *base_shape)) | actual_vel = np.zeros((len(pos), *base_shape)) | ||||||
| act = np.zeros((len(pos), *base_shape)) | act = np.zeros((len(pos), *base_shape)) | ||||||
| 
 | 
 | ||||||
|  | plt.ion() | ||||||
|  | fig = plt.figure() | ||||||
|  | ax = fig.add_subplot(1, 1, 1) | ||||||
|  | img = ax.imshow(env.env.render("rgb_array")) | ||||||
|  | fig.show() | ||||||
|  | 
 | ||||||
| for t, pos_vel in enumerate(zip(pos, vel)): | for t, pos_vel in enumerate(zip(pos, vel)): | ||||||
|     actions = env.policy.get_action(pos_vel[0], pos_vel[1],, self.current_vel, self.current_pos |     actions = env.policy.get_action(pos_vel[0], pos_vel[1],, self.current_vel, self.current_pos | ||||||
|     actions = np.clip(actions, env.full_action_space.low, env.full_action_space.high) |     actions = np.clip(actions, env.full_action_space.low, env.full_action_space.high) | ||||||
|     _, _, _, _ = env.env.step(actions) |     _, _, _, _ = env.env.step(actions) | ||||||
|  |     if t % 15 == 0: | ||||||
|  |         img.set_data(env.env.render("rgb_array")) | ||||||
|  |         fig.canvas.draw() | ||||||
|  |         fig.canvas.flush_events() | ||||||
|     act[t, :] = actions |     act[t, :] = actions | ||||||
|     # TODO verify for your environment |     # TODO verify for your environment | ||||||
|     actual_pos[t, :] = env.current_pos |     actual_pos[t, :] = env.current_pos | ||||||
|     actual_vel[t, :] = env.current_vel |     actual_vel[t, :] = 0  # env.current_vel | ||||||
| 
 | 
 | ||||||
| plt.figure(figsize=(15, 5)) | plt.figure(figsize=(15, 5)) | ||||||
| 
 | 
 | ||||||
| plt.subplot(131) | plt.subplot(131) | ||||||
| plt.title("Position") | plt.title("Position") | ||||||
| plt.plot(actual_pos, c='C0', label=["true" if i == 0 else "" for i in range(np.prod(base_shape))]) | p1 = plt.plot(actual_pos, c='C0', label="true") | ||||||
| # plt.plot(actual_pos_ball, label="true pos ball") | # plt.plot(actual_pos_ball, label="true pos ball") | ||||||
| plt.plot(pos, c='C1', label=["MP" if i == 0 else "" for i in range(np.prod(base_shape))]) | p2 = plt.plot(pos, c='C1', label="MP")  # , label=["MP" if i == 0 else None for i in range(np.prod(base_shape))]) | ||||||
| plt.xlabel("Episode steps") | plt.xlabel("Episode steps") | ||||||
| plt.legend() | # plt.legend() | ||||||
|  | handles, labels = plt.gca().get_legend_handles_labels() | ||||||
|  | from collections import OrderedDict | ||||||
|  | 
 | ||||||
|  | by_label = OrderedDict(zip(labels, handles)) | ||||||
|  | plt.legend(by_label.values(), by_label.keys()) | ||||||
| 
 | 
 | ||||||
| plt.subplot(132) | plt.subplot(132) | ||||||
| plt.title("Velocity") | plt.title("Velocity") | ||||||
| plt.plot(actual_vel, c='C0', label=[f"true" if i == 0 else "" for i in range(np.prod(base_shape))]) | plt.plot(actual_vel, c='C0', label="true") | ||||||
| plt.plot(vel, c='C1', label=[f"MP" if i == 0 else "" for i in range(np.prod(base_shape))]) | plt.plot(vel, c='C1', label="MP") | ||||||
| plt.xlabel("Episode steps") | plt.xlabel("Episode steps") | ||||||
| plt.legend() |  | ||||||
| 
 | 
 | ||||||
| plt.subplot(133) | plt.subplot(133) | ||||||
| plt.title("Actions") | plt.title(f"Actions {np.std(act, axis=0)}") | ||||||
| plt.plot(act, c="C0"),  # label=[f"actions" if i == 0 else "" for i in range(np.prod(base_action_shape))]) | plt.plot(act, c="C0"),  # label=[f"actions" if i == 0 else "" for i in range(np.prod(base_action_shape))]) | ||||||
| plt.xlabel("Episode steps") | plt.xlabel("Episode steps") | ||||||
| # plt.legend() | # plt.legend() | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user