import numpy as np
from matplotlib import pyplot as plt

from alr_envs import dmc, meta
from alr_envs.utils.make_env_helpers import make_promp_env

# This might work for some environments, however, please verify either way the correct trajectory information
# for your environment are extracted below
SEED = 10
env_id = "ball_in_cup-catch"
wrappers = [dmc.ball_in_cup.MPWrapper]

mp_kwargs = {
    "num_dof": 2,
    "num_basis": 10,
    "duration": 2,
    "width": 0.025,
    "policy_type": "motor",
    "weights_scale": 1,
    "zero_start": True,
    "policy_kwargs": {
        "p_gains": 1,
        "d_gains": 1
    }
}

kwargs = dict(time_limit=2, episode_length=100)

env = make_promp_env(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs,
                     **kwargs)

# Plot difference between real trajectory and target MP trajectory
env.reset()
pos, vel = env.mp_rollout(env.action_space.sample())

base_shape = env.full_action_space.shape
actual_pos = np.zeros((len(pos), *base_shape))
actual_vel = np.zeros((len(pos), *base_shape))
act = np.zeros((len(pos), *base_shape))

for t, pos_vel in enumerate(zip(pos, vel)):
    actions = env.policy.get_action(pos_vel[0], pos_vel[1])
    actions = np.clip(actions, env.full_action_space.low, env.full_action_space.high)
    _, _, _, _ = env.env.step(actions)
    act[t, :] = actions
    # TODO verify for your environment
    actual_pos[t, :] = env.current_pos
    actual_vel[t, :] = env.current_vel

plt.figure(figsize=(15, 5))

plt.subplot(131)
plt.title("Position")
plt.plot(actual_pos, c='C0', label=["true" if i == 0 else "" for i in range(np.prod(base_shape))])
# plt.plot(actual_pos_ball, label="true pos ball")
plt.plot(pos, c='C1', label=["MP" if i == 0 else "" for i in range(np.prod(base_shape))])
plt.xlabel("Episode steps")
plt.legend()

plt.subplot(132)
plt.title("Velocity")
plt.plot(actual_vel, c='C0', label=[f"true" if i == 0 else "" for i in range(np.prod(base_shape))])
plt.plot(vel, c='C1', label=[f"MP" if i == 0 else "" for i in range(np.prod(base_shape))])
plt.xlabel("Episode steps")
plt.legend()

plt.subplot(133)
plt.title("Actions")
plt.plot(act, c="C0"),  # label=[f"actions" if i == 0 else "" for i in range(np.prod(base_action_shape))])
plt.xlabel("Episode steps")
# plt.legend()
plt.show()