Removed old TODOs

This commit is contained in:
Dominik Moritz Roth 2022-08-28 12:07:19 +02:00
parent eb881559d6
commit 4080ad8135
3 changed files with 7 additions and 9 deletions

View File

@ -200,7 +200,6 @@ class UniversalGaussianDistribution(SB3_Distribution):
assert std_init >= 0.0, "std can not be initialized to a negative value." assert std_init >= 0.0, "std can not be initialized to a negative value."
# TODO: Implement SDE
self.latent_sde_dim = latent_sde_dim self.latent_sde_dim = latent_sde_dim
mean_actions = nn.Linear(latent_dim, self.action_dim) mean_actions = nn.Linear(latent_dim, self.action_dim)
@ -348,7 +347,6 @@ class UniversalGaussianDistribution(SB3_Distribution):
def get_noise(self, latent_sde: th.Tensor) -> th.Tensor: def get_noise(self, latent_sde: th.Tensor) -> th.Tensor:
latent_sde = latent_sde if self.learn_features else latent_sde.detach() latent_sde = latent_sde if self.learn_features else latent_sde.detach()
# # TODO: Good idea?
latent_sde = th.nn.functional.normalize(latent_sde, dim=-1) latent_sde = th.nn.functional.normalize(latent_sde, dim=-1)
# Default case: only one exploration matrix # Default case: only one exploration matrix
if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices): if len(latent_sde) == 1 or len(latent_sde) != len(self.exploration_matrices):
@ -579,7 +577,6 @@ class CholNet(nn.Module):
dim2=-1)).diag_embed() + chol.triu(1) dim2=-1)).diag_embed() + chol.triu(1)
def string(self): def string(self):
# TODO
return '<CholNet />' return '<CholNet />'

View File

@ -79,8 +79,6 @@ class ActorCriticPolicy(BasePolicy):
excluding the learning rate, to pass to the optimizer excluding the learning rate, to pass to the optimizer
""" """
# TODO: Allow passing of dist_kwargs into dist
def __init__( def __init__(
self, self,
observation_space: gym.spaces.Space, observation_space: gym.spaces.Space,

11
test.py
View File

@ -15,18 +15,21 @@ import columbus
from metastable_baselines.distributions import Strength, ParametrizationType, EnforcePositiveType, ProbSquashingType from metastable_baselines.distributions import Strength, ParametrizationType, EnforcePositiveType, ProbSquashingType
import torch as th
root_path = '.' root_path = '.'
def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=True, saveModel=True, n_eval_episodes=0): def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=True, saveModel=True, n_eval_episodes=0):
env = gym.make(env_name) env = gym.make(env_name)
use_sde = False use_sde = False
# th.autograd.set_detect_anomaly(True)
ppo = PPO( ppo = PPO(
MlpPolicyPPO, MlpPolicyPPO,
env, env,
projection=BaseProjectionLayer(), # KLProjectionLayer(trust_region_coeff=0.01), projection=BaseProjectionLayer(), # KLProjectionLayer(trust_region_coeff=0.01),
policy_kwargs={'dist_kwargs': {'neural_strength': Strength.NONE, 'cov_strength': Strength.DIAG, 'parameterization_type': policy_kwargs={'dist_kwargs': {'neural_strength': Strength.NONE, 'cov_strength': Strength.FULL, 'parameterization_type':
ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}}, ParametrizationType.CHOL, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}},
verbose=0, verbose=0,
tensorboard_log=root_path+"/logs_tb/" + tensorboard_log=root_path+"/logs_tb/" +
env_name+"/ppo"+(['', '_sde'][use_sde])+"/", env_name+"/ppo"+(['', '_sde'][use_sde])+"/",
@ -37,7 +40,7 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=Tru
ent_coef=0.1, # 0.1 ent_coef=0.1, # 0.1
vf_coef=0.5, vf_coef=0.5,
use_sde=use_sde, # False use_sde=use_sde, # False
clip_range=0.2 # 1 # 0.2, clip_range=None # 1 # 0.2,
) )
# trl_frob = PPO( # trl_frob = PPO(
# MlpPolicy, # MlpPolicy,