diff --git a/test.py b/test.py index 5042e33..8f0e7fc 100755 --- a/test.py +++ b/test.py @@ -1,8 +1,5 @@ #!/usr/bin/python3 import gym -from gym.envs.registration import register -import numpy as np -import os import time import datetime @@ -67,15 +64,15 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=Tru # saveModel, n_eval_episodes) -def full(env_name='ColumbusCandyland_Aux10-v0', timesteps=35_000, saveModel=True, n_eval_episodes=4): +def full(env_name='ColumbusCandyland_Aux10-v0', timesteps=200_000, saveModel=True, n_eval_episodes=4): env = gym.make(env_name) use_sde = False - skip_num = 8 # 10 (/ start at index) - sac = True + skip_num = 4 # 10 (/ start at index) + sac = False Model = [PPO, SAC][sac] Policy = [MlpPolicyPPO, MlpPolicySAC][sac] - #projection = FrobeniusProjectionLayer() - projection = BaseProjectionLayer() + projection = FrobeniusProjectionLayer() + #projection = BaseProjectionLayer() gen = enumerate(get_legal_setups( allowedEPTs=[EnforcePositiveType.SOFTPLUS, EnforcePositiveType.ABS])) @@ -87,21 +84,21 @@ def full(env_name='ColumbusCandyland_Aux10-v0', timesteps=35_000, saveModel=True model = Model( Policy, env, - # projection=projection, + projection=projection, policy_kwargs={'dist_kwargs': {'neural_strength': ps, 'cov_strength': cs, 'parameterization_type': - pt, 'enforce_positive_type': ept, 'prob_squashing_type': ProbSquashingType.NONE}}, + pt, 'enforce_positive_type': ept, 'prob_squashing_type': ProbSquashingType.TANH}}, verbose=0, tensorboard_log=root_path+"/logs_tb/" + - env_name+"/"+['ppo', 'sac'][sac]+"_" + + env_name+"/"+['ppo', 'sac'][sac]+"_" + 'TANH_' + ("_".join([str(s) for s in setup])+['', '_sde'][use_sde])+"/", - # learning_rate=3e-4, - # gamma=0.99, - # gae_lambda=0.95, - # normalize_advantage=True, - # ent_coef=0.02, # 0.1 - # vf_coef=0.5, + learning_rate=3e-4, + gamma=0.99, + gae_lambda=0.95, + normalize_advantage=True, + ent_coef=0.02, # 0.1 + vf_coef=0.5, use_sde=use_sde, # False - # clip_range=1 # 0.2, + clip_range=1 # 0.2, ) testModel(model, timesteps, False, @@ -152,5 +149,5 @@ if __name__ == '__main__': # main('ColumbusJustState-v0') # main('ColumbusStateWithBarriers-v0') # full('ColumbusEasierObstacles-v0') - # full('ColumbusStateWithBarriers-v0') - full('LunarLanderContinuous-v2') + full('ColumbusStateWithBarriers-v0') + # full('LunarLanderContinuous-v2')