2023-01-17 08:27:29 +01:00
from typing import List , Type
2023-01-12 17:21:56 +01:00
import gymnasium as gym
2022-09-23 09:40:35 +02:00
import numpy as np
2023-07-23 11:13:15 +02:00
from gymnasium import make
2022-09-23 09:40:35 +02:00
2023-01-17 08:27:29 +01:00
def run_env ( env_id : str , iterations : int = None , seed : int = 0 , wrappers : List [ Type [ gym . Wrapper ] ] = [ ] ,
render : bool = False ) :
2022-09-23 09:40:35 +02:00
"""
Example for running a DMC based env in the step based setting .
2022-09-26 08:39:54 +02:00
The env_id has to be specified as ` dmc : domain_name - task_name ` or
2022-09-23 09:40:35 +02:00
for manipulation tasks as ` manipulation - environment_name `
Args :
2022-09-26 08:39:54 +02:00
env_id : Either ` dmc : domain_name - task_name ` or ` dmc : manipulation - environment_name `
2022-09-23 09:40:35 +02:00
iterations : Number of rollout steps to run
2022-09-26 08:39:54 +02:00
seed : random seeding
2023-01-17 08:27:29 +01:00
wrappers : List of Wrappers to apply to the environment
2022-09-23 09:40:35 +02:00
render : Render the episode
2023-01-12 17:21:56 +01:00
Returns : observations , rewards , terminations , truncations , actions
2022-09-23 09:40:35 +02:00
"""
env : gym . Env = make ( env_id , seed = seed )
2023-01-17 08:27:29 +01:00
for w in wrappers :
env = w ( env )
2022-09-23 09:40:35 +02:00
rewards = [ ]
observations = [ ]
2022-09-26 08:39:54 +02:00
actions = [ ]
2023-01-12 17:21:56 +01:00
terminations = [ ]
truncations = [ ]
2023-06-18 11:51:01 +02:00
obs , _ = env . reset ( seed = seed )
2022-09-26 08:39:54 +02:00
verify_observations ( obs , env . observation_space , " reset() " )
2022-09-23 09:40:35 +02:00
2022-09-26 08:39:54 +02:00
iterations = iterations or ( env . spec . max_episode_steps or 1 )
2022-09-23 09:40:35 +02:00
# number of samples(multiple environment steps)
for i in range ( iterations ) :
observations . append ( obs )
ac = env . action_space . sample ( )
2022-09-26 08:39:54 +02:00
actions . append ( ac )
2022-09-23 09:40:35 +02:00
# ac = np.random.uniform(env.action_space.low, env.action_space.high, env.action_space.shape)
2023-01-12 17:21:56 +01:00
obs , reward , terminated , truncated , info = env . step ( ac )
2022-09-23 09:40:35 +02:00
2022-09-26 08:39:54 +02:00
verify_observations ( obs , env . observation_space , " step() " )
verify_reward ( reward )
2023-01-12 17:21:56 +01:00
verify_done ( terminated )
verify_done ( truncated )
2022-09-23 09:40:35 +02:00
rewards . append ( reward )
2023-01-12 17:21:56 +01:00
terminations . append ( terminated )
truncations . append ( truncated )
2022-09-23 09:40:35 +02:00
if render :
env . render ( " human " )
2023-01-12 17:21:56 +01:00
if terminated or truncated :
2022-09-26 08:39:54 +02:00
break
2022-11-01 22:51:43 +01:00
if not hasattr ( env , " replanning_schedule " ) :
2023-05-15 17:19:50 +02:00
assert terminated or truncated , f " Termination or truncation flag is not True after { i + 1 } iterations. "
2022-09-23 09:40:35 +02:00
observations . append ( obs )
env . close ( )
del env
2023-01-12 17:21:56 +01:00
return np . array ( observations ) , np . array ( rewards ) , np . array ( terminations ) , np . array ( truncations ) , np . array ( actions )
2022-09-23 09:40:35 +02:00
2023-01-17 08:27:29 +01:00
def run_env_determinism ( env_id : str , seed : int , iterations : int = None , wrappers : List [ Type [ gym . Wrapper ] ] = [ ] ) :
2023-05-15 17:19:50 +02:00
traj1 = run_env ( env_id , iterations = iterations ,
seed = seed , wrappers = wrappers )
traj2 = run_env ( env_id , iterations = iterations ,
seed = seed , wrappers = wrappers )
2022-09-23 09:40:35 +02:00
# Iterate over two trajectories, which should have the same state and action sequence
for i , time_step in enumerate ( zip ( * traj1 , * traj2 ) ) :
2023-01-12 17:21:56 +01:00
obs1 , rwd1 , term1 , trunc1 , ac1 , obs2 , rwd2 , term2 , trunc2 , ac2 = time_step
2023-05-15 17:19:50 +02:00
assert np . allclose (
2023-06-10 18:49:02 +02:00
obs1 , obs2 ) , f " Observations [ { i } ] { obs1 } ( { obs1 . shape } ) and { obs2 } ( { obs2 . shape } ) do not match: Biggest difference is { np . abs ( obs1 - obs2 ) . max ( ) } at index { np . abs ( obs1 - obs2 ) . argmax ( ) } . "
2023-05-15 17:19:50 +02:00
assert np . array_equal (
ac1 , ac2 ) , f " Actions [ { i } ] { ac1 } and { ac2 } do not match. "
assert np . array_equal (
rwd1 , rwd2 ) , f " Rewards [ { i } ] { rwd1 } and { rwd2 } do not match. "
assert np . array_equal (
term1 , term2 ) , f " Terminateds [ { i } ] { term1 } and { term2 } do not match. "
assert np . array_equal (
term1 , term2 ) , f " Truncateds [ { i } ] { trunc1 } and { trunc2 } do not match. "
2022-09-23 09:40:35 +02:00
2022-09-26 09:46:53 +02:00
def verify_observations ( obs , observation_space : gym . Space , obs_type = " reset() " ) :
2022-09-23 09:40:35 +02:00
assert observation_space . contains ( obs ) , \
2023-06-10 18:49:02 +02:00
f " Observation { obs } ( { obs . shape } ) received from { obs_type } not contained in observation space { observation_space } . "
2022-09-23 09:40:35 +02:00
2022-09-26 08:39:54 +02:00
def verify_reward ( reward ) :
2023-05-15 17:19:50 +02:00
assert isinstance (
reward , ( float , int ) ) , f " Returned type { type ( reward ) } as reward, expected float or int. "
2022-09-23 09:40:35 +02:00
2022-09-26 08:39:54 +02:00
def verify_done ( done ) :
2023-05-15 17:19:50 +02:00
assert isinstance (
done , bool ) , f " Returned { done } as done flag, expected bool. "
2023-06-18 17:47:54 +02:00
def ugly_hack_to_mitigate_metaworld_bug ( env ) :
head = env
try :
for i in range ( 16 ) :
head . curr_path_length = 0
head = head . env
except :
pass