diff --git a/alr_envs/examples/examples_movement_primitives.py b/alr_envs/examples/examples_movement_primitives.py index 62ab91c..49b4ee0 100644 --- a/alr_envs/examples/examples_movement_primitives.py +++ b/alr_envs/examples/examples_movement_primitives.py @@ -157,10 +157,10 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True): if __name__ == '__main__': render = True # DMP - example_mp("alr_envs:HoleReacherDMP-v0", seed=10, iterations=5, render=render) + example_mp("HoleReacherDMP-v0", seed=10, iterations=5, render=render) # # # ProMP - example_mp("alr_envs:HoleReacherProMP-v0", seed=10, iterations=5, render=render) + example_mp("HoleReacherProMP-v0", seed=10, iterations=5, render=render) # Altered basis functions obs1 = example_custom_mp("Reacher5dProMP-v0", seed=10, iterations=5, render=render) diff --git a/alr_envs/examples/pd_control_gain_tuning.py b/alr_envs/examples/pd_control_gain_tuning.py index d4c5201..8a74e6a 100644 --- a/alr_envs/examples/pd_control_gain_tuning.py +++ b/alr_envs/examples/pd_control_gain_tuning.py @@ -1,55 +1,24 @@ from collections import OrderedDict import numpy as np +import alr_envs from matplotlib import pyplot as plt -from alr_envs import make_bb, dmc, meta -from alr_envs.envs import mujoco - - -def visualize(env): - t = env.t - pos_features = env.traj_gen.basis_generator.basis(t) - plt.plot(t, pos_features) - plt.show() - - # This might work for some environments, however, please verify either way the correct trajectory information # for your environment are extracted below SEED = 1 -# env_id = "dmc:ball_in_cup-catch" -# wrappers = [dmc.suite.ball_in_cup.MPWrapper] -env_id = "Reacher5dSparse-v0" -wrappers = [mujoco.reacher.MPWrapper] -# env_id = "metaworld:button-press-v2" -# wrappers = [meta.goal_object_change_mp_wrapper.MPWrapper] -mp_kwargs = { - "num_dof": 4, - "num_basis": 5, - "duration": 6.25, - "policy_type": "metaworld", - "weights_scale": 10, - "zero_start": True, - # "policy_kwargs": { - # "p_gains": 1, - # "d_gains": 0.1 - # } -} +env_id = "Reacher5dProMP-v0" -# kwargs = dict(time_limit=4, episode_length=200) -kwargs = {} - -env = make_bb(env_id, wrappers, seed=SEED, mp_kwargs=mp_kwargs, **kwargs) +env = alr_envs.make(env_id, seed=SEED, controller_kwargs={'p_gains': 0.05, 'd_gains': 0.05}).env env.action_space.seed(SEED) # Plot difference between real trajectory and target MP trajectory env.reset() -w = env.action_space.sample() # N(0,1) -visualize(env) -pos, vel = env.mp_rollout(w) +w = env.action_space.sample() +pos, vel = env.get_trajectory(w) -base_shape = env.full_action_space.shape +base_shape = env.env.action_space.shape actual_pos = np.zeros((len(pos), *base_shape)) actual_vel = np.zeros((len(pos), *base_shape)) act = np.zeros((len(pos), *base_shape)) @@ -57,31 +26,30 @@ act = np.zeros((len(pos), *base_shape)) plt.ion() fig = plt.figure() ax = fig.add_subplot(1, 1, 1) -img = ax.imshow(env.env.render("rgb_array")) + +img = ax.imshow(env.env.render(mode="rgb_array")) fig.show() for t, pos_vel in enumerate(zip(pos, vel)): - actions = env.policy.get_action(pos_vel[0], pos_vel[1], env.current_vel, env.current_pos) - actions = np.clip(actions, env.full_action_space.low, env.full_action_space.high) + actions = env.tracking_controller.get_action(pos_vel[0], pos_vel[1], env.current_vel, env.current_pos) + actions = np.clip(actions, env.env.action_space.low, env.env.action_space.high) _, _, _, _ = env.env.step(actions) if t % 15 == 0: - img.set_data(env.env.render("rgb_array")) + img.set_data(env.env.render(mode="rgb_array")) fig.canvas.draw() fig.canvas.flush_events() act[t, :] = actions # TODO verify for your environment actual_pos[t, :] = env.current_pos - actual_vel[t, :] = 0 # env.current_vel + actual_vel[t, :] = env.current_vel plt.figure(figsize=(15, 5)) plt.subplot(131) plt.title("Position") p1 = plt.plot(actual_pos, c='C0', label="true") -# plt.plot(actual_pos_ball, label="true pos ball") -p2 = plt.plot(pos, c='C1', label="MP") # , label=["MP" if i == 0 else None for i in range(np.prod(base_shape))]) +p2 = plt.plot(pos, c='C1', label="MP") plt.xlabel("Episode steps") -# plt.legend() handles, labels = plt.gca().get_legend_handles_labels() by_label = OrderedDict(zip(labels, handles)) @@ -95,7 +63,6 @@ plt.xlabel("Episode steps") plt.subplot(133) plt.title(f"Actions {np.std(act, axis=0)}") -plt.plot(act, c="C0"), # label=[f"actions" if i == 0 else "" for i in range(np.prod(base_action_shape))]) +plt.plot(act, c="C0"), plt.xlabel("Episode steps") -# plt.legend() plt.show()