75 lines
2.4 KiB
Python
75 lines
2.4 KiB
Python
import numpy as np
|
|
from matplotlib import pyplot as plt
|
|
|
|
from alr_envs import dmc
|
|
from alr_envs.utils.make_env_helpers import make_detpmp_env
|
|
|
|
# This might work for some environments, however, please verify either way the correct trajectory information
|
|
# for your environment are extracted below
|
|
SEED = 10
|
|
env_id = "cartpole-swingup"
|
|
wrappers = [dmc.suite.cartpole.MPWrapper]
|
|
|
|
mp_kwargs = {
|
|
"num_dof": 1,
|
|
"num_basis": 5,
|
|
"duration": 2,
|
|
"width": 0.025,
|
|
"policy_type": "motor",
|
|
"weights_scale": 0.2,
|
|
"zero_start": True,
|
|
"policy_kwargs": {
|
|
"p_gains": 10,
|
|
"d_gains": 10 # a good starting point is the sqrt of p_gains
|
|
}
|
|
}
|
|
|
|
kwargs = dict(time_limit=2, episode_length=200)
|
|
|
|
env = make_detpmp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs,
|
|
**kwargs)
|
|
|
|
# Plot difference between real trajectory and target MP trajectory
|
|
env.reset()
|
|
pos, vel = env.mp_rollout(env.action_space.sample())
|
|
|
|
base_shape = env.full_action_space.shape
|
|
actual_pos = np.zeros((len(pos), *base_shape))
|
|
actual_pos_ball = np.zeros((len(pos), *base_shape))
|
|
actual_vel = np.zeros((len(pos), *base_shape))
|
|
act = np.zeros((len(pos), *base_shape))
|
|
|
|
for t, pos_vel in enumerate(zip(pos, vel)):
|
|
actions = env.policy.get_action(pos_vel[0], pos_vel[1])
|
|
actions = np.clip(actions, env.full_action_space.low, env.full_action_space.high)
|
|
_, _, _, _ = env.env.step(actions)
|
|
act[t, :] = actions
|
|
# TODO verify for your environment
|
|
actual_pos[t, :] = env.current_pos
|
|
# actual_pos_ball[t, :] = env.physics.data.qpos[2:]
|
|
actual_vel[t, :] = env.current_vel
|
|
|
|
plt.figure(figsize=(15, 5))
|
|
|
|
plt.subplot(131)
|
|
plt.title("Position")
|
|
plt.plot(actual_pos, c='C0', label=["true" if i == 0 else "" for i in range(np.prod(base_shape))])
|
|
# plt.plot(actual_pos_ball, label="true pos ball")
|
|
plt.plot(pos, c='C1', label=["MP" if i == 0 else "" for i in range(np.prod(base_shape))])
|
|
plt.xlabel("Episode steps")
|
|
plt.legend()
|
|
|
|
plt.subplot(132)
|
|
plt.title("Velocity")
|
|
plt.plot(actual_vel, c='C0', label=[f"true" if i == 0 else "" for i in range(np.prod(base_shape))])
|
|
plt.plot(vel, c='C1', label=[f"MP" if i == 0 else "" for i in range(np.prod(base_shape))])
|
|
plt.xlabel("Episode steps")
|
|
plt.legend()
|
|
|
|
plt.subplot(133)
|
|
plt.title("Actions")
|
|
plt.plot(act, c="C0"), # label=[f"actions" if i == 0 else "" for i in range(np.prod(base_action_shape))])
|
|
plt.xlabel("Episode steps")
|
|
# plt.legend()
|
|
plt.show()
|