2022-09-26 16:11:41 +02:00
from itertools import chain
2022-10-24 09:24:12 +02:00
from typing import Tuple , Type , Union , Optional , Callable
2022-09-26 16:11:41 +02:00
2023-05-18 19:07:19 +02:00
import gymnasium as gym
2022-09-26 16:11:41 +02:00
import numpy as np
import pytest
2023-07-23 12:20:49 +02:00
from gymnasium import register , make
2023-05-18 19:07:19 +02:00
from gymnasium . core import ActType , ObsType
2022-09-26 16:11:41 +02:00
import fancy_gym
2022-09-30 15:07:48 +02:00
from fancy_gym . black_box . raw_interface_wrapper import RawInterfaceWrapper
2023-06-18 14:25:20 +02:00
from fancy_gym . utils . wrappers import TimeAwareObservation
2023-06-18 17:47:54 +02:00
from test . utils import ugly_hack_to_mitigate_metaworld_bug
2022-09-26 16:11:41 +02:00
SEED = 1
2023-07-30 18:34:27 +02:00
ENV_IDS = [ ' fancy/Reacher5d-v0 ' , ' dm_control/ball_in_cup-catch-v0 ' , ' metaworld/reach-v2 ' , ' Reacher-v2 ' ]
2022-09-26 16:11:41 +02:00
WRAPPERS = [ fancy_gym . envs . mujoco . reacher . MPWrapper , fancy_gym . dmc . suite . ball_in_cup . MPWrapper ,
fancy_gym . meta . goal_object_change_mp_wrapper . MPWrapper , fancy_gym . open_ai . mujoco . reacher_v2 . MPWrapper ]
ALL_MP_ENVS = chain ( * fancy_gym . ALL_MOVEMENT_PRIMITIVE_ENVIRONMENTS . values ( ) )
2023-06-18 15:52:17 +02:00
MAX_STEPS_FALLBACK = 100
2023-06-18 14:25:20 +02:00
2022-09-26 16:11:41 +02:00
2022-09-30 15:07:48 +02:00
class Object ( object ) :
pass
2022-09-26 16:11:41 +02:00
class ToyEnv ( gym . Env ) :
observation_space = gym . spaces . Box ( low = - 1 , high = 1 , shape = ( 1 , ) , dtype = np . float64 )
action_space = gym . spaces . Box ( low = - 1 , high = 1 , shape = ( 1 , ) , dtype = np . float64 )
2022-10-21 16:16:49 +02:00
dt = 0.02
2022-09-26 16:11:41 +02:00
2022-09-30 15:07:48 +02:00
def __init__ ( self , a : int = 0 , b : float = 0.0 , c : list = [ ] , d : dict = { } , e : Object = Object ( ) ) :
self . a , self . b , self . c , self . d , self . e = a , b , c , d , e
2022-09-26 16:11:41 +02:00
def reset ( self , * , seed : Optional [ int ] = None , return_info : bool = False ,
options : Optional [ dict ] = None ) - > Union [ ObsType , Tuple [ ObsType , dict ] ] :
2023-05-18 19:13:35 +02:00
obs , options = np . array ( [ - 1 ] ) , { }
return obs , options
2022-09-26 16:11:41 +02:00
def step ( self , action : ActType ) - > Tuple [ ObsType , float , bool , dict ] :
2023-05-18 19:13:35 +02:00
obs , reward , terminated , truncated , info = np . array ( [ - 1 ] ) , 1 , False , False , { }
return obs , reward , terminated , truncated , info
2022-09-26 16:11:41 +02:00
def render ( self , mode = " human " ) :
pass
class ToyWrapper ( RawInterfaceWrapper ) :
@property
def current_pos ( self ) - > Union [ float , int , np . ndarray , Tuple ] :
return np . ones ( self . action_space . shape )
@property
def current_vel ( self ) - > Union [ float , int , np . ndarray , Tuple ] :
2022-10-21 16:16:49 +02:00
return np . zeros ( self . action_space . shape )
2022-09-26 16:11:41 +02:00
@pytest.fixture ( scope = " session " , autouse = True )
def setup ( ) :
register (
id = f ' toy-v0 ' ,
entry_point = ' test.test_black_box:ToyEnv ' ,
max_episode_steps = 50 ,
)
@pytest.mark.parametrize ( ' env_id ' , ENV_IDS )
def test_missing_wrapper ( env_id : str ) :
with pytest . raises ( ValueError ) :
fancy_gym . make_bb ( env_id , [ ] , { } , { } , { } , { } , { } )
2022-11-09 17:54:34 +01:00
@pytest.mark.parametrize ( ' mp_type ' , [ ' promp ' , ' dmp ' , ' prodmp ' ] )
2022-10-21 16:16:49 +02:00
def test_missing_local_state ( mp_type : str ) :
2022-11-09 17:54:34 +01:00
basis_generator_type = ' prodmp ' if mp_type == ' prodmp ' else ' rbf '
2022-09-30 15:07:48 +02:00
env = fancy_gym . make_bb ( ' toy-v0 ' , [ RawInterfaceWrapper ] , { } ,
2022-10-21 16:16:49 +02:00
{ ' trajectory_generator_type ' : mp_type } ,
2022-09-30 15:07:48 +02:00
{ ' controller_type ' : ' motor ' } ,
2022-10-21 16:16:49 +02:00
{ ' phase_generator_type ' : ' exp ' } ,
2022-11-09 17:54:34 +01:00
{ ' basis_generator_type ' : basis_generator_type } )
2023-06-18 11:51:01 +02:00
env . reset ( seed = SEED )
2022-09-30 15:07:48 +02:00
with pytest . raises ( NotImplementedError ) :
env . step ( env . action_space . sample ( ) )
2022-11-09 17:54:34 +01:00
@pytest.mark.parametrize ( ' mp_type ' , [ ' promp ' , ' dmp ' , ' prodmp ' ] )
2022-09-26 16:11:41 +02:00
@pytest.mark.parametrize ( ' env_wrap ' , zip ( ENV_IDS , WRAPPERS ) )
2022-09-30 15:07:48 +02:00
@pytest.mark.parametrize ( ' verbose ' , [ 1 , 2 ] )
2022-10-21 16:16:49 +02:00
def test_verbosity ( mp_type : str , env_wrap : Tuple [ str , Type [ RawInterfaceWrapper ] ] , verbose : int ) :
2022-11-09 17:54:34 +01:00
basis_generator_type = ' prodmp ' if mp_type == ' prodmp ' else ' rbf '
2022-09-26 16:11:41 +02:00
env_id , wrapper_class = env_wrap
2022-10-21 16:16:49 +02:00
env = fancy_gym . make_bb ( env_id , [ wrapper_class ] , { ' verbose ' : verbose } ,
{ ' trajectory_generator_type ' : mp_type } ,
2022-09-26 16:11:41 +02:00
{ ' controller_type ' : ' motor ' } ,
2022-10-21 16:16:49 +02:00
{ ' phase_generator_type ' : ' exp ' } ,
2022-11-09 17:54:34 +01:00
{ ' basis_generator_type ' : basis_generator_type } )
2023-06-18 11:51:01 +02:00
env . reset ( seed = SEED )
2023-05-18 17:31:40 +02:00
_obs , _reward , _terminated , _truncated , info = env . step ( env . action_space . sample ( ) )
info_keys = list ( info . keys ( ) )
2022-09-30 15:07:48 +02:00
2023-07-23 12:20:49 +02:00
env_step = make ( env_id )
2022-09-30 15:07:48 +02:00
env_step . reset ( )
2023-05-18 17:31:40 +02:00
_obs , _reward , _terminated , _truncated , info = env . step ( env . action_space . sample ( ) )
info_keys_step = info . keys ( )
2022-09-30 15:07:48 +02:00
2022-10-21 16:16:49 +02:00
assert all ( e in info_keys for e in info_keys_step )
2022-09-30 15:07:48 +02:00
assert ' trajectory_length ' in info_keys
if verbose > = 2 :
2022-10-21 16:16:49 +02:00
mp_keys = [ ' positions ' , ' velocities ' , ' step_actions ' , ' step_observations ' , ' step_rewards ' ]
assert all ( e in info_keys for e in mp_keys )
2022-09-30 15:07:48 +02:00
2022-11-09 17:54:34 +01:00
@pytest.mark.parametrize ( ' mp_type ' , [ ' promp ' , ' dmp ' , ' prodmp ' ] )
2022-09-30 15:07:48 +02:00
@pytest.mark.parametrize ( ' env_wrap ' , zip ( ENV_IDS , WRAPPERS ) )
2022-10-21 16:16:49 +02:00
def test_length ( mp_type : str , env_wrap : Tuple [ str , Type [ RawInterfaceWrapper ] ] ) :
2022-11-09 17:54:34 +01:00
basis_generator_type = ' prodmp ' if mp_type == ' prodmp ' else ' rbf '
2022-09-30 15:07:48 +02:00
env_id , wrapper_class = env_wrap
env = fancy_gym . make_bb ( env_id , [ wrapper_class ] , { } ,
2022-10-21 16:16:49 +02:00
{ ' trajectory_generator_type ' : mp_type } ,
2022-09-30 15:07:48 +02:00
{ ' controller_type ' : ' motor ' } ,
2022-10-21 16:16:49 +02:00
{ ' phase_generator_type ' : ' exp ' } ,
2023-06-18 14:25:20 +02:00
{ ' basis_generator_type ' : basis_generator_type } , fallback_max_steps = MAX_STEPS_FALLBACK )
2022-09-30 15:07:48 +02:00
2023-06-10 18:49:02 +02:00
for i in range ( 5 ) :
2023-06-18 11:51:01 +02:00
env . reset ( seed = SEED )
2023-06-18 17:47:54 +02:00
ugly_hack_to_mitigate_metaworld_bug ( env ) # TODO: Remove, when metaworld fixed it upstream
2023-05-18 17:31:40 +02:00
_obs , _reward , _terminated , _truncated , info = env . step ( env . action_space . sample ( ) )
length = info [ ' trajectory_length ' ]
2022-09-26 16:11:41 +02:00
2023-06-10 18:49:02 +02:00
assert length == env . spec . max_episode_steps , f ' Expcted total simulation length ( { length } ) to be equal to spec.max_episode_steps ( { env . spec . max_episode_steps } ), but was not during test nr. { i } '
2022-09-26 16:11:41 +02:00
2022-10-21 16:16:49 +02:00
2022-11-09 17:54:34 +01:00
@pytest.mark.parametrize ( ' mp_type ' , [ ' promp ' , ' dmp ' , ' prodmp ' ] )
2022-09-26 16:11:41 +02:00
@pytest.mark.parametrize ( ' reward_aggregation ' , [ np . sum , np . mean , np . median , lambda x : np . mean ( x [ : : 2 ] ) ] )
2022-10-24 09:24:12 +02:00
def test_aggregation ( mp_type : str , reward_aggregation : Callable [ [ np . ndarray ] , float ] ) :
2022-11-09 17:54:34 +01:00
basis_generator_type = ' prodmp ' if mp_type == ' prodmp ' else ' rbf '
2022-09-26 16:11:41 +02:00
env = fancy_gym . make_bb ( ' toy-v0 ' , [ ToyWrapper ] , { ' reward_aggregation ' : reward_aggregation } ,
2022-10-21 16:16:49 +02:00
{ ' trajectory_generator_type ' : mp_type } ,
2022-09-26 16:11:41 +02:00
{ ' controller_type ' : ' motor ' } ,
2022-10-21 16:16:49 +02:00
{ ' phase_generator_type ' : ' exp ' } ,
2022-11-09 17:54:34 +01:00
{ ' basis_generator_type ' : basis_generator_type } )
2023-06-18 11:51:01 +02:00
env . reset ( seed = SEED )
2022-09-26 16:11:41 +02:00
# ToyEnv only returns 1 as reward
2023-05-18 17:31:40 +02:00
_obs , reward , _terminated , _truncated , _info = env . step ( env . action_space . sample ( ) )
assert reward == reward_aggregation ( np . ones ( 50 , ) )
2022-09-26 16:11:41 +02:00
2022-10-21 16:16:49 +02:00
@pytest.mark.parametrize ( ' mp_type ' , [ ' promp ' , ' dmp ' ] )
2022-09-30 15:07:48 +02:00
@pytest.mark.parametrize ( ' env_wrap ' , zip ( ENV_IDS , WRAPPERS ) )
2022-10-21 16:16:49 +02:00
def test_context_space ( mp_type : str , env_wrap : Tuple [ str , Type [ RawInterfaceWrapper ] ] ) :
2022-09-30 15:07:48 +02:00
env_id , wrapper_class = env_wrap
env = fancy_gym . make_bb ( env_id , [ wrapper_class ] , { } ,
2022-10-21 16:16:49 +02:00
{ ' trajectory_generator_type ' : mp_type } ,
2022-09-30 15:07:48 +02:00
{ ' controller_type ' : ' motor ' } ,
2022-10-21 16:16:49 +02:00
{ ' phase_generator_type ' : ' exp ' } ,
2022-09-30 15:07:48 +02:00
{ ' basis_generator_type ' : ' rbf ' } )
# check if observation space matches with the specified mask values which are true
2023-07-23 12:20:49 +02:00
env_step = make ( env_id )
2022-09-30 15:07:48 +02:00
wrapper = wrapper_class ( env_step )
assert env . observation_space . shape == wrapper . context_mask [ wrapper . context_mask ] . shape
2022-11-09 17:54:34 +01:00
@pytest.mark.parametrize ( ' mp_type ' , [ ' promp ' , ' dmp ' , ' prodmp ' ] )
2022-09-30 15:07:48 +02:00
@pytest.mark.parametrize ( ' num_dof ' , [ 0 , 1 , 2 , 5 ] )
2023-05-15 16:32:45 +02:00
@pytest.mark.parametrize ( ' num_basis ' , [
pytest . param ( 0 , marks = pytest . mark . xfail ( reason = " Basis Length 0 is not yet implemented. " ) ) ,
1 , 2 , 5 ] )
2022-09-30 15:07:48 +02:00
@pytest.mark.parametrize ( ' learn_tau ' , [ True , False ] )
@pytest.mark.parametrize ( ' learn_delay ' , [ True , False ] )
2022-10-21 16:16:49 +02:00
def test_action_space ( mp_type : str , num_dof : int , num_basis : int , learn_tau : bool , learn_delay : bool ) :
2022-11-09 17:54:34 +01:00
basis_generator_type = ' prodmp ' if mp_type == ' prodmp ' else ' rbf '
2022-09-30 15:07:48 +02:00
env = fancy_gym . make_bb ( ' toy-v0 ' , [ ToyWrapper ] , { } ,
2022-10-21 16:16:49 +02:00
{ ' trajectory_generator_type ' : mp_type ,
2022-09-30 15:07:48 +02:00
' action_dim ' : num_dof
} ,
{ ' controller_type ' : ' motor ' } ,
2022-10-21 16:16:49 +02:00
{ ' phase_generator_type ' : ' exp ' ,
2022-09-30 15:07:48 +02:00
' learn_tau ' : learn_tau ,
' learn_delay ' : learn_delay
} ,
2022-11-09 17:54:34 +01:00
{ ' basis_generator_type ' : basis_generator_type ,
2022-09-30 15:07:48 +02:00
' num_basis ' : num_basis
} )
2022-10-21 16:16:49 +02:00
base_dims = num_dof * num_basis
2022-11-09 17:54:34 +01:00
additional_dims = num_dof if ' dmp ' in mp_type else 0
2022-10-21 16:16:49 +02:00
traj_modification_dims = int ( learn_tau ) + int ( learn_delay )
assert env . action_space . shape [ 0 ] == base_dims + traj_modification_dims + additional_dims
2022-09-30 15:07:48 +02:00
2022-11-09 17:54:34 +01:00
@pytest.mark.parametrize ( ' mp_type ' , [ ' promp ' , ' dmp ' , ' prodmp ' ] )
2022-09-30 15:07:48 +02:00
@pytest.mark.parametrize ( ' a ' , [ 1 ] )
@pytest.mark.parametrize ( ' b ' , [ 1.0 ] )
@pytest.mark.parametrize ( ' c ' , [ [ 1 ] , [ 1.0 ] , [ ' str ' ] , [ { ' a ' : ' b ' } ] , [ np . ones ( 3 , ) ] ] )
@pytest.mark.parametrize ( ' d ' , [ { ' a ' : 1 } , { 1 : 2.0 } , { ' a ' : [ 1.0 ] } , { ' a ' : np . ones ( 3 , ) } , { ' a ' : { ' a ' : ' b ' } } ] )
@pytest.mark.parametrize ( ' e ' , [ Object ( ) ] )
2022-10-21 16:16:49 +02:00
def test_change_env_kwargs ( mp_type : str , a : int , b : float , c : list , d : dict , e : Object ) :
2022-11-09 17:54:34 +01:00
basis_generator_type = ' prodmp ' if mp_type == ' prodmp ' else ' rbf '
2022-09-30 15:07:48 +02:00
env = fancy_gym . make_bb ( ' toy-v0 ' , [ ToyWrapper ] , { } ,
2022-10-21 16:16:49 +02:00
{ ' trajectory_generator_type ' : mp_type } ,
2022-09-30 15:07:48 +02:00
{ ' controller_type ' : ' motor ' } ,
2022-10-21 16:16:49 +02:00
{ ' phase_generator_type ' : ' exp ' } ,
2022-11-09 17:54:34 +01:00
{ ' basis_generator_type ' : basis_generator_type } ,
2022-09-30 15:07:48 +02:00
a = a , b = b , c = c , d = d , e = e
)
assert a is env . a
assert b is env . b
assert c is env . c
# Due to how gym works dict kwargs need to be copied and hence can only be checked to have the same content
assert d == env . d
assert e is env . e
2022-11-13 16:59:13 +01:00
@pytest.mark.parametrize ( ' mp_type ' , [ ' promp ' , ' prodmp ' ] )
2022-10-21 16:16:49 +02:00
@pytest.mark.parametrize ( ' tau ' , [ 0.25 , 0.5 , 0.75 , 1 ] )
def test_learn_tau ( mp_type : str , tau : float ) :
2022-11-13 16:59:13 +01:00
phase_generator_type = ' exp ' if mp_type == ' prodmp ' else ' linear '
basis_generator_type = ' prodmp ' if mp_type == ' prodmp ' else ' rbf '
2022-10-21 16:16:49 +02:00
env = fancy_gym . make_bb ( ' toy-v0 ' , [ ToyWrapper ] , { ' verbose ' : 2 } ,
{ ' trajectory_generator_type ' : mp_type ,
} ,
{ ' controller_type ' : ' motor ' } ,
2022-11-13 16:59:13 +01:00
{ ' phase_generator_type ' : phase_generator_type ,
2022-10-21 16:16:49 +02:00
' learn_tau ' : True ,
' learn_delay ' : False
} ,
2022-11-13 16:59:13 +01:00
{ ' basis_generator_type ' : basis_generator_type ,
2023-07-23 12:20:49 +02:00
} )
2022-10-21 16:16:49 +02:00
2023-07-23 12:20:49 +02:00
env . reset ( seed = SEED )
2023-05-19 14:53:04 +02:00
done = True
2022-10-21 16:16:49 +02:00
for i in range ( 5 ) :
2023-05-19 14:53:04 +02:00
if done :
2023-06-18 11:51:01 +02:00
env . reset ( seed = SEED )
2022-10-21 16:16:49 +02:00
action = env . action_space . sample ( )
action [ 0 ] = tau
2023-05-19 14:53:04 +02:00
_obs , _reward , terminated , truncated , info = env . step ( action )
done = terminated or truncated
2022-10-21 16:16:49 +02:00
length = info [ ' trajectory_length ' ]
assert length == env . spec . max_episode_steps
tau_time_steps = int ( np . round ( tau / env . dt ) )
pos = info [ ' positions ' ] . flatten ( )
vel = info [ ' velocities ' ] . flatten ( )
2022-09-26 16:11:41 +02:00
2022-10-21 16:16:49 +02:00
# Check end is all same (only true for linear basis)
2022-11-13 16:59:13 +01:00
if phase_generator_type == " linear " :
assert np . all ( pos [ tau_time_steps : ] == pos [ - 1 ] )
assert np . all ( vel [ tau_time_steps : ] == vel [ - 1 ] )
2022-09-30 15:07:48 +02:00
2022-10-21 16:16:49 +02:00
# Check active trajectory section is different to end values
assert np . all ( pos [ : tau_time_steps - 1 ] != pos [ - 1 ] )
assert np . all ( vel [ : tau_time_steps - 2 ] != vel [ - 1 ] )
2022-11-13 16:59:13 +01:00
#
#
2023-05-18 17:31:40 +02:00
2022-11-13 16:59:13 +01:00
@pytest.mark.parametrize ( ' mp_type ' , [ ' promp ' , ' prodmp ' ] )
2022-10-21 16:16:49 +02:00
@pytest.mark.parametrize ( ' delay ' , [ 0 , 0.25 , 0.5 , 0.75 ] )
def test_learn_delay ( mp_type : str , delay : float ) :
2022-11-13 16:59:13 +01:00
basis_generator_type = ' prodmp ' if mp_type == ' prodmp ' else ' rbf '
phase_generator_type = ' exp ' if mp_type == ' prodmp ' else ' linear '
2022-10-21 16:16:49 +02:00
env = fancy_gym . make_bb ( ' toy-v0 ' , [ ToyWrapper ] , { ' verbose ' : 2 } ,
{ ' trajectory_generator_type ' : mp_type ,
} ,
2022-09-30 15:07:48 +02:00
{ ' controller_type ' : ' motor ' } ,
2022-11-13 16:59:13 +01:00
{ ' phase_generator_type ' : phase_generator_type ,
2022-10-21 16:16:49 +02:00
' learn_tau ' : False ,
' learn_delay ' : True
} ,
2022-11-13 16:59:13 +01:00
{ ' basis_generator_type ' : basis_generator_type ,
2023-07-23 12:20:49 +02:00
} )
2022-09-26 16:11:41 +02:00
2023-07-23 12:20:49 +02:00
env . reset ( seed = SEED )
2023-05-19 14:53:04 +02:00
done = True
2022-10-21 16:16:49 +02:00
for i in range ( 5 ) :
2023-05-19 14:53:04 +02:00
if done :
2023-06-18 11:51:01 +02:00
env . reset ( seed = SEED )
2022-10-21 16:16:49 +02:00
action = env . action_space . sample ( )
action [ 0 ] = delay
2022-09-30 15:07:48 +02:00
2023-06-28 20:25:50 +02:00
_obs , _reward , terminated , truncated , info = env . step ( action )
2023-05-19 14:53:04 +02:00
done = terminated or truncated
2022-10-21 16:16:49 +02:00
length = info [ ' trajectory_length ' ]
assert length == env . spec . max_episode_steps
delay_time_steps = int ( np . round ( delay / env . dt ) )
pos = info [ ' positions ' ] . flatten ( )
vel = info [ ' velocities ' ] . flatten ( )
# Check beginning is all same (only true for linear basis)
assert np . all ( pos [ : max ( 1 , delay_time_steps - 1 ) ] == pos [ 0 ] )
assert np . all ( vel [ : max ( 1 , delay_time_steps - 2 ) ] == vel [ 0 ] )
# Check active trajectory section is different to beginning values
assert np . all ( pos [ max ( 1 , delay_time_steps ) : ] != pos [ 0 ] )
assert np . all ( vel [ max ( 1 , delay_time_steps ) ] != vel [ 0 ] )
2022-11-13 16:59:13 +01:00
#
#
2023-05-18 17:31:40 +02:00
2022-11-13 16:59:13 +01:00
@pytest.mark.parametrize ( ' mp_type ' , [ ' promp ' , ' prodmp ' ] )
2022-10-21 16:16:49 +02:00
@pytest.mark.parametrize ( ' tau ' , [ 0.25 , 0.5 , 0.75 , 1 ] )
@pytest.mark.parametrize ( ' delay ' , [ 0.25 , 0.5 , 0.75 , 1 ] )
def test_learn_tau_and_delay ( mp_type : str , tau : float , delay : float ) :
2022-11-13 16:59:13 +01:00
phase_generator_type = ' exp ' if mp_type == ' prodmp ' else ' linear '
basis_generator_type = ' prodmp ' if mp_type == ' prodmp ' else ' rbf '
2022-10-21 16:16:49 +02:00
env = fancy_gym . make_bb ( ' toy-v0 ' , [ ToyWrapper ] , { ' verbose ' : 2 } ,
{ ' trajectory_generator_type ' : mp_type ,
} ,
{ ' controller_type ' : ' motor ' } ,
2022-11-13 16:59:13 +01:00
{ ' phase_generator_type ' : phase_generator_type ,
2022-10-21 16:16:49 +02:00
' learn_tau ' : True ,
' learn_delay ' : True
} ,
2022-11-13 16:59:13 +01:00
{ ' basis_generator_type ' : basis_generator_type ,
2023-07-23 12:20:49 +02:00
} )
env . reset ( seed = SEED )
2022-10-21 16:16:49 +02:00
if env . spec . max_episode_steps * env . dt < delay + tau :
return
2023-05-18 17:31:40 +02:00
done = True
2022-10-21 16:16:49 +02:00
for i in range ( 5 ) :
2023-05-18 17:31:40 +02:00
if done :
2023-06-18 11:51:01 +02:00
env . reset ( seed = SEED )
2023-06-18 17:47:54 +02:00
ugly_hack_to_mitigate_metaworld_bug ( env )
2022-10-21 16:16:49 +02:00
action = env . action_space . sample ( )
action [ 0 ] = tau
action [ 1 ] = delay
2023-05-18 17:31:40 +02:00
_obs , _reward , terminated , truncated , info = env . step ( action )
done = terminated or truncated
2022-10-21 16:16:49 +02:00
length = info [ ' trajectory_length ' ]
assert length == env . spec . max_episode_steps
tau_time_steps = int ( np . round ( tau / env . dt ) )
delay_time_steps = int ( np . round ( delay / env . dt ) )
joint_time_steps = delay_time_steps + tau_time_steps
pos = info [ ' positions ' ] . flatten ( )
vel = info [ ' velocities ' ] . flatten ( )
# Check end is all same (only true for linear basis)
2022-11-13 16:59:13 +01:00
if phase_generator_type == " linear " :
assert np . all ( pos [ joint_time_steps : ] == pos [ - 1 ] )
assert np . all ( vel [ joint_time_steps : ] == vel [ - 1 ] )
2022-10-21 16:16:49 +02:00
# Check beginning is all same (only true for linear basis)
assert np . all ( pos [ : delay_time_steps - 1 ] == pos [ 0 ] )
assert np . all ( vel [ : delay_time_steps - 2 ] == vel [ 0 ] )
2022-09-30 15:07:48 +02:00
2022-10-21 16:16:49 +02:00
# Check active trajectory section is different to beginning and end values
active_pos = pos [ delay_time_steps : joint_time_steps - 1 ]
active_vel = vel [ delay_time_steps : joint_time_steps - 2 ]
assert np . all ( active_pos != pos [ - 1 ] ) and np . all ( active_pos != pos [ 0 ] )
2023-05-18 17:31:40 +02:00
assert np . all ( active_vel != vel [ - 1 ] ) and np . all ( active_vel != vel [ 0 ] )