reacher adjustments
This commit is contained in:
parent
c763f89d60
commit
d313795cec
@ -357,13 +357,13 @@ for _v in _versions:
|
|||||||
"num_basis": 5,
|
"num_basis": 5,
|
||||||
"duration": 2,
|
"duration": 2,
|
||||||
"policy_type": "velocity",
|
"policy_type": "velocity",
|
||||||
"weights_scale": 0.1,
|
"weights_scale": 5,
|
||||||
"zero_start": True
|
"zero_start": True
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["ProMP"].append(_env_id)
|
||||||
|
|
||||||
## ALRReacher
|
## ALRReacher
|
||||||
_versions = ["ALRReacher-v0", "ALRLongReacher-v0", "ALRReacherSparse-v0", "ALRLongReacherSparse-v0"]
|
_versions = ["ALRReacher-v0", "ALRLongReacher-v0", "ALRReacherSparse-v0", "ALRLongReacherSparse-v0"]
|
||||||
for _v in _versions:
|
for _v in _versions:
|
||||||
@ -378,12 +378,12 @@ for _v in _versions:
|
|||||||
"wrappers": [mujoco.reacher.MPWrapper],
|
"wrappers": [mujoco.reacher.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 5 if "long" not in _v.lower() else 7,
|
"num_dof": 5 if "long" not in _v.lower() else 7,
|
||||||
"num_basis": 5,
|
"num_basis": 2,
|
||||||
"duration": 4,
|
"duration": 4,
|
||||||
"alpha_phase": 2,
|
"alpha_phase": 2,
|
||||||
"learn_goal": True,
|
"learn_goal": True,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 1,
|
"weights_scale": 5,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 1,
|
"p_gains": 1,
|
||||||
"d_gains": 0.1
|
"d_gains": 0.1
|
||||||
@ -402,10 +402,10 @@ for _v in _versions:
|
|||||||
"wrappers": [mujoco.reacher.MPWrapper],
|
"wrappers": [mujoco.reacher.MPWrapper],
|
||||||
"mp_kwargs": {
|
"mp_kwargs": {
|
||||||
"num_dof": 5 if "long" not in _v.lower() else 7,
|
"num_dof": 5 if "long" not in _v.lower() else 7,
|
||||||
"num_basis": 5,
|
"num_basis": 1,
|
||||||
"duration": 4,
|
"duration": 4,
|
||||||
"policy_type": "motor",
|
"policy_type": "motor",
|
||||||
"weights_scale": 1,
|
"weights_scale": 5,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_kwargs": {
|
"policy_kwargs": {
|
||||||
"p_gains": 1,
|
"p_gains": 1,
|
||||||
|
@ -44,9 +44,9 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle):
|
|||||||
reward_dist -= self.reward_weight * np.linalg.norm(vec)
|
reward_dist -= self.reward_weight * np.linalg.norm(vec)
|
||||||
if self.steps_before_reward > 0:
|
if self.steps_before_reward > 0:
|
||||||
# avoid giving this penalty for normal step based case
|
# avoid giving this penalty for normal step based case
|
||||||
angular_vel -= np.linalg.norm(self.sim.data.qvel.flat[:self.n_links])
|
# angular_vel -= 10 * np.linalg.norm(self.sim.data.qvel.flat[:self.n_links])
|
||||||
# angular_vel -= np.square(self.sim.data.qvel.flat[:self.n_links]).sum()
|
angular_vel -= 10 * np.square(self.sim.data.qvel.flat[:self.n_links]).sum()
|
||||||
reward_ctrl = - np.square(a).sum()
|
reward_ctrl = - 10 * np.square(a).sum()
|
||||||
|
|
||||||
if self.balance:
|
if self.balance:
|
||||||
reward_balance -= self.balance_weight * np.abs(
|
reward_balance -= self.balance_weight * np.abs(
|
||||||
@ -64,6 +64,35 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle):
|
|||||||
def viewer_setup(self):
|
def viewer_setup(self):
|
||||||
self.viewer.cam.trackbodyid = 0
|
self.viewer.cam.trackbodyid = 0
|
||||||
|
|
||||||
|
# def reset_model(self):
|
||||||
|
# qpos = self.init_qpos
|
||||||
|
# if not hasattr(self, "goal"):
|
||||||
|
# self.goal = np.array([-0.25, 0.25])
|
||||||
|
# # self.goal = self.init_qpos.copy()[:2] + 0.05
|
||||||
|
# qpos[-2:] = self.goal
|
||||||
|
# qvel = self.init_qvel
|
||||||
|
# qvel[-2:] = 0
|
||||||
|
# self.set_state(qpos, qvel)
|
||||||
|
# self._steps = 0
|
||||||
|
#
|
||||||
|
# return self._get_obs()
|
||||||
|
|
||||||
|
def reset_model(self):
|
||||||
|
qpos = self.init_qpos.copy()
|
||||||
|
while True:
|
||||||
|
self.goal = self.np_random.uniform(low=-self.n_links / 10, high=self.n_links / 10, size=2)
|
||||||
|
# self.goal = self.np_random.uniform(low=0, high=self.n_links / 10, size=2)
|
||||||
|
# self.goal = np.random.uniform(low=[-self.n_links / 10, 0], high=[0, self.n_links / 10], size=2)
|
||||||
|
if np.linalg.norm(self.goal) < self.n_links / 10:
|
||||||
|
break
|
||||||
|
qpos[-2:] = self.goal
|
||||||
|
qvel = self.init_qvel.copy()
|
||||||
|
qvel[-2:] = 0
|
||||||
|
self.set_state(qpos, qvel)
|
||||||
|
self._steps = 0
|
||||||
|
|
||||||
|
return self._get_obs()
|
||||||
|
|
||||||
# def reset_model(self):
|
# def reset_model(self):
|
||||||
# qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos
|
# qpos = self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) + self.init_qpos
|
||||||
# while True:
|
# while True:
|
||||||
@ -78,30 +107,15 @@ class ALRReacherEnv(MujocoEnv, utils.EzPickle):
|
|||||||
#
|
#
|
||||||
# return self._get_obs()
|
# return self._get_obs()
|
||||||
|
|
||||||
def reset_model(self):
|
|
||||||
qpos = self.init_qpos
|
|
||||||
if not hasattr(self, "goal"):
|
|
||||||
while True:
|
|
||||||
self.goal = self.np_random.uniform(low=-self.n_links / 10, high=self.n_links / 10, size=2)
|
|
||||||
if np.linalg.norm(self.goal) < self.n_links / 10:
|
|
||||||
break
|
|
||||||
qpos[-2:] = self.goal
|
|
||||||
qvel = self.init_qvel
|
|
||||||
qvel[-2:] = 0
|
|
||||||
self.set_state(qpos, qvel)
|
|
||||||
self._steps = 0
|
|
||||||
|
|
||||||
return self._get_obs()
|
|
||||||
|
|
||||||
def _get_obs(self):
|
def _get_obs(self):
|
||||||
theta = self.sim.data.qpos.flat[:self.n_links]
|
theta = self.sim.data.qpos.flat[:self.n_links]
|
||||||
|
target = self.get_body_com("target")
|
||||||
return np.concatenate([
|
return np.concatenate([
|
||||||
np.cos(theta),
|
np.cos(theta),
|
||||||
np.sin(theta),
|
np.sin(theta),
|
||||||
self.sim.data.qpos.flat[self.n_links:], # this is goal position
|
target[:2], # x-y of goal position
|
||||||
self.sim.data.qvel.flat[:self.n_links], # this is angular velocity
|
self.sim.data.qvel.flat[:self.n_links], # angular velocity
|
||||||
self.get_body_com("fingertip") - self.get_body_com("target"),
|
self.get_body_com("fingertip") - target, # goal distance
|
||||||
# self.get_body_com("target"), # only return target to make problem harder
|
|
||||||
[self._steps],
|
[self._steps],
|
||||||
])
|
])
|
||||||
|
|
||||||
@ -122,4 +136,4 @@ if __name__ == '__main__':
|
|||||||
if d:
|
if d:
|
||||||
env.reset()
|
env.reset()
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
@ -1,54 +1,57 @@
|
|||||||
<mujoco model="reacher">
|
<mujoco model="reacher">
|
||||||
<compiler angle="radian" inertiafromgeom="true"/>
|
<compiler angle="radian" inertiafromgeom="true"/>
|
||||||
<default>
|
<default>
|
||||||
<joint armature="1" damping="1" limited="true"/>
|
<joint armature="1" damping="1" limited="true"/>
|
||||||
<geom contype="0" friction="1 0.1 0.1" rgba="0.7 0.7 0 1"/>
|
<geom contype="0" friction="1 0.1 0.1" rgba="0.7 0.7 0 1"/>
|
||||||
</default>
|
</default>
|
||||||
<option gravity="0 0 -9.81" integrator="RK4" timestep="0.01"/>
|
<option gravity="0 0 -9.81" integrator="RK4" timestep="0.01"/>
|
||||||
<worldbody>
|
<worldbody>
|
||||||
<!-- Arena -->
|
<!-- Arena -->
|
||||||
<geom conaffinity="0" contype="0" name="ground" pos="0 0 0" rgba="0.9 0.9 0.9 1" size="1 1 10" type="plane"/>
|
<geom conaffinity="0" contype="0" name="ground" pos="0 0 0" rgba="0.9 0.9 0.9 1" size="1 1 10" type="plane"/>
|
||||||
<geom conaffinity="0" fromto="-.6 -.6 .01 .6 -.6 .01" name="sideS" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/>
|
<geom conaffinity="0" fromto="-.6 -.6 .01 .6 -.6 .01" name="sideS" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/>
|
||||||
<geom conaffinity="0" fromto=" .6 -.6 .01 .6 .6 .01" name="sideE" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/>
|
<geom conaffinity="0" fromto=" .6 -.6 .01 .6 .6 .01" name="sideE" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/>
|
||||||
<geom conaffinity="0" fromto="-.6 .6 .01 .6 .6 .01" name="sideN" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/>
|
<geom conaffinity="0" fromto="-.6 .6 .01 .6 .6 .01" name="sideN" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/>
|
||||||
<geom conaffinity="0" fromto="-.6 -.6 .01 -.6 .6 .01" name="sideW" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/>
|
<geom conaffinity="0" fromto="-.6 -.6 .01 -.6 .6 .01" name="sideW" rgba="0.9 0.4 0.6 1" size=".02" type="capsule"/>
|
||||||
<!-- Arm -->
|
<!-- Arm -->
|
||||||
<geom conaffinity="0" contype="0" fromto="0 0 0 0 0 0.02" name="root" rgba="0.9 0.4 0.6 1" size=".011" type="cylinder"/>
|
<geom conaffinity="0" contype="0" fromto="0 0 0 0 0 0.02" name="root" rgba="0.9 0.4 0.6 1" size=".011" type="cylinder"/>
|
||||||
<body name="body0" pos="0 0 .01">
|
<body name="body0" pos="0 0 .01">
|
||||||
<geom fromto="0 0 0 0.1 0 0" name="link0" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/>
|
<geom fromto="0 0 0 0.1 0 0" name="link0" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/>
|
||||||
<joint axis="0 0 1" limited="false" name="joint0" pos="0 0 0" type="hinge"/>
|
<joint axis="0 0 1" limited="false" name="joint0" pos="0 0 0" type="hinge"/>
|
||||||
<body name="body1" pos="0.1 0 0">
|
<body name="body1" pos="0.1 0 0">
|
||||||
<joint axis="0 0 1" limited="false" name="joint1" pos="0 0 0" type="hinge"/>
|
<joint axis="0 0 1" limited="false" name="joint1" pos="0 0 0" type="hinge"/>
|
||||||
<geom fromto="0 0 0 0.1 0 0" name="link1" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/>
|
<geom fromto="0 0 0 0.1 0 0" name="link1" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/>
|
||||||
<body name="body2" pos="0.1 0 0">
|
<body name="body2" pos="0.1 0 0">
|
||||||
<joint axis="0 0 1" limited="false" name="joint2" pos="0 0 0" type="hinge"/>
|
<joint axis="0 0 1" limited="false" name="joint2" pos="0 0 0" type="hinge"/>
|
||||||
<geom fromto="0 0 0 0.1 0 0" name="link2" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/>
|
<geom fromto="0 0 0 0.1 0 0" name="link2" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/>
|
||||||
<body name="body3" pos="0.1 0 0">
|
<body name="body3" pos="0.1 0 0">
|
||||||
<joint axis="0 0 1" limited="false" name="joint3" pos="0 0 0" type="hinge"/>
|
<joint axis="0 0 1" limited="false" name="joint3" pos="0 0 0" type="hinge"/>
|
||||||
<geom fromto="0 0 0 0.1 0 0" name="link3" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/>
|
<geom fromto="0 0 0 0.1 0 0" name="link3" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/>
|
||||||
<body name="body4" pos="0.1 0 0">
|
<body name="body4" pos="0.1 0 0">
|
||||||
<joint axis="0 0 1" limited="true" name="joint4" pos="0 0 0" range="-3.0 3.0" type="hinge"/>
|
<joint axis="0 0 1" limited="true" name="joint4" pos="0 0 0" range="-3.0 3.0" type="hinge"/>
|
||||||
<geom fromto="0 0 0 0.1 0 0" name="link4" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/>
|
<geom fromto="0 0 0 0.1 0 0" name="link4" rgba="0.0 0.4 0.6 1" size=".01" type="capsule"/>
|
||||||
<body name="fingertip" pos="0.11 0 0">
|
<body name="fingertip" pos="0.11 0 0">
|
||||||
<geom contype="0" name="fingertip" pos="0 0 0" rgba="0.0 0.8 0.6 1" size=".01" type="sphere"/>
|
<geom contype="0" name="fingertip" pos="0 0 0" rgba="0.0 0.8 0.6 1" size=".01" type="sphere"/>
|
||||||
</body>
|
</body>
|
||||||
</body>
|
</body>
|
||||||
</body>
|
</body>
|
||||||
</body>
|
</body>
|
||||||
</body>
|
</body>
|
||||||
</body>
|
</body>
|
||||||
<!-- Target -->
|
<!-- Target -->
|
||||||
<body name="target" pos=".1 -.1 .01">
|
<body name="target" pos=".1 -.1 .01">
|
||||||
<joint armature="0" axis="1 0 0" damping="0" limited="true" name="target_x" pos="0 0 0" range="-.27 .27" ref=".1" stiffness="0" type="slide"/>
|
<!-- <joint armature="0" axis="1 0 0" damping="0" limited="true" name="target_x" pos="0 0 0" range="-.27 .27" ref=".1" stiffness="0" type="slide"/>-->
|
||||||
<joint armature="0" axis="0 1 0" damping="0" limited="true" name="target_y" pos="0 0 0" range="-.27 .27" ref="-.1" stiffness="0" type="slide"/>
|
<!-- <joint armature="0" axis="0 1 0" damping="0" limited="true" name="target_y" pos="0 0 0" range="-.27 .27" ref="-.1" stiffness="0" type="slide"/>-->
|
||||||
<geom conaffinity="0" contype="0" name="target" pos="0 0 0" rgba="0.9 0.2 0.2 1" size=".009" type="sphere"/>
|
<joint armature="0" axis="1 0 0" damping="0" limited="true" name="target_x" pos="0 0 0" range="-.7 .7" ref=".1" stiffness="0" type="slide"/>
|
||||||
</body>
|
<joint armature="0" axis="0 1 0" damping="0" limited="true" name="target_y" pos="0 0 0" range="-.7 .7" ref="-.1" stiffness="0" type="slide"/>
|
||||||
</worldbody>
|
<geom conaffinity="0" contype="0" name="target" pos="0 0 0" rgba="0.9 0.2 0.2 1" size=".009" type="sphere"/>
|
||||||
<actuator>
|
</body>
|
||||||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint0"/>
|
<site name="context_space" pos="0 0.0 0.0" euler="0 0 0" size="0.5 0.5 0.01" rgba="0 1 0 0.1" type="box"/>
|
||||||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint1"/>
|
</worldbody>
|
||||||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint2"/>
|
<actuator>
|
||||||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint3"/>
|
<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint0"/>
|
||||||
<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint4"/>
|
<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint1"/>
|
||||||
</actuator>
|
<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint2"/>
|
||||||
|
<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint3"/>
|
||||||
|
<motor ctrllimited="true" ctrlrange="-1.0 1.0" gear="200.0" joint="joint4"/>
|
||||||
|
</actuator>
|
||||||
</mujoco>
|
</mujoco>
|
@ -9,15 +9,27 @@ class MPWrapper(MPEnvWrapper):
|
|||||||
@property
|
@property
|
||||||
def active_obs(self):
|
def active_obs(self):
|
||||||
return np.concatenate([
|
return np.concatenate([
|
||||||
[True] * self.n_links, # cos
|
[False] * self.n_links, # cos
|
||||||
[True] * self.n_links, # sin
|
[False] * self.n_links, # sin
|
||||||
[True] * 2, # goal position
|
[True] * 2, # goal position
|
||||||
[True] * self.n_links, # angular velocity
|
[False] * self.n_links, # angular velocity
|
||||||
[True] * 3, # goal distance
|
[False] * 3, # goal distance
|
||||||
# self.get_body_com("target"), # only return target to make problem harder
|
# self.get_body_com("target"), # only return target to make problem harder
|
||||||
[False], # step
|
[False], # step
|
||||||
])
|
])
|
||||||
|
|
||||||
|
# @property
|
||||||
|
# def active_obs(self):
|
||||||
|
# return np.concatenate([
|
||||||
|
# [True] * self.n_links, # cos, True
|
||||||
|
# [True] * self.n_links, # sin, True
|
||||||
|
# [True] * 2, # goal position
|
||||||
|
# [True] * self.n_links, # angular velocity, True
|
||||||
|
# [True] * 3, # goal distance
|
||||||
|
# # self.get_body_com("target"), # only return target to make problem harder
|
||||||
|
# [False], # step
|
||||||
|
# ])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_vel(self) -> Union[float, int, np.ndarray]:
|
def current_vel(self) -> Union[float, int, np.ndarray]:
|
||||||
return self.sim.data.qvel.flat[:self.n_links]
|
return self.sim.data.qvel.flat[:self.n_links]
|
||||||
|
@ -12,34 +12,38 @@ def visualize(env):
|
|||||||
plt.plot(t, pos_features)
|
plt.plot(t, pos_features)
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
# This might work for some environments, however, please verify either way the correct trajectory information
|
# This might work for some environments, however, please verify either way the correct trajectory information
|
||||||
# for your environment are extracted below
|
# for your environment are extracted below
|
||||||
SEED = 1
|
SEED = 1
|
||||||
# env_id = "ball_in_cup-catch"
|
# env_id = "ball_in_cup-catch"
|
||||||
env_id = "ALRReacherSparse-v0"
|
env_id = "ALRReacherSparse-v0"
|
||||||
|
env_id = "button-press-v2"
|
||||||
wrappers = [mujoco.reacher.MPWrapper]
|
wrappers = [mujoco.reacher.MPWrapper]
|
||||||
|
wrappers = [meta.goal_object_change_mp_wrapper.MPWrapper]
|
||||||
|
|
||||||
mp_kwargs = {
|
mp_kwargs = {
|
||||||
"num_dof": 5,
|
"num_dof": 4,
|
||||||
"num_basis": 8,
|
"num_basis": 5,
|
||||||
"duration": 4,
|
"duration": 6.25,
|
||||||
"policy_type": "motor",
|
"policy_type": "metaworld",
|
||||||
"weights_scale": 1,
|
"weights_scale": 10,
|
||||||
"zero_start": True,
|
"zero_start": True,
|
||||||
"policy_kwargs": {
|
# "policy_kwargs": {
|
||||||
"p_gains": 1,
|
# "p_gains": 1,
|
||||||
"d_gains": 0.1
|
# "d_gains": 0.1
|
||||||
}
|
# }
|
||||||
}
|
}
|
||||||
|
|
||||||
# kwargs = dict(time_limit=4, episode_length=200)
|
# kwargs = dict(time_limit=4, episode_length=200)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs, **kwargs)
|
env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs, **kwargs)
|
||||||
|
env.action_space.seed(SEED)
|
||||||
|
|
||||||
# Plot difference between real trajectory and target MP trajectory
|
# Plot difference between real trajectory and target MP trajectory
|
||||||
env.reset()
|
env.reset()
|
||||||
w = env.action_space.sample() * 10
|
w = env.action_space.sample() # N(0,1)
|
||||||
visualize(env)
|
visualize(env)
|
||||||
pos, vel = env.mp_rollout(w)
|
pos, vel = env.mp_rollout(w)
|
||||||
|
|
||||||
@ -48,14 +52,24 @@ actual_pos = np.zeros((len(pos), *base_shape))
|
|||||||
actual_vel = np.zeros((len(pos), *base_shape))
|
actual_vel = np.zeros((len(pos), *base_shape))
|
||||||
act = np.zeros((len(pos), *base_shape))
|
act = np.zeros((len(pos), *base_shape))
|
||||||
|
|
||||||
|
plt.ion()
|
||||||
|
fig = plt.figure()
|
||||||
|
ax = fig.add_subplot(1, 1, 1)
|
||||||
|
img = ax.imshow(env.env.render("rgb_array"))
|
||||||
|
fig.show()
|
||||||
|
|
||||||
for t, pos_vel in enumerate(zip(pos, vel)):
|
for t, pos_vel in enumerate(zip(pos, vel)):
|
||||||
actions = env.policy.get_action(pos_vel[0], pos_vel[1])
|
actions = env.policy.get_action(pos_vel[0], pos_vel[1])
|
||||||
actions = np.clip(actions, env.full_action_space.low, env.full_action_space.high)
|
actions = np.clip(actions, env.full_action_space.low, env.full_action_space.high)
|
||||||
_, _, _, _ = env.env.step(actions)
|
_, _, _, _ = env.env.step(actions)
|
||||||
|
if t % 15 == 0:
|
||||||
|
img.set_data(env.env.render("rgb_array"))
|
||||||
|
fig.canvas.draw()
|
||||||
|
fig.canvas.flush_events()
|
||||||
act[t, :] = actions
|
act[t, :] = actions
|
||||||
# TODO verify for your environment
|
# TODO verify for your environment
|
||||||
actual_pos[t, :] = env.current_pos
|
actual_pos[t, :] = env.current_pos
|
||||||
actual_vel[t, :] = env.current_vel
|
actual_vel[t, :] = 0 # env.current_vel
|
||||||
|
|
||||||
plt.figure(figsize=(15, 5))
|
plt.figure(figsize=(15, 5))
|
||||||
|
|
||||||
@ -79,7 +93,7 @@ plt.plot(vel, c='C1', label="MP")
|
|||||||
plt.xlabel("Episode steps")
|
plt.xlabel("Episode steps")
|
||||||
|
|
||||||
plt.subplot(133)
|
plt.subplot(133)
|
||||||
plt.title("Actions")
|
plt.title(f"Actions {np.std(act, axis=0)}")
|
||||||
plt.plot(act, c="C0"), # label=[f"actions" if i == 0 else "" for i in range(np.prod(base_action_shape))])
|
plt.plot(act, c="C0"), # label=[f"actions" if i == 0 else "" for i in range(np.prod(base_action_shape))])
|
||||||
plt.xlabel("Episode steps")
|
plt.xlabel("Episode steps")
|
||||||
# plt.legend()
|
# plt.legend()
|
||||||
|
Loading…
Reference in New Issue
Block a user