reacher adjustments
This commit is contained in:
		
							parent
							
								
									c763f89d60
								
							
						
					
					
						commit
						d313795cec
					
				@ -357,7 +357,7 @@ for _v in _versions:
 | 
			
		||||
                "num_basis": 5,
 | 
			
		||||
                "duration": 2,
 | 
			
		||||
                "policy_type": "velocity",
 | 
			
		||||
                "weights_scale": 0.1,
 | 
			
		||||
                "weights_scale": 5,
 | 
			
		||||
                "zero_start": True
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
@ -378,12 +378,12 @@ for _v in _versions:
 | 
			
		||||
            "wrappers": [mujoco.reacher.MPWrapper],
 | 
			
		||||
            "mp_kwargs": {
 | 
			
		||||
                "num_dof": 5 if "long" not in _v.lower() else 7,
 | 
			
		||||
                "num_basis": 5,
 | 
			
		||||
                "num_basis": 2,
 | 
			
		||||
                "duration": 4,
 | 
			
		||||
                "alpha_phase": 2,
 | 
			
		||||
                "learn_goal": True,
 | 
			
		||||
                "policy_type": "motor",
 | 
			
		||||
                "weights_scale": 1,
 | 
			
		||||
                "weights_scale": 5,
 | 
			
		||||
                "policy_kwargs": {
 | 
			
		||||
                    "p_gains": 1,
 | 
			
		||||
                    "d_gains": 0.1
 | 
			
		||||
@ -402,10 +402,10 @@ for _v in _versions:
 | 
			
		||||
            "wrappers": [mujoco.reacher.MPWrapper],
 | 
			
		||||
            "mp_kwargs": {
 | 
			
		||||
                "num_dof": 5 if "long" not in _v.lower() else 7,
 | 
			
		||||
                "num_basis": 5,
 | 
			
		||||
                "num_basis": 1,
 | 
			
		||||
                "duration": 4,
 | 
			
		||||
                "policy_type": "motor",
 | 
			
		||||
                "weights_scale": 1,
 | 
			
		||||
                "weights_scale": 5,
 | 
			
		||||
                "zero_start": True,
 | 
			
		||||
                "policy_kwargs": {
 | 
			
		||||
                    "p_gains": 1,
 | 
			
		||||
 | 
			
		||||
@ -44,9 +44,9 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle):
 | 
			
		||||
            reward_dist -= self.reward_weight * np.linalg.norm(vec)
 | 
			
		||||
            if self.steps_before_reward > 0:
 | 
			
		||||
                # avoid giving this penalty for normal step based case
 | 
			
		||||
                angular_vel -= np.linalg.norm(self.sim.data.qvel.flat[:self.n_links])
 | 
			
		||||
                # angular_vel -= np.square(self.sim.data.qvel.flat[:self.n_links]).sum()
 | 
			
		||||
        reward_ctrl = - np.square(a).sum()
 | 
			
		||||
                # angular_vel -= 10 * np.linalg.norm(self.sim.data.qvel.flat[:self.n_links])
 | 
			
		||||
                angular_vel -= 10 * np.square(self.sim.data.qvel.flat[:self.n_links]).sum()
 | 
			
		||||
        reward_ctrl = - 10 * np.square(a).sum()
 | 
			
		||||
 | 
			
		||||
        if self.balance:
 | 
			
		||||
            reward_balance -= self.balance_weight * np.abs(
 | 
			
		||||
@ -64,6 +64,35 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle):
 | 
			
		||||
    def viewer_setup(self):
 | 
			
		||||
        self.viewer.cam.trackbodyid = 0
 | 
			
		||||
 | 
			
		||||
    # def reset_model(self):
 | 
			
		||||
    #     qpos = self.init_qpos
 | 
			
		||||
    #     if not hasattr(self, "goal"):
 | 
			
		||||
    #         self.goal = np.array([-0.25, 0.25])
 | 
			
		||||
    #         # self.goal = self.init_qpos.copy()[:2] + 0.05
 | 
			
		||||
    #     qpos[-2:] = self.goal
 | 
			
		||||
    #     qvel = self.init_qvel
 | 
			
		||||
    #     qvel[-2:] = 0
 | 
			
		||||
    #     self.set_state(qpos, qvel)
 | 
			
		||||
    #     self._steps = 0
 | 
			
		||||
    #
 | 
			
		||||
    #     return self._get_obs()
 | 
			
		||||
 | 
			
		||||
    def reset_model(self):
 | 
			
		||||
        qpos = self.init_qpos.copy()
 | 
			
		||||
        while True:
 | 
			
		||||
            self.goal = self.np_random.uniform(low=-self.n_links / 10, high=self.n_links / 10, size=2)
 | 
			
		||||
            # self.goal = self.np_random.uniform(low=0, high=self.n_links / 10, size=2)
 | 
			
		||||
            # self.goal = np.random.uniform(low=[-self.n_links / 10, 0], high=[0, self.n_links / 10], size=2)
 | 
			
		||||
            if np.linalg.norm(self.goal) < self.n_links / 10:
 | 
			
		||||
                break
 | 
			
		||||
        qpos[-2:] = self.goal
 | 
			
		||||
        qvel = self.init_qvel.copy()
 | 
			
		||||
        qvel[-2:] = 0
 | 
			
		||||
        self.set_state(qpos, qvel)
 | 
			
		||||
        self._steps = 0
 | 
			
		||||
 | 
			
		||||
        return self._get_obs()
 | 
			
		||||
 | 
			
		||||
    # def reset_model(self):
 | 
			
		||||
    #     qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos
 | 
			
		||||
    #     while True:
 | 
			
		||||
@ -78,30 +107,15 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle):
 | 
			
		||||
    #
 | 
			
		||||
    #     return self._get_obs()
 | 
			
		||||
 | 
			
		||||
    def reset_model(self):
 | 
			
		||||
        qpos = self.init_qpos
 | 
			
		||||
        if not hasattr(self, "goal"):
 | 
			
		||||
            while True:
 | 
			
		||||
                self.goal = self.np_random.uniform(low=-self.n_links / 10, high=self.n_links / 10, size=2)
 | 
			
		||||
                if np.linalg.norm(self.goal) < self.n_links / 10:
 | 
			
		||||
                    break
 | 
			
		||||
        qpos[-2:] = self.goal
 | 
			
		||||
        qvel = self.init_qvel
 | 
			
		||||
        qvel[-2:] = 0
 | 
			
		||||
        self.set_state(qpos, qvel)
 | 
			
		||||
        self._steps = 0
 | 
			
		||||
 | 
			
		||||
        return self._get_obs()
 | 
			
		||||
 | 
			
		||||
    def _get_obs(self):
 | 
			
		||||
        theta = self.sim.data.qpos.flat[:self.n_links]
 | 
			
		||||
        target = self.get_body_com("target")
 | 
			
		||||
        return np.concatenate([
 | 
			
		||||
            np.cos(theta),
 | 
			
		||||
            np.sin(theta),
 | 
			
		||||
            self.sim.data.qpos.flat[self.n_links:],  # this is goal position
 | 
			
		||||
            self.sim.data.qvel.flat[:self.n_links],  # this is angular velocity
 | 
			
		||||
            self.get_body_com("fingertip") - self.get_body_com("target"),
 | 
			
		||||
            # self.get_body_com("target"),  # only return target to make problem harder
 | 
			
		||||
            target[:2],  # x-y of goal position
 | 
			
		||||
            self.sim.data.qvel.flat[:self.n_links],  # angular velocity
 | 
			
		||||
            self.get_body_com("fingertip") - target,  # goal distance
 | 
			
		||||
            [self._steps],
 | 
			
		||||
        ])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,54 +1,57 @@
 | 
			
		||||
<mujoco model="reacher">
 | 
			
		||||
	<compiler angle="radian" inertiafromgeom="true"/>
 | 
			
		||||
	<default>
 | 
			
		||||
		<joint armature="1" damping="1" limited="true"/>
 | 
			
		||||
		<geom contype="0" friction="1 0.1 0.1" rgba="0.7 0.7 0 1"/>
 | 
			
		||||
	</default>
 | 
			
		||||
	<option gravity="0 0 -9.81" integrator="RK4" timestep="0.01"/>
 | 
			
		||||
	<worldbody>
 | 
			
		||||
		<!-- Arena -->
 | 
			
		||||
		<geom conaffinity="0" contype="0" name="ground" pos="0 0 0" rgba="0.9 0.9 0.9 1" size="1 1 10" type="plane"/>
 | 
			
		||||
		<geom conaffinity="0" fromto="-.6 -.6 .01 .6 -.6 .01" name="sideS" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/>
 | 
			
		||||
		<geom conaffinity="0" fromto=" .6 -.6 .01 .6  .6 .01" name="sideE" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/>
 | 
			
		||||
		<geom conaffinity="0" fromto="-.6  .6 .01 .6  .6 .01" name="sideN" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/>
 | 
			
		||||
		<geom conaffinity="0" fromto="-.6 -.6 .01 -.6 .6 .01" name="sideW" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/>
 | 
			
		||||
		<!-- Arm -->
 | 
			
		||||
		<geom conaffinity="0" contype="0" fromto="0 0 0 0 0 0.02" name="root" rgba="0.9 0.4 0.6 1" size=".011" type="cylinder"/>
 | 
			
		||||
		<body name="body0" pos="0 0 .01">
 | 
			
		||||
			<geom fromto="0 0 0 0.1 0 0" name="link0" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/>
 | 
			
		||||
			<joint axis="0 0 1" limited="false" name="joint0" pos="0 0 0" type="hinge"/>
 | 
			
		||||
			<body name="body1" pos="0.1 0 0">
 | 
			
		||||
				<joint axis="0 0 1" limited="false" name="joint1" pos="0 0 0" type="hinge"/>
 | 
			
		||||
				<geom fromto="0 0 0 0.1 0 0" name="link1" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/>
 | 
			
		||||
				<body name="body2" pos="0.1 0 0">
 | 
			
		||||
					<joint axis="0 0 1" limited="false" name="joint2" pos="0 0 0" type="hinge"/>
 | 
			
		||||
					<geom fromto="0 0 0 0.1 0 0" name="link2" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/>
 | 
			
		||||
					<body name="body3" pos="0.1 0 0">
 | 
			
		||||
						<joint axis="0 0 1" limited="false" name="joint3" pos="0 0 0" type="hinge"/>
 | 
			
		||||
						<geom fromto="0 0 0 0.1 0 0" name="link3" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/>
 | 
			
		||||
						<body name="body4" pos="0.1 0 0">
 | 
			
		||||
							<joint axis="0 0 1" limited="true" name="joint4" pos="0 0 0" range="-3.0 3.0" type="hinge"/>
 | 
			
		||||
							<geom fromto="0 0 0 0.1 0 0" name="link4" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/>
 | 
			
		||||
							<body name="fingertip" pos="0.11 0 0">
 | 
			
		||||
								<geom contype="0" name="fingertip" pos="0 0 0" rgba="0.0 0.8 0.6 1" size=".01" type="sphere"/>
 | 
			
		||||
							</body>
 | 
			
		||||
						</body>
 | 
			
		||||
					</body>
 | 
			
		||||
				</body>
 | 
			
		||||
			</body>
 | 
			
		||||
		</body>
 | 
			
		||||
		<!-- Target -->
 | 
			
		||||
		<body name="target" pos=".1 -.1 .01">
 | 
			
		||||
			<joint armature="0" axis="1 0 0" damping="0" limited="true" name="target_x" pos="0 0 0" range="-.27 .27" ref=".1" stiffness="0" type="slide"/>
 | 
			
		||||
			<joint armature="0" axis="0 1 0" damping="0" limited="true" name="target_y" pos="0 0 0" range="-.27 .27" ref="-.1" stiffness="0" type="slide"/>
 | 
			
		||||
			<geom conaffinity="0" contype="0" name="target" pos="0 0 0" rgba="0.9 0.2 0.2 1" size=".009" type="sphere"/>
 | 
			
		||||
		</body>
 | 
			
		||||
	</worldbody>
 | 
			
		||||
	<actuator>
 | 
			
		||||
		<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint0"/>
 | 
			
		||||
		<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint1"/>
 | 
			
		||||
		<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint2"/>
 | 
			
		||||
		<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint3"/>
 | 
			
		||||
		<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint4"/>
 | 
			
		||||
	</actuator>
 | 
			
		||||
   <compiler angle="radian" inertiafromgeom="true"/>
 | 
			
		||||
   <default>
 | 
			
		||||
      <joint armature="1" damping="1" limited="true"/>
 | 
			
		||||
      <geom contype="0" friction="1 0.1 0.1" rgba="0.7 0.7 0 1"/>
 | 
			
		||||
   </default>
 | 
			
		||||
   <option gravity="0 0 -9.81" integrator="RK4" timestep="0.01"/>
 | 
			
		||||
   <worldbody>
 | 
			
		||||
      <!-- Arena -->
 | 
			
		||||
      <geom conaffinity="0" contype="0" name="ground" pos="0 0 0" rgba="0.9 0.9 0.9 1" size="1 1 10" type="plane"/>
 | 
			
		||||
      <geom conaffinity="0" fromto="-.6 -.6 .01 .6 -.6 .01" name="sideS" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/>
 | 
			
		||||
      <geom conaffinity="0" fromto=" .6 -.6 .01 .6  .6 .01" name="sideE" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/>
 | 
			
		||||
      <geom conaffinity="0" fromto="-.6  .6 .01 .6  .6 .01" name="sideN" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/>
 | 
			
		||||
      <geom conaffinity="0" fromto="-.6 -.6 .01 -.6 .6 .01" name="sideW" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/>
 | 
			
		||||
      <!-- Arm -->
 | 
			
		||||
      <geom conaffinity="0" contype="0" fromto="0 0 0 0 0 0.02" name="root" rgba="0.9 0.4 0.6 1" size=".011" type="cylinder"/>
 | 
			
		||||
      <body name="body0" pos="0 0 .01">
 | 
			
		||||
         <geom fromto="0 0 0 0.1 0 0" name="link0" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/>
 | 
			
		||||
         <joint axis="0 0 1" limited="false" name="joint0" pos="0 0 0" type="hinge"/>
 | 
			
		||||
         <body name="body1" pos="0.1 0 0">
 | 
			
		||||
            <joint axis="0 0 1" limited="false" name="joint1" pos="0 0 0" type="hinge"/>
 | 
			
		||||
            <geom fromto="0 0 0 0.1 0 0" name="link1" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/>
 | 
			
		||||
            <body name="body2" pos="0.1 0 0">
 | 
			
		||||
               <joint axis="0 0 1" limited="false" name="joint2" pos="0 0 0" type="hinge"/>
 | 
			
		||||
               <geom fromto="0 0 0 0.1 0 0" name="link2" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/>
 | 
			
		||||
               <body name="body3" pos="0.1 0 0">
 | 
			
		||||
                  <joint axis="0 0 1" limited="false" name="joint3" pos="0 0 0" type="hinge"/>
 | 
			
		||||
                  <geom fromto="0 0 0 0.1 0 0" name="link3" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/>
 | 
			
		||||
                  <body name="body4" pos="0.1 0 0">
 | 
			
		||||
                     <joint axis="0 0 1" limited="true" name="joint4" pos="0 0 0" range="-3.0 3.0" type="hinge"/>
 | 
			
		||||
                     <geom fromto="0 0 0 0.1 0 0" name="link4" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/>
 | 
			
		||||
                     <body name="fingertip" pos="0.11 0 0">
 | 
			
		||||
                        <geom contype="0" name="fingertip" pos="0 0 0" rgba="0.0 0.8 0.6 1" size=".01" type="sphere"/>
 | 
			
		||||
                     </body>
 | 
			
		||||
                  </body>
 | 
			
		||||
               </body>
 | 
			
		||||
            </body>
 | 
			
		||||
         </body>
 | 
			
		||||
      </body>
 | 
			
		||||
      <!-- Target -->
 | 
			
		||||
      <body name="target" pos=".1 -.1 .01">
 | 
			
		||||
<!--         <joint armature="0" axis="1 0 0" damping="0" limited="true" name="target_x" pos="0 0 0" range="-.27 .27" ref=".1" stiffness="0" type="slide"/>-->
 | 
			
		||||
<!--         <joint armature="0" axis="0 1 0" damping="0" limited="true" name="target_y" pos="0 0 0" range="-.27 .27" ref="-.1" stiffness="0" type="slide"/>-->
 | 
			
		||||
            <joint armature="0" axis="1 0 0" damping="0" limited="true" name="target_x" pos="0 0 0" range="-.7 .7" ref=".1" stiffness="0" type="slide"/>
 | 
			
		||||
         <joint armature="0" axis="0 1 0" damping="0" limited="true" name="target_y" pos="0 0 0" range="-.7 .7" ref="-.1" stiffness="0" type="slide"/>
 | 
			
		||||
         <geom conaffinity="0" contype="0" name="target" pos="0 0 0" rgba="0.9 0.2 0.2 1" size=".009" type="sphere"/>
 | 
			
		||||
      </body>
 | 
			
		||||
        <site name="context_space" pos="0 0.0 0.0" euler="0 0 0" size="0.5 0.5 0.01" rgba="0 1 0 0.1" type="box"/>
 | 
			
		||||
   </worldbody>
 | 
			
		||||
   <actuator>
 | 
			
		||||
      <motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint0"/>
 | 
			
		||||
      <motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint1"/>
 | 
			
		||||
      <motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint2"/>
 | 
			
		||||
      <motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint3"/>
 | 
			
		||||
      <motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint4"/>
 | 
			
		||||
   </actuator>
 | 
			
		||||
</mujoco>
 | 
			
		||||
@ -9,15 +9,27 @@ class MPWrapper(MPEnvWrapper):
 | 
			
		||||
    @property
 | 
			
		||||
    def active_obs(self):
 | 
			
		||||
        return np.concatenate([
 | 
			
		||||
            [True] * self.n_links,  # cos
 | 
			
		||||
            [True] * self.n_links,  # sin
 | 
			
		||||
            [False] * self.n_links,  # cos
 | 
			
		||||
            [False] * self.n_links,  # sin
 | 
			
		||||
            [True] * 2,  # goal position
 | 
			
		||||
            [True] * self.n_links,  # angular velocity
 | 
			
		||||
            [True] * 3,  # goal distance
 | 
			
		||||
            [False] * self.n_links,  # angular velocity
 | 
			
		||||
            [False] * 3,  # goal distance
 | 
			
		||||
            # self.get_body_com("target"),  # only return target to make problem harder
 | 
			
		||||
            [False], # step
 | 
			
		||||
            [False],  # step
 | 
			
		||||
        ])
 | 
			
		||||
 | 
			
		||||
    # @property
 | 
			
		||||
    # def active_obs(self):
 | 
			
		||||
    #     return np.concatenate([
 | 
			
		||||
    #         [True] * self.n_links,  # cos, True
 | 
			
		||||
    #         [True] * self.n_links,  # sin, True
 | 
			
		||||
    #         [True] * 2,  # goal position
 | 
			
		||||
    #         [True] * self.n_links,  # angular velocity, True
 | 
			
		||||
    #         [True] * 3,  # goal distance
 | 
			
		||||
    #         # self.get_body_com("target"),  # only return target to make problem harder
 | 
			
		||||
    #         [False],  # step
 | 
			
		||||
    #     ])
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def current_vel(self) -> Union[float, int, np.ndarray]:
 | 
			
		||||
        return self.sim.data.qvel.flat[:self.n_links]
 | 
			
		||||
 | 
			
		||||
@ -12,34 +12,38 @@ def visualize(env):
 | 
			
		||||
    plt.plot(t, pos_features)
 | 
			
		||||
    plt.show()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# This might work for some environments, however, please verify either way the correct trajectory information
 | 
			
		||||
# for your environment are extracted below
 | 
			
		||||
SEED = 1
 | 
			
		||||
# env_id = "ball_in_cup-catch"
 | 
			
		||||
env_id = "ALRReacherSparse-v0"
 | 
			
		||||
env_id = "button-press-v2"
 | 
			
		||||
wrappers = [mujoco.reacher.MPWrapper]
 | 
			
		||||
wrappers = [meta.goal_object_change_mp_wrapper.MPWrapper]
 | 
			
		||||
 | 
			
		||||
mp_kwargs = {
 | 
			
		||||
    "num_dof": 5,
 | 
			
		||||
    "num_basis": 8,
 | 
			
		||||
    "duration": 4,
 | 
			
		||||
    "policy_type": "motor",
 | 
			
		||||
    "weights_scale": 1,
 | 
			
		||||
    "num_dof": 4,
 | 
			
		||||
    "num_basis": 5,
 | 
			
		||||
    "duration": 6.25,
 | 
			
		||||
    "policy_type": "metaworld",
 | 
			
		||||
    "weights_scale": 10,
 | 
			
		||||
    "zero_start": True,
 | 
			
		||||
    "policy_kwargs": {
 | 
			
		||||
        "p_gains": 1,
 | 
			
		||||
        "d_gains": 0.1
 | 
			
		||||
    }
 | 
			
		||||
    # "policy_kwargs": {
 | 
			
		||||
    #     "p_gains": 1,
 | 
			
		||||
    #     "d_gains": 0.1
 | 
			
		||||
    # }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
# kwargs = dict(time_limit=4, episode_length=200)
 | 
			
		||||
kwargs = {}
 | 
			
		||||
 | 
			
		||||
env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs, **kwargs)
 | 
			
		||||
env.action_space.seed(SEED)
 | 
			
		||||
 | 
			
		||||
# Plot difference between real trajectory and target MP trajectory
 | 
			
		||||
env.reset()
 | 
			
		||||
w = env.action_space.sample() * 10
 | 
			
		||||
w = env.action_space.sample()  # N(0,1)
 | 
			
		||||
visualize(env)
 | 
			
		||||
pos, vel = env.mp_rollout(w)
 | 
			
		||||
 | 
			
		||||
@ -48,14 +52,24 @@ actual_pos = np.zeros((len(pos), *base_shape))
 | 
			
		||||
actual_vel = np.zeros((len(pos), *base_shape))
 | 
			
		||||
act = np.zeros((len(pos), *base_shape))
 | 
			
		||||
 | 
			
		||||
plt.ion()
 | 
			
		||||
fig = plt.figure()
 | 
			
		||||
ax = fig.add_subplot(1, 1, 1)
 | 
			
		||||
img = ax.imshow(env.env.render("rgb_array"))
 | 
			
		||||
fig.show()
 | 
			
		||||
 | 
			
		||||
for t, pos_vel in enumerate(zip(pos, vel)):
 | 
			
		||||
    actions = env.policy.get_action(pos_vel[0], pos_vel[1])
 | 
			
		||||
    actions = np.clip(actions, env.full_action_space.low, env.full_action_space.high)
 | 
			
		||||
    _, _, _, _ = env.env.step(actions)
 | 
			
		||||
    if t % 15 == 0:
 | 
			
		||||
        img.set_data(env.env.render("rgb_array"))
 | 
			
		||||
        fig.canvas.draw()
 | 
			
		||||
        fig.canvas.flush_events()
 | 
			
		||||
    act[t, :] = actions
 | 
			
		||||
    # TODO verify for your environment
 | 
			
		||||
    actual_pos[t, :] = env.current_pos
 | 
			
		||||
    actual_vel[t, :] = env.current_vel
 | 
			
		||||
    actual_vel[t, :] = 0  # env.current_vel
 | 
			
		||||
 | 
			
		||||
plt.figure(figsize=(15, 5))
 | 
			
		||||
 | 
			
		||||
@ -79,7 +93,7 @@ plt.plot(vel, c='C1', label="MP")
 | 
			
		||||
plt.xlabel("Episode steps")
 | 
			
		||||
 | 
			
		||||
plt.subplot(133)
 | 
			
		||||
plt.title("Actions")
 | 
			
		||||
plt.title(f"Actions {np.std(act, axis=0)}")
 | 
			
		||||
plt.plot(act, c="C0"),  # label=[f"actions" if i == 0 else "" for i in range(np.prod(base_action_shape))])
 | 
			
		||||
plt.xlabel("Episode steps")
 | 
			
		||||
# plt.legend()
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user