fancy_gym/alr_envs/examples/pd_control_gain_tuning.py

69 lines
2.0 KiB
Python
Raw Normal View History

2022-07-12 15:43:46 +02:00
from collections import OrderedDict
import numpy as np
2022-07-13 11:10:25 +02:00
import alr_envs
from matplotlib import pyplot as plt
# This might work for some environments, however, please verify either way the correct trajectory information
# for your environment are extracted below
2022-01-25 15:23:57 +01:00
SEED = 1
2022-07-13 11:10:25 +02:00
env_id = "Reacher5dProMP-v0"
2022-07-13 11:10:25 +02:00
env = alr_envs.make(env_id, seed=SEED, controller_kwargs={'p_gains': 0.05, 'd_gains': 0.05}).env
2022-04-07 14:40:43 +02:00
env.action_space.seed(SEED)
# Plot difference between real trajectory and target MP trajectory
env.reset()
2022-07-13 11:10:25 +02:00
w = env.action_space.sample()
pos, vel = env.get_trajectory(w)
2022-07-13 11:10:25 +02:00
base_shape = env.env.action_space.shape
actual_pos = np.zeros((len(pos), *base_shape))
actual_vel = np.zeros((len(pos), *base_shape))
act = np.zeros((len(pos), *base_shape))
2022-04-07 14:40:43 +02:00
plt.ion()
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
2022-07-13 11:10:25 +02:00
img = ax.imshow(env.env.render(mode="rgb_array"))
2022-04-07 14:40:43 +02:00
fig.show()
for t, pos_vel in enumerate(zip(pos, vel)):
2022-07-13 11:10:25 +02:00
actions = env.tracking_controller.get_action(pos_vel[0], pos_vel[1], env.current_vel, env.current_pos)
actions = np.clip(actions, env.env.action_space.low, env.env.action_space.high)
_, _, _, _ = env.env.step(actions)
2022-04-07 14:40:43 +02:00
if t % 15 == 0:
2022-07-13 11:10:25 +02:00
img.set_data(env.env.render(mode="rgb_array"))
2022-04-07 14:40:43 +02:00
fig.canvas.draw()
fig.canvas.flush_events()
act[t, :] = actions
# TODO verify for your environment
actual_pos[t, :] = env.current_pos
2022-07-13 11:10:25 +02:00
actual_vel[t, :] = env.current_vel
plt.figure(figsize=(15, 5))
plt.subplot(131)
plt.title("Position")
2022-01-25 15:23:57 +01:00
p1 = plt.plot(actual_pos, c='C0', label="true")
2022-07-13 11:10:25 +02:00
p2 = plt.plot(pos, c='C1', label="MP")
plt.xlabel("Episode steps")
2022-01-25 15:23:57 +01:00
handles, labels = plt.gca().get_legend_handles_labels()
by_label = OrderedDict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys())
plt.subplot(132)
plt.title("Velocity")
2022-01-25 15:23:57 +01:00
plt.plot(actual_vel, c='C0', label="true")
plt.plot(vel, c='C1', label="MP")
plt.xlabel("Episode steps")
plt.subplot(133)
2022-04-07 14:40:43 +02:00
plt.title(f"Actions {np.std(act, axis=0)}")
2022-07-13 11:10:25 +02:00
plt.plot(act, c="C0"),
plt.xlabel("Episode steps")
plt.show()