fancy_gym/test/utils.py

103 lines
3.7 KiB
Python
Raw Normal View History

from typing import List, Type
2023-01-12 17:21:56 +01:00
import gymnasium as gym
2022-09-23 09:40:35 +02:00
import numpy as np
from fancy_gym import make
def run_env(env_id: str, iterations: int = None, seed: int = 0, wrappers: List[Type[gym.Wrapper]] = [],
render: bool = False):
2022-09-23 09:40:35 +02:00
"""
Example for running a DMC based env in the step based setting.
2022-09-26 08:39:54 +02:00
The env_id has to be specified as `dmc:domain_name-task_name` or
2022-09-23 09:40:35 +02:00
for manipulation tasks as `manipulation-environment_name`
Args:
2022-09-26 08:39:54 +02:00
env_id: Either `dmc:domain_name-task_name` or `dmc:manipulation-environment_name`
2022-09-23 09:40:35 +02:00
iterations: Number of rollout steps to run
2022-09-26 08:39:54 +02:00
seed: random seeding
wrappers: List of Wrappers to apply to the environment
2022-09-23 09:40:35 +02:00
render: Render the episode
2023-01-12 17:21:56 +01:00
Returns: observations, rewards, terminations, truncations, actions
2022-09-23 09:40:35 +02:00
"""
env: gym.Env = make(env_id, seed=seed)
for w in wrappers:
env = w(env)
2022-09-23 09:40:35 +02:00
rewards = []
observations = []
2022-09-26 08:39:54 +02:00
actions = []
2023-01-12 17:21:56 +01:00
terminations = []
truncations = []
obs, _ = env.reset()
2022-09-26 08:39:54 +02:00
verify_observations(obs, env.observation_space, "reset()")
2022-09-23 09:40:35 +02:00
2022-09-26 08:39:54 +02:00
iterations = iterations or (env.spec.max_episode_steps or 1)
2022-09-23 09:40:35 +02:00
# number of samples(multiple environment steps)
for i in range(iterations):
observations.append(obs)
ac = env.action_space.sample()
2022-09-26 08:39:54 +02:00
actions.append(ac)
2022-09-23 09:40:35 +02:00
# ac = np.random.uniform(env.action_space.low, env.action_space.high, env.action_space.shape)
2023-01-12 17:21:56 +01:00
obs, reward, terminated, truncated, info = env.step(ac)
2022-09-23 09:40:35 +02:00
2022-09-26 08:39:54 +02:00
verify_observations(obs, env.observation_space, "step()")
verify_reward(reward)
2023-01-12 17:21:56 +01:00
verify_done(terminated)
verify_done(truncated)
2022-09-23 09:40:35 +02:00
rewards.append(reward)
2023-01-12 17:21:56 +01:00
terminations.append(terminated)
truncations.append(truncated)
2022-09-23 09:40:35 +02:00
if render:
env.render("human")
2023-01-12 17:21:56 +01:00
if terminated or truncated:
2022-09-26 08:39:54 +02:00
break
if not hasattr(env, "replanning_schedule"):
assert terminated or truncated, f"Termination or truncation flag is not True after {i + 1} iterations."
2022-09-23 09:40:35 +02:00
observations.append(obs)
env.close()
del env
2023-01-12 17:21:56 +01:00
return np.array(observations), np.array(rewards), np.array(terminations), np.array(truncations), np.array(actions)
2022-09-23 09:40:35 +02:00
def run_env_determinism(env_id: str, seed: int, iterations: int = None, wrappers: List[Type[gym.Wrapper]] = []):
traj1 = run_env(env_id, iterations=iterations,
seed=seed, wrappers=wrappers)
traj2 = run_env(env_id, iterations=iterations,
seed=seed, wrappers=wrappers)
2022-09-23 09:40:35 +02:00
# Iterate over two trajectories, which should have the same state and action sequence
for i, time_step in enumerate(zip(*traj1, *traj2)):
2023-01-12 17:21:56 +01:00
obs1, rwd1, term1, trunc1, ac1, obs2, rwd2, term2, trunc2, ac2 = time_step
assert np.allclose(
obs1, obs2), f"Observations [{i}] {obs1} ({obs1.shape}) and {obs2} ({obs2.shape}) do not match."
assert np.array_equal(
ac1, ac2), f"Actions [{i}] {ac1} and {ac2} do not match."
assert np.array_equal(
rwd1, rwd2), f"Rewards [{i}] {rwd1} and {rwd2} do not match."
assert np.array_equal(
term1, term2), f"Terminateds [{i}] {term1} and {term2} do not match."
assert np.array_equal(
term1, term2), f"Truncateds [{i}] {trunc1} and {trunc2} do not match."
2022-09-23 09:40:35 +02:00
2022-09-26 09:46:53 +02:00
def verify_observations(obs, observation_space: gym.Space, obs_type="reset()"):
2022-09-23 09:40:35 +02:00
assert observation_space.contains(obs), \
f"Observation {obs} received from {obs_type} not contained in observation space {observation_space}."
2022-09-26 08:39:54 +02:00
def verify_reward(reward):
assert isinstance(
reward, (float, int)), f"Returned type {type(reward)} as reward, expected float or int."
2022-09-23 09:40:35 +02:00
2022-09-26 08:39:54 +02:00
def verify_done(done):
assert isinstance(
done, bool), f"Returned {done} as done flag, expected bool."