added more fine-grained test cases
This commit is contained in:
parent
c443b06fef
commit
5e33259cb1
@ -12,7 +12,7 @@ MANIPULATION_SPECS = [f'manipulation-{task}' for task in manipulation.ALL if tas
|
|||||||
SEED = 1
|
SEED = 1
|
||||||
|
|
||||||
|
|
||||||
class TestEnvironments(unittest.TestCase):
|
class TestStepDMCEnvironments(unittest.TestCase):
|
||||||
|
|
||||||
def _run_env(self, env_id, iterations=None, seed=SEED, render=False):
|
def _run_env(self, env_id, iterations=None, seed=SEED, render=False):
|
||||||
"""
|
"""
|
||||||
|
@ -10,7 +10,7 @@ ALL_SPECS = list(spec for spec in gym.envs.registry.all() if "alr_envs" in spec.
|
|||||||
SEED = 1
|
SEED = 1
|
||||||
|
|
||||||
|
|
||||||
class TestEnvironments(unittest.TestCase):
|
class TestMPEnvironments(unittest.TestCase):
|
||||||
|
|
||||||
def _run_env(self, env_id, iterations=None, seed=SEED, render=False):
|
def _run_env(self, env_id, iterations=None, seed=SEED, render=False):
|
||||||
"""
|
"""
|
||||||
@ -68,6 +68,18 @@ class TestEnvironments(unittest.TestCase):
|
|||||||
del env
|
del env
|
||||||
return np.array(observations), np.array(rewards), np.array(dones)
|
return np.array(observations), np.array(rewards), np.array(dones)
|
||||||
|
|
||||||
|
def _run_env_determinism(self, ids):
|
||||||
|
seed = 0
|
||||||
|
for env_id in ids:
|
||||||
|
with self.subTest(msg=env_id):
|
||||||
|
traj1 = self._run_env(env_id, seed=seed)
|
||||||
|
traj2 = self._run_env(env_id, seed=seed)
|
||||||
|
for i, time_step in enumerate(zip(*traj1, *traj2)):
|
||||||
|
obs1, rwd1, done1, obs2, rwd2, done2 = time_step
|
||||||
|
self.assertTrue(np.array_equal(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match.")
|
||||||
|
self.assertEqual(rwd1, rwd2, f"Rewards [{i}] {rwd1} and {rwd2} do not match.")
|
||||||
|
self.assertEqual(done1, done2, f"Dones [{i}] {done1} and {done2} do not match.")
|
||||||
|
|
||||||
def _verify_observations(self, obs, observation_space, obs_type="reset()"):
|
def _verify_observations(self, obs, observation_space, obs_type="reset()"):
|
||||||
self.assertTrue(observation_space.contains(obs),
|
self.assertTrue(observation_space.contains(obs),
|
||||||
f"Observation {obs} received from {obs_type} "
|
f"Observation {obs} received from {obs_type} "
|
||||||
@ -79,31 +91,81 @@ class TestEnvironments(unittest.TestCase):
|
|||||||
def _verify_done(self, done):
|
def _verify_done(self, done):
|
||||||
self.assertIsInstance(done, bool, f"Returned {done} as done flag, expected bool.")
|
self.assertIsInstance(done, bool, f"Returned {done} as done flag, expected bool.")
|
||||||
|
|
||||||
def test_environment_functionality(self):
|
def test_alr_environment_functionality(self):
|
||||||
"""Tests that environments runs without errors using random actions."""
|
"""Tests that environments runs without errors using random actions for ALR MP envs."""
|
||||||
for spec in ALL_SPECS:
|
with self.subTest(msg="DMP"):
|
||||||
with self.subTest(msg=spec.id):
|
for env_id in alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS['DMP']:
|
||||||
self._run_env(spec.id)
|
with self.subTest(msg=env_id):
|
||||||
|
self._run_env(env_id)
|
||||||
|
|
||||||
def test_environment_determinism(self):
|
with self.subTest(msg="DetPMP"):
|
||||||
"""Tests that identical seeds produce identical trajectories."""
|
for env_id in alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']:
|
||||||
seed = 0
|
with self.subTest(msg=env_id):
|
||||||
# Iterate over two trajectories, which should have the same state and action sequence
|
self._run_env(env_id)
|
||||||
for spec in ALL_SPECS:
|
|
||||||
with self.subTest(msg=spec.id):
|
|
||||||
traj1 = self._run_env(spec.id, seed=seed)
|
|
||||||
traj2 = self._run_env(spec.id, seed=seed)
|
|
||||||
for i, time_step in enumerate(zip(*traj1, *traj2)):
|
|
||||||
obs1, rwd1, done1, obs2, rwd2, done2 = time_step
|
|
||||||
self.assertTrue(np.array_equal(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match.")
|
|
||||||
self.assertEqual(rwd1, rwd2, f"Rewards [{i}] {rwd1} and {rwd2} do not match.")
|
|
||||||
self.assertEqual(done1, done2, f"Dones [{i}] {done1} and {done2} do not match.")
|
|
||||||
|
|
||||||
def test_environment_functionality_meta(self):
|
def test_openai_environment_functionality(self):
|
||||||
"""Tests that environments runs without errors using random actions."""
|
"""Tests that environments runs without errors using random actions for OpenAI gym MP envs."""
|
||||||
for id in alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']:
|
with self.subTest(msg="DMP"):
|
||||||
with self.subTest(msg=id):
|
for env_id in alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS['DMP']:
|
||||||
self._run_env(id)
|
with self.subTest(msg=env_id):
|
||||||
|
self._run_env(env_id)
|
||||||
|
|
||||||
|
with self.subTest(msg="DetPMP"):
|
||||||
|
for env_id in alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']:
|
||||||
|
with self.subTest(msg=env_id):
|
||||||
|
self._run_env(env_id)
|
||||||
|
|
||||||
|
def test_dmc_environment_functionality(self):
|
||||||
|
"""Tests that environments runs without errors using random actions for DMC MP envs."""
|
||||||
|
with self.subTest(msg="DMP"):
|
||||||
|
for env_id in alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS['DMP']:
|
||||||
|
with self.subTest(msg=env_id):
|
||||||
|
self._run_env(env_id)
|
||||||
|
|
||||||
|
with self.subTest(msg="DetPMP"):
|
||||||
|
for env_id in alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']:
|
||||||
|
with self.subTest(msg=env_id):
|
||||||
|
self._run_env(env_id)
|
||||||
|
|
||||||
|
def test_metaworld_environment_functionality(self):
|
||||||
|
"""Tests that environments runs without errors using random actions for Metaworld MP envs."""
|
||||||
|
with self.subTest(msg="DMP"):
|
||||||
|
for env_id in alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS['DMP']:
|
||||||
|
with self.subTest(msg=env_id):
|
||||||
|
self._run_env(env_id)
|
||||||
|
|
||||||
|
with self.subTest(msg="DetPMP"):
|
||||||
|
for env_id in alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS['DetPMP']:
|
||||||
|
with self.subTest(msg=env_id):
|
||||||
|
self._run_env(env_id)
|
||||||
|
|
||||||
|
def test_alr_environment_determinism(self):
|
||||||
|
"""Tests that identical seeds produce identical trajectories for ALR MP Envs."""
|
||||||
|
with self.subTest(msg="DMP"):
|
||||||
|
self._run_env_determinism(alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
|
||||||
|
with self.subTest(msg="DetPMP"):
|
||||||
|
self._run_env_determinism(alr_envs.ALL_ALR_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"])
|
||||||
|
|
||||||
|
def test_openai_environment_determinism(self):
|
||||||
|
"""Tests that identical seeds produce identical trajectories for OpenAI gym MP Envs."""
|
||||||
|
with self.subTest(msg="DMP"):
|
||||||
|
self._run_env_determinism(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
|
||||||
|
with self.subTest(msg="DetPMP"):
|
||||||
|
self._run_env_determinism(alr_envs.ALL_GYM_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"])
|
||||||
|
|
||||||
|
def test_dmc_environment_determinism(self):
|
||||||
|
"""Tests that identical seeds produce identical trajectories for DMC MP Envs."""
|
||||||
|
with self.subTest(msg="DMP"):
|
||||||
|
self._run_env_determinism(alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
|
||||||
|
with self.subTest(msg="DetPMP"):
|
||||||
|
self._run_env_determinism(alr_envs.ALL_DEEPMIND_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"])
|
||||||
|
|
||||||
|
def test_metaworld_environment_determinism(self):
|
||||||
|
"""Tests that identical seeds produce identical trajectories for Metaworld MP Envs."""
|
||||||
|
with self.subTest(msg="DMP"):
|
||||||
|
self._run_env_determinism(alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DMP"])
|
||||||
|
with self.subTest(msg="DetPMP"):
|
||||||
|
self._run_env_determinism(alr_envs.ALL_METAWORLD_MOTION_PRIMITIVE_ENVIRONMENTS["DetPMP"])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -10,7 +10,7 @@ ALL_ENVS = [env.split("-goal-observable")[0] for env, _ in ALL_V2_ENVIRONMENTS_G
|
|||||||
SEED = 1
|
SEED = 1
|
||||||
|
|
||||||
|
|
||||||
class TestEnvironments(unittest.TestCase):
|
class TestStepMetaWorlEnvironments(unittest.TestCase):
|
||||||
|
|
||||||
def _run_env(self, env_id, iterations=None, seed=SEED, render=False):
|
def _run_env(self, env_id, iterations=None, seed=SEED, render=False):
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user