diff --git a/test.py b/test.py index 8f0e7fc..2edce78 100755 --- a/test.py +++ b/test.py @@ -20,11 +20,11 @@ root_path = '.' def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=True, saveModel=True, n_eval_episodes=0): env = gym.make(env_name) - use_sde = False + use_sde = True ppo = PPO( MlpPolicyPPO, env, - projection=FrobeniusProjectionLayer(), + projection=BaseProjectionLayer(), policy_kwargs={'dist_kwargs': {'neural_strength': Strength.FULL, 'cov_strength': Strength.FULL, 'parameterization_type': ParametrizationType.CHOL, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}}, verbose=0, @@ -149,5 +149,5 @@ if __name__ == '__main__': # main('ColumbusJustState-v0') # main('ColumbusStateWithBarriers-v0') # full('ColumbusEasierObstacles-v0') - full('ColumbusStateWithBarriers-v0') + main('ColumbusSingle-v0') # full('LunarLanderContinuous-v2')