fancy_gym/alr_envs/examples/pd_control_gain_tuning.py
2022-01-25 15:23:57 +01:00

87 lines
2.4 KiB
Python

import numpy as np
from matplotlib import pyplot as plt
from alr_envs import dmc, meta
from alr_envs.alr import mujoco
from alr_envs.utils.make_env_helpers import make_promp_env
def visualize(env):
t = env.t
pos_features = env.mp.basis_generator.basis(t)
plt.plot(t, pos_features)
plt.show()
# This might work for some environments, however, please verify either way the correct trajectory information
# for your environment are extracted below
SEED = 1
# env_id = "ball_in_cup-catch"
env_id = "ALRReacherSparse-v0"
wrappers = [mujoco.reacher.MPWrapper]
mp_kwargs = {
"num_dof": 5,
"num_basis": 8,
"duration": 4,
"policy_type": "motor",
"weights_scale": 1,
"zero_start": True,
"policy_kwargs": {
"p_gains": 1,
"d_gains": 0.1
}
}
# kwargs = dict(time_limit=4, episode_length=200)
kwargs = {}
env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs, **kwargs)
# Plot difference between real trajectory and target MP trajectory
env.reset()
w = env.action_space.sample() * 10
visualize(env)
pos, vel = env.mp_rollout(w)
base_shape = env.full_action_space.shape
actual_pos = np.zeros((len(pos), *base_shape))
actual_vel = np.zeros((len(pos), *base_shape))
act = np.zeros((len(pos), *base_shape))
for t, pos_vel in enumerate(zip(pos, vel)):
actions = env.policy.get_action(pos_vel[0], pos_vel[1])
actions = np.clip(actions, env.full_action_space.low, env.full_action_space.high)
_, _, _, _ = env.env.step(actions)
act[t, :] = actions
# TODO verify for your environment
actual_pos[t, :] = env.current_pos
actual_vel[t, :] = env.current_vel
plt.figure(figsize=(15, 5))
plt.subplot(131)
plt.title("Position")
p1 = plt.plot(actual_pos, c='C0', label="true")
# plt.plot(actual_pos_ball, label="true pos ball")
p2 = plt.plot(pos, c='C1', label="MP") # , label=["MP" if i == 0 else None for i in range(np.prod(base_shape))])
plt.xlabel("Episode steps")
# plt.legend()
handles, labels = plt.gca().get_legend_handles_labels()
from collections import OrderedDict
by_label = OrderedDict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys())
plt.subplot(132)
plt.title("Velocity")
plt.plot(actual_vel, c='C0', label="true")
plt.plot(vel, c='C1', label="MP")
plt.xlabel("Episode steps")
plt.subplot(133)
plt.title("Actions")
plt.plot(act, c="C0"), # label=[f"actions" if i == 0 else "" for i in range(np.prod(base_action_shape))])
plt.xlabel("Episode steps")
# plt.legend()
plt.show()