2022-09-23 09:40:35 +02:00
|
|
|
import gym
|
|
|
|
import numpy as np
|
|
|
|
from fancy_gym import make
|
|
|
|
|
|
|
|
|
2022-09-26 08:39:54 +02:00
|
|
|
def run_env(env_id, iterations=None, seed=0, render=False):
|
2022-09-23 09:40:35 +02:00
|
|
|
"""
|
|
|
|
Example for running a DMC based env in the step based setting.
|
2022-09-26 08:39:54 +02:00
|
|
|
The env_id has to be specified as `dmc:domain_name-task_name` or
|
2022-09-23 09:40:35 +02:00
|
|
|
for manipulation tasks as `manipulation-environment_name`
|
|
|
|
|
|
|
|
Args:
|
2022-09-26 08:39:54 +02:00
|
|
|
env_id: Either `dmc:domain_name-task_name` or `dmc:manipulation-environment_name`
|
2022-09-23 09:40:35 +02:00
|
|
|
iterations: Number of rollout steps to run
|
2022-09-26 08:39:54 +02:00
|
|
|
seed: random seeding
|
2022-09-23 09:40:35 +02:00
|
|
|
render: Render the episode
|
|
|
|
|
2022-09-26 08:39:54 +02:00
|
|
|
Returns: observations, rewards, dones, actions
|
2022-09-23 09:40:35 +02:00
|
|
|
|
|
|
|
"""
|
|
|
|
env: gym.Env = make(env_id, seed=seed)
|
|
|
|
rewards = []
|
|
|
|
observations = []
|
2022-09-26 08:39:54 +02:00
|
|
|
actions = []
|
2022-09-23 09:40:35 +02:00
|
|
|
dones = []
|
|
|
|
obs = env.reset()
|
2022-09-26 08:39:54 +02:00
|
|
|
verify_observations(obs, env.observation_space, "reset()")
|
2022-09-23 09:40:35 +02:00
|
|
|
|
2022-09-26 08:39:54 +02:00
|
|
|
iterations = iterations or (env.spec.max_episode_steps or 1)
|
2022-09-23 09:40:35 +02:00
|
|
|
|
|
|
|
# number of samples(multiple environment steps)
|
|
|
|
for i in range(iterations):
|
|
|
|
observations.append(obs)
|
|
|
|
|
|
|
|
ac = env.action_space.sample()
|
2022-09-26 08:39:54 +02:00
|
|
|
actions.append(ac)
|
2022-09-23 09:40:35 +02:00
|
|
|
# ac = np.random.uniform(env.action_space.low, env.action_space.high, env.action_space.shape)
|
|
|
|
obs, reward, done, info = env.step(ac)
|
|
|
|
|
2022-09-26 08:39:54 +02:00
|
|
|
verify_observations(obs, env.observation_space, "step()")
|
|
|
|
verify_reward(reward)
|
|
|
|
verify_done(done)
|
2022-09-23 09:40:35 +02:00
|
|
|
|
|
|
|
rewards.append(reward)
|
|
|
|
dones.append(done)
|
|
|
|
|
|
|
|
if render:
|
|
|
|
env.render("human")
|
|
|
|
|
|
|
|
if done:
|
2022-09-26 08:39:54 +02:00
|
|
|
break
|
2022-09-23 09:40:35 +02:00
|
|
|
|
2022-09-26 08:39:54 +02:00
|
|
|
assert done, "Done flag is not True after end of episode."
|
2022-09-23 09:40:35 +02:00
|
|
|
observations.append(obs)
|
|
|
|
env.close()
|
|
|
|
del env
|
2022-09-26 08:39:54 +02:00
|
|
|
return np.array(observations), np.array(rewards), np.array(dones), np.array(actions)
|
2022-09-23 09:40:35 +02:00
|
|
|
|
|
|
|
|
2022-09-26 08:39:54 +02:00
|
|
|
def run_env_determinism(env_id: str, seed: int):
|
|
|
|
traj1 = run_env(env_id, seed=seed)
|
|
|
|
traj2 = run_env(env_id, seed=seed)
|
2022-09-23 09:40:35 +02:00
|
|
|
# Iterate over two trajectories, which should have the same state and action sequence
|
|
|
|
for i, time_step in enumerate(zip(*traj1, *traj2)):
|
2022-09-26 08:39:54 +02:00
|
|
|
obs1, rwd1, done1, ac1, obs2, rwd2, done2, ac2 = time_step
|
|
|
|
assert np.array_equal(obs1, obs2), f"Observations [{i}] {obs1} and {obs2} do not match."
|
|
|
|
assert np.array_equal(ac1, ac2), f"Actions [{i}] {ac1} and {ac2} do not match."
|
|
|
|
assert np.array_equal(rwd1, rwd2), f"Rewards [{i}] {rwd1} and {rwd2} do not match."
|
|
|
|
assert np.array_equal(done1, done2), f"Dones [{i}] {done1} and {done2} do not match."
|
2022-09-23 09:40:35 +02:00
|
|
|
|
|
|
|
|
2022-09-26 09:46:53 +02:00
|
|
|
def verify_observations(obs, observation_space: gym.Space, obs_type="reset()"):
|
2022-09-23 09:40:35 +02:00
|
|
|
assert observation_space.contains(obs), \
|
|
|
|
f"Observation {obs} received from {obs_type} not contained in observation space {observation_space}."
|
|
|
|
|
|
|
|
|
2022-09-26 08:39:54 +02:00
|
|
|
def verify_reward(reward):
|
|
|
|
assert isinstance(reward, (float, int)), f"Returned type {type(reward)} as reward, expected float or int."
|
2022-09-23 09:40:35 +02:00
|
|
|
|
|
|
|
|
2022-09-26 08:39:54 +02:00
|
|
|
def verify_done(done):
|
2022-09-23 09:40:35 +02:00
|
|
|
assert isinstance(done, bool), f"Returned {done} as done flag, expected bool."
|