delete debugging codes
This commit is contained in:
parent
ca8787f449
commit
811c5df3d1
@ -1,31 +0,0 @@
|
|||||||
import fancy_gym
|
|
||||||
import numpy as np
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
# This is the code that I am using to plot the data
|
|
||||||
|
|
||||||
|
|
||||||
def plot_trajs(desired_traj, actual_traj, dim):
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
ax.plot(desired_traj[:, dim], label='desired')
|
|
||||||
ax.plot(actual_traj[:, dim], label='actual')
|
|
||||||
ax.legend()
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
def compare_desired_and_actual(env_id: str = "TableTennis4DProMP-v0"):
|
|
||||||
env = fancy_gym.make(env_id, seed=0)
|
|
||||||
env.traj_gen.basis_gn.show_basis(plot=True)
|
|
||||||
env.reset()
|
|
||||||
for _ in range(1):
|
|
||||||
env.render(mode=None)
|
|
||||||
action = env.action_space.sample()
|
|
||||||
obs, reward, done, info = env.step(action)
|
|
||||||
for i in range(1):
|
|
||||||
plot_trajs(info['desired_pos_traj'], info['pos_traj'], i)
|
|
||||||
# plot_trajs(info['desired_vel_traj'], info['vel_traj'], i)
|
|
||||||
if done:
|
|
||||||
env.reset()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
compare_desired_and_actual(env_id='TableTennis4DProMP-v0')
|
|
@ -1,22 +0,0 @@
|
|||||||
import fancy_gym
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
env_1 = fancy_gym.make("TableTennis4DProDMP-v0", seed=0)
|
|
||||||
env_2 = fancy_gym.make("TableTennis4DProDMP-v0", seed=0)
|
|
||||||
|
|
||||||
obs_1 = env_1.reset()
|
|
||||||
obs_2 = env_2.reset()
|
|
||||||
assert np.all(obs_1 == obs_2), "The observations should be the same"
|
|
||||||
for i in range(100000):
|
|
||||||
action = env_1.action_space.sample()
|
|
||||||
obs_1, reward_1, done_1, info_1 = env_1.step(action)
|
|
||||||
obs_2, reward_2, done_2, info_2 = env_2.step(action)
|
|
||||||
assert np.all(obs_1 == obs_2), "The observations should be the same"
|
|
||||||
assert np.all(reward_1 == reward_2), "The rewards should be the same"
|
|
||||||
assert np.all(done_1 == done_2), "The done flags should be the same"
|
|
||||||
for key in info_1:
|
|
||||||
assert np.all(info_1[key] == info_2[key]), f"The info fields: {key} should be the same"
|
|
||||||
if done_1 and done_2:
|
|
||||||
obs_1 = env_1.reset()
|
|
||||||
obs_2 = env_2.reset()
|
|
||||||
assert np.all(obs_1 == obs_2), "The observations should be the same"
|
|
Loading…
Reference in New Issue
Block a user