updated examples to new api,
This commit is contained in:
		
							parent
							
								
									0c7ac838bf
								
							
						
					
					
						commit
						fbe3ef4a4b
					
				| @ -26,10 +26,10 @@ def example_dmc(env_id="dmc:fish-swim", seed=1, iterations=1000, render=True): | |||||||
|         ac = env.action_space.sample() |         ac = env.action_space.sample() | ||||||
|         if render: |         if render: | ||||||
|             env.render(mode="human") |             env.render(mode="human") | ||||||
|         obs, reward, done, info = env.step(ac) |         obs, reward, terminated, truncated, info = env.step(ac) | ||||||
|         rewards += reward |         rewards += reward | ||||||
| 
 | 
 | ||||||
|         if done: |         if terminated or truncated: | ||||||
|             print(env_id, rewards) |             print(env_id, rewards) | ||||||
|             rewards = 0 |             rewards = 0 | ||||||
|             obs = env.reset() |             obs = env.reset() | ||||||
| @ -102,10 +102,10 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True): | |||||||
|     # number of samples/full trajectories (multiple environment steps) |     # number of samples/full trajectories (multiple environment steps) | ||||||
|     for i in range(iterations): |     for i in range(iterations): | ||||||
|         ac = env.action_space.sample() |         ac = env.action_space.sample() | ||||||
|         obs, reward, done, info = env.step(ac) |         obs, reward, terminated, truncated, info = env.step(ac) | ||||||
|         rewards += reward |         rewards += reward | ||||||
| 
 | 
 | ||||||
|         if done: |         if terminated or truncated: | ||||||
|             print(base_env_id, rewards) |             print(base_env_id, rewards) | ||||||
|             rewards = 0 |             rewards = 0 | ||||||
|             obs = env.reset() |             obs = env.reset() | ||||||
|  | |||||||
| @ -1,6 +1,6 @@ | |||||||
| from collections import defaultdict | from collections import defaultdict | ||||||
| 
 | 
 | ||||||
| import gym | import gymnasium as gym | ||||||
| import numpy as np | import numpy as np | ||||||
| 
 | 
 | ||||||
| import fancy_gym | import fancy_gym | ||||||
| @ -29,13 +29,13 @@ def example_general(env_id="Pendulum-v1", seed=1, iterations=1000, render=True): | |||||||
| 
 | 
 | ||||||
|     # number of environment steps |     # number of environment steps | ||||||
|     for i in range(iterations): |     for i in range(iterations): | ||||||
|         obs, reward, done, info = env.step(env.action_space.sample()) |         obs, reward, terminated, truncated, info = env.step(env.action_space.sample()) | ||||||
|         rewards += reward |         rewards += reward | ||||||
| 
 | 
 | ||||||
|         if render: |         if render: | ||||||
|             env.render() |             env.render() | ||||||
| 
 | 
 | ||||||
|         if done: |         if terminated or truncated: | ||||||
|             print(rewards) |             print(rewards) | ||||||
|             rewards = 0 |             rewards = 0 | ||||||
|             obs = env.reset() |             obs = env.reset() | ||||||
| @ -69,12 +69,15 @@ def example_async(env_id="HoleReacher-v0", n_cpu=4, seed=int('533D', 16), n_samp | |||||||
|     # this would generate more samples than requested if n_samples % num_envs != 0 |     # this would generate more samples than requested if n_samples % num_envs != 0 | ||||||
|     repeat = int(np.ceil(n_samples / env.num_envs)) |     repeat = int(np.ceil(n_samples / env.num_envs)) | ||||||
|     for i in range(repeat): |     for i in range(repeat): | ||||||
|         obs, reward, done, info = env.step(env.action_space.sample()) |         obs, reward, terminated, truncated, info = env.step(env.action_space.sample()) | ||||||
|         buffer['obs'].append(obs) |         buffer['obs'].append(obs) | ||||||
|         buffer['reward'].append(reward) |         buffer['reward'].append(reward) | ||||||
|         buffer['done'].append(done) |         buffer['terminated'].append(terminated) | ||||||
|  |         buffer['truncated'].append(truncated) | ||||||
|         buffer['info'].append(info) |         buffer['info'].append(info) | ||||||
|         rewards += reward |         rewards += reward | ||||||
|  | 
 | ||||||
|  |         done = terminated or truncated | ||||||
|         if np.any(done): |         if np.any(done): | ||||||
|             print(f"Reward at iteration {i}: {rewards[done]}") |             print(f"Reward at iteration {i}: {rewards[done]}") | ||||||
|             rewards[done] = 0 |             rewards[done] = 0 | ||||||
|  | |||||||
| @ -29,9 +29,9 @@ def example_dmc(env_id="fish-swim", seed=1, iterations=1000, render=True): | |||||||
|             # THIS NEEDS TO BE SET TO FALSE FOR NOW, BECAUSE THE INTERFACE FOR RENDERING IS DIFFERENT TO BASIC GYM |             # THIS NEEDS TO BE SET TO FALSE FOR NOW, BECAUSE THE INTERFACE FOR RENDERING IS DIFFERENT TO BASIC GYM | ||||||
|             # TODO: Remove this, when Metaworld fixes its interface. |             # TODO: Remove this, when Metaworld fixes its interface. | ||||||
|             env.render(False) |             env.render(False) | ||||||
|         obs, reward, done, info = env.step(ac) |         obs, reward, terminated, truncated, info = env.step(ac) | ||||||
|         rewards += reward |         rewards += reward | ||||||
|         if done: |         if terminated or truncated: | ||||||
|             print(env_id, rewards) |             print(env_id, rewards) | ||||||
|             rewards = 0 |             rewards = 0 | ||||||
|             obs = env.reset() |             obs = env.reset() | ||||||
| @ -103,10 +103,10 @@ def example_custom_dmc_and_mp(seed=1, iterations=1, render=True): | |||||||
|     # number of samples/full trajectories (multiple environment steps) |     # number of samples/full trajectories (multiple environment steps) | ||||||
|     for i in range(iterations): |     for i in range(iterations): | ||||||
|         ac = env.action_space.sample() |         ac = env.action_space.sample() | ||||||
|         obs, reward, done, info = env.step(ac) |         obs, reward, terminated, truncated, info = env.step(ac) | ||||||
|         rewards += reward |         rewards += reward | ||||||
| 
 | 
 | ||||||
|         if done: |         if terminated or truncated: | ||||||
|             print(base_env_id, rewards) |             print(base_env_id, rewards) | ||||||
|             rewards = 0 |             rewards = 0 | ||||||
|             obs = env.reset() |             obs = env.reset() | ||||||
| @ -131,4 +131,3 @@ if __name__ == '__main__': | |||||||
|     # |     # | ||||||
|     # # Custom MetaWorld task |     # # Custom MetaWorld task | ||||||
|     example_custom_dmc_and_mp(seed=10, iterations=1, render=render) |     example_custom_dmc_and_mp(seed=10, iterations=1, render=render) | ||||||
| 
 |  | ||||||
|  | |||||||
| @ -41,11 +41,11 @@ def example_mp(env_name="HoleReacherProMP-v0", seed=1, iterations=1, render=True | |||||||
|         # This executes a full trajectory and gives back the context (obs) of the last step in the trajectory, or the |         # This executes a full trajectory and gives back the context (obs) of the last step in the trajectory, or the | ||||||
|         # full observation space of the last step, if replanning/sub-trajectory learning is used. The 'reward' is equal |         # full observation space of the last step, if replanning/sub-trajectory learning is used. The 'reward' is equal | ||||||
|         # to the return of a trajectory. Default is the sum over the step-wise rewards. |         # to the return of a trajectory. Default is the sum over the step-wise rewards. | ||||||
|         obs, reward, done, info = env.step(ac) |         obs, reward, terminated, truncated, info = env.step(ac) | ||||||
|         # Aggregated returns |         # Aggregated returns | ||||||
|         returns += reward |         returns += reward | ||||||
| 
 | 
 | ||||||
|         if done: |         if terminated or truncated: | ||||||
|             print(reward) |             print(reward) | ||||||
|             obs = env.reset() |             obs = env.reset() | ||||||
| 
 | 
 | ||||||
| @ -79,10 +79,10 @@ def example_custom_mp(env_name="Reacher5dProMP-v0", seed=1, iterations=1, render | |||||||
|     # number of samples/full trajectories (multiple environment steps) |     # number of samples/full trajectories (multiple environment steps) | ||||||
|     for i in range(iterations): |     for i in range(iterations): | ||||||
|         ac = env.action_space.sample() |         ac = env.action_space.sample() | ||||||
|         obs, reward, done, info = env.step(ac) |         obs, reward, terminated, truncated, info = env.step(ac) | ||||||
|         returns += reward |         returns += reward | ||||||
| 
 | 
 | ||||||
|         if done: |         if terminated or truncated: | ||||||
|             print(i, reward) |             print(i, reward) | ||||||
|             obs = env.reset() |             obs = env.reset() | ||||||
| 
 | 
 | ||||||
| @ -145,10 +145,10 @@ def example_fully_custom_mp(seed=1, iterations=1, render=True): | |||||||
|     # number of samples/full trajectories (multiple environment steps) |     # number of samples/full trajectories (multiple environment steps) | ||||||
|     for i in range(iterations): |     for i in range(iterations): | ||||||
|         ac = env.action_space.sample() |         ac = env.action_space.sample() | ||||||
|         obs, reward, done, info = env.step(ac) |         obs, reward, terminated, truncated, info = env.step(ac) | ||||||
|         rewards += reward |         rewards += reward | ||||||
| 
 | 
 | ||||||
|         if done: |         if terminated or truncated: | ||||||
|             print(rewards) |             print(rewards) | ||||||
|             rewards = 0 |             rewards = 0 | ||||||
|             obs = env.reset() |             obs = env.reset() | ||||||
|  | |||||||
| @ -24,10 +24,10 @@ def example_mp(env_name, seed=1, render=True): | |||||||
|         else: |         else: | ||||||
|             env.render(mode=None) |             env.render(mode=None) | ||||||
|         ac = env.action_space.sample() |         ac = env.action_space.sample() | ||||||
|         obs, reward, done, info = env.step(ac) |         obs, reward, terminated, truncated, info = env.step(ac) | ||||||
|         returns += reward |         returns += reward | ||||||
| 
 | 
 | ||||||
|         if done: |         if terminated or truncated: | ||||||
|             print(returns) |             print(returns) | ||||||
|             obs = env.reset() |             obs = env.reset() | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -34,7 +34,7 @@ fig.show() | |||||||
| for t, pos_vel in enumerate(zip(pos, vel)): | for t, pos_vel in enumerate(zip(pos, vel)): | ||||||
|     actions = env.tracking_controller.get_action(pos_vel[0], pos_vel[1], env.current_vel, env.current_pos) |     actions = env.tracking_controller.get_action(pos_vel[0], pos_vel[1], env.current_vel, env.current_pos) | ||||||
|     actions = np.clip(actions, env.env.action_space.low, env.env.action_space.high) |     actions = np.clip(actions, env.env.action_space.low, env.env.action_space.high) | ||||||
|     _, _, _, _ = env.env.step(actions) |     env.env.step(actions) | ||||||
|     if t % 15 == 0: |     if t % 15 == 0: | ||||||
|         img.set_data(env.env.render(mode="rgb_array")) |         img.set_data(env.env.render(mode="rgb_array")) | ||||||
|         fig.canvas.draw() |         fig.canvas.draw() | ||||||
|  | |||||||
							
								
								
									
										8
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								setup.py
									
									
									
									
									
								
							| @ -7,8 +7,10 @@ extras = { | |||||||
|     "dmc": ["dm_control>=1.0.1"], |     "dmc": ["dm_control>=1.0.1"], | ||||||
|     "metaworld": ["metaworld @ git+https://github.com/rlworkgroup/metaworld.git@master#egg=metaworld", |     "metaworld": ["metaworld @ git+https://github.com/rlworkgroup/metaworld.git@master#egg=metaworld", | ||||||
|                   'mujoco-py<2.2,>=2.1', |                   'mujoco-py<2.2,>=2.1', | ||||||
|                   'scipy' |                   'scipy', | ||||||
|  |                   'gym>=0.15.4', | ||||||
|                   ], |                   ], | ||||||
|  |     "mujoco": ["gymnasium[mujoco]"], | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| # All dependencies | # All dependencies | ||||||
| @ -18,7 +20,7 @@ extras["all"] = list(set(itertools.chain.from_iterable(map(lambda group: extras[ | |||||||
| setup( | setup( | ||||||
|     author='Fabian Otto, Onur Celik', |     author='Fabian Otto, Onur Celik', | ||||||
|     name='fancy_gym', |     name='fancy_gym', | ||||||
|     version='0.2', |     version='0.3', | ||||||
|     classifiers=[ |     classifiers=[ | ||||||
|         # Python 3.7 is minimally supported |         # Python 3.7 is minimally supported | ||||||
|         "Programming Language :: Python :: 3", |         "Programming Language :: Python :: 3", | ||||||
| @ -29,7 +31,7 @@ setup( | |||||||
|     ], |     ], | ||||||
|     extras_require=extras, |     extras_require=extras, | ||||||
|     install_requires=[ |     install_requires=[ | ||||||
|         'gym[mujoco]<0.25.0,>=0.24.0', |         'gymnasium', | ||||||
|         'mp_pytorch @ git+https://github.com/ALRhub/MP_PyTorch.git@main' |         'mp_pytorch @ git+https://github.com/ALRhub/MP_PyTorch.git@main' | ||||||
|     ], |     ], | ||||||
|     packages=[package for package in find_packages() if package.startswith("fancy_gym")], |     packages=[package for package in find_packages() if package.startswith("fancy_gym")], | ||||||
|  | |||||||
| @ -1,39 +1,43 @@ | |||||||
| from itertools import chain | from itertools import chain | ||||||
|  | from typing import Callable | ||||||
| 
 | 
 | ||||||
|  | import gymnasium as gym | ||||||
| import pytest | import pytest | ||||||
| from dm_control import suite, manipulation | from dm_control import suite, manipulation | ||||||
| 
 | 
 | ||||||
| import fancy_gym | import fancy_gym | ||||||
| from test.utils import run_env, run_env_determinism | from test.utils import run_env, run_env_determinism | ||||||
| 
 | 
 | ||||||
| SUITE_IDS = [f'dmc:{env}-{task}' for env, task in suite.ALL_TASKS if env != "lqr"] | # SUITE_IDS = [f'dmc:{env}-{task}' for env, task in suite.ALL_TASKS if env != "lqr"] | ||||||
| MANIPULATION_IDS = [f'dmc:manipulation-{task}' for task in manipulation.ALL if task.endswith('_features')] | # MANIPULATION_IDS = [f'dmc:manipulation-{task}' for task in manipulation.ALL if task.endswith('_features')] | ||||||
|  | DM_CONTROL_IDS = [spec.id for spec in gym.envs.registry.values() if | ||||||
|  |                   not isinstance(spec.entry_point, Callable) and spec.entry_point.startswith('dm_control/')] | ||||||
| DMC_MP_IDS = chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()) | DMC_MP_IDS = chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()) | ||||||
| SEED = 1 | SEED = 1 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize('env_id', SUITE_IDS) | @pytest.mark.parametrize('env_id', DM_CONTROL_IDS) | ||||||
| def test_step_suite_functionality(env_id: str): | def test_step_dm_control_functionality(env_id: str): | ||||||
|     """Tests that suite step environments run without errors using random actions.""" |     """Tests that suite step environments run without errors using random actions.""" | ||||||
|     run_env(env_id) |     run_env(env_id) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize('env_id', SUITE_IDS) | @pytest.mark.parametrize('env_id', DM_CONTROL_IDS) | ||||||
| def test_step_suite_determinism(env_id: str): | def test_step_dm_control_determinism(env_id: str): | ||||||
|     """Tests that for step environments identical seeds produce identical trajectories.""" |     """Tests that for step environments identical seeds produce identical trajectories.""" | ||||||
|     run_env_determinism(env_id, SEED) |     run_env_determinism(env_id, SEED) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize('env_id', MANIPULATION_IDS) | # @pytest.mark.parametrize('env_id', MANIPULATION_IDS) | ||||||
| def test_step_manipulation_functionality(env_id: str): | # def test_step_manipulation_functionality(env_id: str): | ||||||
|     """Tests that manipulation step environments run without errors using random actions.""" | #     """Tests that manipulation step environments run without errors using random actions.""" | ||||||
|     run_env(env_id) | #     run_env(env_id) | ||||||
| 
 | # | ||||||
| 
 | # | ||||||
| @pytest.mark.parametrize('env_id', MANIPULATION_IDS) | # @pytest.mark.parametrize('env_id', MANIPULATION_IDS) | ||||||
| def test_step_manipulation_determinism(env_id: str): | # def test_step_manipulation_determinism(env_id: str): | ||||||
|     """Tests that for step environments identical seeds produce identical trajectories.""" | #     """Tests that for step environments identical seeds produce identical trajectories.""" | ||||||
|     run_env_determinism(env_id, SEED) | #     run_env_determinism(env_id, SEED) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize('env_id', DMC_MP_IDS) | @pytest.mark.parametrize('env_id', DMC_MP_IDS) | ||||||
|  | |||||||
| @ -1,12 +1,14 @@ | |||||||
| import itertools | import itertools | ||||||
|  | from typing import Callable | ||||||
| 
 | 
 | ||||||
| import fancy_gym | import fancy_gym | ||||||
| import gym | import gymnasium as gym | ||||||
| import pytest | import pytest | ||||||
| 
 | 
 | ||||||
| from test.utils import run_env, run_env_determinism | from test.utils import run_env, run_env_determinism | ||||||
| 
 | 
 | ||||||
| CUSTOM_IDS = [id for id, spec in gym.envs.registry.items() if | CUSTOM_IDS = [id for id, spec in gym.envs.registry.items() if | ||||||
|  |               not isinstance(spec.entry_point, Callable) and | ||||||
|               "fancy_gym" in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point] |               "fancy_gym" in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point] | ||||||
| CUSTOM_MP_IDS = itertools.chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()) | CUSTOM_MP_IDS = itertools.chain(*fancy_gym.ALL_FANCY_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()) | ||||||
| SEED = 1 | SEED = 1 | ||||||
|  | |||||||
| @ -1,12 +1,12 @@ | |||||||
| from itertools import chain | from itertools import chain | ||||||
| 
 | 
 | ||||||
| import gym | import gymnasium as gym | ||||||
| import pytest | import pytest | ||||||
| 
 | 
 | ||||||
| import fancy_gym | import fancy_gym | ||||||
| from test.utils import run_env, run_env_determinism | from test.utils import run_env, run_env_determinism | ||||||
| 
 | 
 | ||||||
| GYM_IDS = [spec.id for spec in gym.envs.registry.all() if | GYM_IDS = [spec.id for spec in gym.envs.registry.values() if | ||||||
|            "fancy_gym" not in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point] |            "fancy_gym" not in spec.entry_point and 'make_bb_env_helper' not in spec.entry_point] | ||||||
| GYM_MP_IDS = chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()) | GYM_MP_IDS = chain(*fancy_gym.ALL_DMC_MOVEMENT_PRIMITIVE_ENVIRONMENTS.values()) | ||||||
| SEED = 1 | SEED = 1 | ||||||
|  | |||||||
| @ -1,4 +1,4 @@ | |||||||
| import gym | import gymnasium as gym | ||||||
| import numpy as np | import numpy as np | ||||||
| from fancy_gym import make | from fancy_gym import make | ||||||
| 
 | 
 | ||||||
| @ -15,16 +15,16 @@ def run_env(env_id, iterations=None, seed=0, render=False): | |||||||
|         seed: random seeding |         seed: random seeding | ||||||
|         render: Render the episode |         render: Render the episode | ||||||
| 
 | 
 | ||||||
|     Returns: observations, rewards, dones, actions |     Returns: observations, rewards, terminations, truncations, actions | ||||||
| 
 | 
 | ||||||
|     """ |     """ | ||||||
|     env: gym.Env = make(env_id, seed=seed) |     env: gym.Env = make(env_id, seed=seed) | ||||||
|     rewards = [] |     rewards = [] | ||||||
|     observations = [] |     observations = [] | ||||||
|     actions = [] |     actions = [] | ||||||
|     dones = [] |     terminations = [] | ||||||
|     obs = env.reset() |     truncations = [] | ||||||
|     print(obs.dtype) |     obs, _ = env.reset() | ||||||
|     verify_observations(obs, env.observation_space, "reset()") |     verify_observations(obs, env.observation_space, "reset()") | ||||||
| 
 | 
 | ||||||
|     iterations = iterations or (env.spec.max_episode_steps or 1) |     iterations = iterations or (env.spec.max_episode_steps or 1) | ||||||
| @ -36,26 +36,28 @@ def run_env(env_id, iterations=None, seed=0, render=False): | |||||||
|         ac = env.action_space.sample() |         ac = env.action_space.sample() | ||||||
|         actions.append(ac) |         actions.append(ac) | ||||||
|         # ac = np.random.uniform(env.action_space.low, env.action_space.high, env.action_space.shape) |         # ac = np.random.uniform(env.action_space.low, env.action_space.high, env.action_space.shape) | ||||||
|         obs, reward, done, info = env.step(ac) |         obs, reward, terminated, truncated, info = env.step(ac) | ||||||
| 
 | 
 | ||||||
|         verify_observations(obs, env.observation_space, "step()") |         verify_observations(obs, env.observation_space, "step()") | ||||||
|         verify_reward(reward) |         verify_reward(reward) | ||||||
|         verify_done(done) |         verify_done(terminated) | ||||||
|  |         verify_done(truncated) | ||||||
| 
 | 
 | ||||||
|         rewards.append(reward) |         rewards.append(reward) | ||||||
|         dones.append(done) |         terminations.append(terminated) | ||||||
|  |         truncations.append(truncated) | ||||||
| 
 | 
 | ||||||
|         if render: |         if render: | ||||||
|             env.render("human") |             env.render("human") | ||||||
| 
 | 
 | ||||||
|         if done: |         if terminated or truncated: | ||||||
|             break |             break | ||||||
| 
 | 
 | ||||||
|     assert done, "Done flag is not True after end of episode." |     assert terminated or truncated, "Termination or truncation flag is not True after end of episode." | ||||||
|     observations.append(obs) |     observations.append(obs) | ||||||
|     env.close() |     env.close() | ||||||
|     del env |     del env | ||||||
|     return np.array(observations), np.array(rewards), np.array(dones), np.array(actions) |     return np.array(observations), np.array(rewards), np.array(terminations), np.array(truncations), np.array(actions) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def run_env_determinism(env_id: str, seed: int): | def run_env_determinism(env_id: str, seed: int): | ||||||
| @ -63,11 +65,12 @@ def run_env_determinism(env_id: str, seed: int): | |||||||
|     traj2 = run_env(env_id, seed=seed) |     traj2 = run_env(env_id, seed=seed) | ||||||
|     # Iterate over two trajectories, which should have the same state and action sequence |     # Iterate over two trajectories, which should have the same state and action sequence | ||||||
|     for i, time_step in enumerate(zip(*traj1, *traj2)): |     for i, time_step in enumerate(zip(*traj1, *traj2)): | ||||||
|         obs1, rwd1, done1, ac1, obs2, rwd2, done2, ac2 = time_step |         obs1, rwd1, term1, trunc1, ac1, obs2, rwd2, term2, trunc2, ac2 = time_step | ||||||
|         assert np.array_equal(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match." |         assert np.array_equal(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match." | ||||||
|         assert np.array_equal(ac1, ac2), f"Actions [{i}] {ac1} and {ac2} do not match." |         assert np.array_equal(ac1, ac2), f"Actions [{i}] {ac1} and {ac2} do not match." | ||||||
|         assert np.array_equal(rwd1, rwd2), f"Rewards [{i}] {rwd1} and {rwd2} do not match." |         assert np.array_equal(rwd1, rwd2), f"Rewards [{i}] {rwd1} and {rwd2} do not match." | ||||||
|         assert np.array_equal(done1, done2), f"Dones [{i}] {done1} and {done2} do not match." |         assert np.array_equal(term1, term2), f"Terminateds [{i}] {term1} and {term2} do not match." | ||||||
|  |         assert np.array_equal(term1, term2), f"Truncateds [{i}] {trunc1} and {trunc2} do not match." | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def verify_observations(obs, observation_space: gym.Space, obs_type="reset()"): | def verify_observations(obs, observation_space: gym.Space, obs_type="reset()"): | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user