Making MultivariateNormal Policies work (and porting Normal to
Independent)
This commit is contained in:
parent
b1ed9fc2b8
commit
ab557a8856
@ -4,7 +4,7 @@ from enum import Enum
|
|||||||
import gym
|
import gym
|
||||||
import torch as th
|
import torch as th
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.distributions import Normal, MultivariateNormal
|
from torch.distributions import Normal, Independent, MultivariateNormal
|
||||||
from math import pi
|
from math import pi
|
||||||
|
|
||||||
from stable_baselines3.common.preprocessing import get_action_dim
|
from stable_baselines3.common.preprocessing import get_action_dim
|
||||||
@ -37,6 +37,7 @@ class Strength(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class ParametrizationType(Enum):
|
class ParametrizationType(Enum):
|
||||||
|
NONE = 0
|
||||||
CHOL = 1
|
CHOL = 1
|
||||||
SPHERICAL_CHOL = 2
|
SPHERICAL_CHOL = 2
|
||||||
# Not (yet?) implemented:
|
# Not (yet?) implemented:
|
||||||
@ -46,6 +47,7 @@ class ParametrizationType(Enum):
|
|||||||
|
|
||||||
class EnforcePositiveType(Enum):
|
class EnforcePositiveType(Enum):
|
||||||
# TODO: Allow custom params for softplus?
|
# TODO: Allow custom params for softplus?
|
||||||
|
NONE = (0, nn.Identity())
|
||||||
SOFTPLUS = (1, nn.Softplus(beta=1, threshold=20))
|
SOFTPLUS = (1, nn.Softplus(beta=1, threshold=20))
|
||||||
ABS = (2, th.abs)
|
ABS = (2, th.abs)
|
||||||
RELU = (3, nn.ReLU(inplace=False))
|
RELU = (3, nn.ReLU(inplace=False))
|
||||||
@ -89,14 +91,14 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng
|
|||||||
# TODO: Implement
|
# TODO: Implement
|
||||||
continue
|
continue
|
||||||
if ps == Strength.NONE:
|
if ps == Strength.NONE:
|
||||||
yield (ps, cs, None, None)
|
yield (ps, cs, EnforcePositiveType.NONE, ProbSquashingType.NONE)
|
||||||
else:
|
else:
|
||||||
for ept in allowedEPTs:
|
for ept in allowedEPTs:
|
||||||
if cs == Strength.FULL:
|
if cs == Strength.FULL:
|
||||||
for pt in allowedPTs:
|
for pt in allowedPTs:
|
||||||
yield (ps, cs, ept, pt)
|
yield (ps, cs, ept, pt)
|
||||||
else:
|
else:
|
||||||
yield (ps, cs, ept, None)
|
yield (ps, cs, ept, ProbSquashingType.NONE)
|
||||||
|
|
||||||
|
|
||||||
def make_proba_distribution(
|
def make_proba_distribution(
|
||||||
@ -138,7 +140,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
:param action_dim: Dimension of the action space.
|
:param action_dim: Dimension of the action space.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.CHOL, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE):
|
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE):
|
||||||
super(UniversalGaussianDistribution, self).__init__()
|
super(UniversalGaussianDistribution, self).__init__()
|
||||||
self.action_dim = action_dim
|
self.action_dim = action_dim
|
||||||
self.par_strength = neural_strength
|
self.par_strength = neural_strength
|
||||||
@ -155,16 +157,19 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
if use_sde:
|
if use_sde:
|
||||||
raise Exception('SDE is not yet implemented')
|
raise Exception('SDE is not yet implemented')
|
||||||
|
|
||||||
|
assert (parameterization_type != ParametrizationType.NONE) == (
|
||||||
|
cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
|
||||||
|
|
||||||
def new_dist_like_me(self, mean: th.Tensor, chol: th.Tensor):
|
def new_dist_like_me(self, mean: th.Tensor, chol: th.Tensor):
|
||||||
p = self.distribution
|
p = self.distribution
|
||||||
if isinstance(p, th.distributions.Normal):
|
if isinstance(p, Independent):
|
||||||
if p.stddev.shape != chol.shape:
|
if p.stddev.shape != chol.shape:
|
||||||
chol = th.diagonal(chol, dim1=1, dim2=2)
|
chol = th.diagonal(chol, dim1=1, dim2=2)
|
||||||
np = th.distributions.Normal(mean, chol)
|
np = Independent(Normal(mean, chol), 1)
|
||||||
elif isinstance(p, th.distributions.MultivariateNormal):
|
elif isinstance(p, MultivariateNormal):
|
||||||
np = th.distributions.MultivariateNormal(mean, scale_tril=chol)
|
np = MultivariateNormal(mean, scale_tril=chol)
|
||||||
new = UniversalGaussianDistribution(self.action_dim, neural_strength=self.par_strength, cov_strength=self.cov_strength,
|
new = UniversalGaussianDistribution(self.action_dim, neural_strength=self.par_strength, cov_strength=self.cov_strength,
|
||||||
parameterization_type=self.par_strength, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type)
|
parameterization_type=self.par_type, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type)
|
||||||
new.distribution = np
|
new.distribution = np
|
||||||
|
|
||||||
return new
|
return new
|
||||||
@ -202,9 +207,10 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
# TODO: latent_pi is for SDE, implement.
|
# TODO: latent_pi is for SDE, implement.
|
||||||
|
|
||||||
if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]:
|
if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]:
|
||||||
self.distribution = Normal(mean_actions, chol)
|
self.distribution = Independent(Normal(mean_actions, chol), 1)
|
||||||
elif self.cov_strength in [Strength.FULL]:
|
elif self.cov_strength in [Strength.FULL]:
|
||||||
self.distribution = MultivariateNormal(mean_actions, cholesky=chol)
|
self.distribution = MultivariateNormal(
|
||||||
|
mean_actions, scale_tril=chol)
|
||||||
if self.distribution == None:
|
if self.distribution == None:
|
||||||
raise Exception('Unable to create torch distribution')
|
raise Exception('Unable to create torch distribution')
|
||||||
return self
|
return self
|
||||||
@ -218,10 +224,10 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
log_prob = self.distribution.log_prob(actions)
|
log_prob = self.distribution.log_prob(actions)
|
||||||
return sum_independent_dims(log_prob)
|
return log_prob
|
||||||
|
|
||||||
def entropy(self) -> th.Tensor:
|
def entropy(self) -> th.Tensor:
|
||||||
return sum_independent_dims(self.distribution.entropy())
|
return self.distribution.entropy()
|
||||||
|
|
||||||
def sample(self) -> th.Tensor:
|
def sample(self) -> th.Tensor:
|
||||||
# Reparametrization trick to pass gradients
|
# Reparametrization trick to pass gradients
|
||||||
|
@ -6,7 +6,7 @@ from ..distributions import UniversalGaussianDistribution, AnyDistribution
|
|||||||
|
|
||||||
|
|
||||||
def get_mean_and_chol(p: AnyDistribution, expand=False):
|
def get_mean_and_chol(p: AnyDistribution, expand=False):
|
||||||
if isinstance(p, th.distributions.Normal):
|
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
|
||||||
if expand:
|
if expand:
|
||||||
return p.mean, th.diag_embed(p.stddev)
|
return p.mean, th.diag_embed(p.stddev)
|
||||||
else:
|
else:
|
||||||
@ -32,7 +32,7 @@ def get_mean_and_sqrt(p: UniversalGaussianDistribution):
|
|||||||
|
|
||||||
|
|
||||||
def get_cov(p: AnyDistribution):
|
def get_cov(p: AnyDistribution):
|
||||||
if isinstance(p, th.distributions.Normal):
|
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
|
||||||
return th.diag_embed(p.variance)
|
return th.diag_embed(p.variance)
|
||||||
elif isinstance(p, th.distributions.MultivariateNormal):
|
elif isinstance(p, th.distributions.MultivariateNormal):
|
||||||
return p.covariance_matrix
|
return p.covariance_matrix
|
||||||
@ -45,7 +45,7 @@ def get_cov(p: AnyDistribution):
|
|||||||
def has_diag_cov(p: AnyDistribution, numerical_check=True):
|
def has_diag_cov(p: AnyDistribution, numerical_check=True):
|
||||||
if isinstance(p, SB3_Distribution):
|
if isinstance(p, SB3_Distribution):
|
||||||
return has_diag_cov(p.distribution, numerical_check=numerical_check)
|
return has_diag_cov(p.distribution, numerical_check=numerical_check)
|
||||||
if isinstance(p, th.distributions.Normal):
|
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
|
||||||
return True
|
return True
|
||||||
if not numerical_check:
|
if not numerical_check:
|
||||||
return False
|
return False
|
||||||
@ -67,11 +67,15 @@ def get_diag_cov_vec(p: AnyDistribution, check_diag=True, numerical_check=True):
|
|||||||
|
|
||||||
def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor):
|
def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor):
|
||||||
if isinstance(orig_p, UniversalGaussianDistribution):
|
if isinstance(orig_p, UniversalGaussianDistribution):
|
||||||
return orig_p.new_list_like_me(mean, chol)
|
return orig_p.new_dist_like_me(mean, chol)
|
||||||
elif isinstance(orig_p, th.distributions.Normal):
|
elif isinstance(orig_p, th.distributions.Normal):
|
||||||
if orig_p.stddev.shape != chol.shape:
|
if orig_p.stddev.shape != chol.shape:
|
||||||
chol = th.diagonal(chol, dim1=1, dim2=2)
|
chol = th.diagonal(chol, dim1=1, dim2=2)
|
||||||
return th.distributions.Normal(mean, chol)
|
return th.distributions.Normal(mean, chol)
|
||||||
|
elif isinstance(orig_p, th.distributions.Independent):
|
||||||
|
if orig_p.stddev.shape != chol.shape:
|
||||||
|
chol = th.diagonal(chol, dim1=1, dim2=2)
|
||||||
|
return th.distributions.Independent(th.distributions.Normal(mean, chol), 1)
|
||||||
elif isinstance(orig_p, th.distributions.MultivariateNormal):
|
elif isinstance(orig_p, th.distributions.MultivariateNormal):
|
||||||
return th.distributions.MultivariateNormal(mean, scale_tril=chol)
|
return th.distributions.MultivariateNormal(mean, scale_tril=chol)
|
||||||
elif isinstance(orig_p, SB3_Distribution):
|
elif isinstance(orig_p, SB3_Distribution):
|
||||||
@ -79,6 +83,10 @@ def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor):
|
|||||||
if isinstance(p, th.distributions.Normal):
|
if isinstance(p, th.distributions.Normal):
|
||||||
p_out = orig_p.__class__(orig_p.action_dim)
|
p_out = orig_p.__class__(orig_p.action_dim)
|
||||||
p_out.distribution = th.distributions.Normal(mean, chol)
|
p_out.distribution = th.distributions.Normal(mean, chol)
|
||||||
|
elif isinstance(p, th.distributions.Independent):
|
||||||
|
p_out = orig_p.__class__(orig_p.action_dim)
|
||||||
|
p_out.distribution = th.distributions.Independent(
|
||||||
|
th.distributions.Normal(mean, chol), 1)
|
||||||
elif isinstance(p, th.distributions.MultivariateNormal):
|
elif isinstance(p, th.distributions.MultivariateNormal):
|
||||||
p_out = orig_p.__class__(orig_p.action_dim)
|
p_out = orig_p.__class__(orig_p.action_dim)
|
||||||
p_out.distribution = th.distributions.MultivariateNormal(
|
p_out.distribution = th.distributions.MultivariateNormal(
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Any, Dict, Optional, Type, Union, NamedTuple
|
from typing import Any, Dict, Optional, Type, Union, NamedTuple
|
||||||
|
from more_itertools import distribute
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
@ -10,6 +11,8 @@ from stable_baselines3.common.vec_env import VecEnv
|
|||||||
from stable_baselines3.common.callbacks import BaseCallback
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
from stable_baselines3.common.utils import obs_as_tensor
|
from stable_baselines3.common.utils import obs_as_tensor
|
||||||
|
|
||||||
|
from ..misc.distTools import get_mean_and_chol
|
||||||
|
from ..distributions.distributions import Strength, UniversalGaussianDistribution
|
||||||
|
|
||||||
# TRL requires the origina mean and covariance from the policy when the datapoint was created.
|
# TRL requires the origina mean and covariance from the policy when the datapoint was created.
|
||||||
# GaussianRolloutBuffer extends the RolloutBuffer by these two fields
|
# GaussianRolloutBuffer extends the RolloutBuffer by these two fields
|
||||||
@ -120,6 +123,12 @@ class GaussianRolloutCollectorAuxclass():
|
|||||||
def _setup_model(self) -> None:
|
def _setup_model(self) -> None:
|
||||||
super()._setup_model()
|
super()._setup_model()
|
||||||
|
|
||||||
|
cov_shape = self.action_space.shape
|
||||||
|
|
||||||
|
if isinstance(self.policy.action_dist, UniversalGaussianDistribution):
|
||||||
|
if self.policy.action_dist.cov_strength == Strength.FULL:
|
||||||
|
cov_shape = cov_shape + cov_shape
|
||||||
|
|
||||||
self.rollout_buffer = GaussianRolloutBuffer(
|
self.rollout_buffer = GaussianRolloutBuffer(
|
||||||
self.n_steps,
|
self.n_steps,
|
||||||
self.observation_space,
|
self.observation_space,
|
||||||
@ -128,6 +137,7 @@ class GaussianRolloutCollectorAuxclass():
|
|||||||
gamma=self.gamma,
|
gamma=self.gamma,
|
||||||
gae_lambda=self.gae_lambda,
|
gae_lambda=self.gae_lambda,
|
||||||
n_envs=self.n_envs,
|
n_envs=self.n_envs,
|
||||||
|
cov_shape=cov_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
def collect_rollouts(
|
def collect_rollouts(
|
||||||
|
@ -1,2 +1,2 @@
|
|||||||
from ..trl_pg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
from ..ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||||
from ..trl_pg.trl_pg import TRL_PG
|
from .ppo import PPO
|
||||||
|
@ -95,6 +95,7 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
normalize_images: bool = True,
|
normalize_images: bool = True,
|
||||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
dist_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
if optimizer_kwargs is None:
|
if optimizer_kwargs is None:
|
||||||
@ -130,15 +131,16 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
|
|
||||||
self.normalize_images = normalize_images
|
self.normalize_images = normalize_images
|
||||||
self.log_std_init = log_std_init
|
self.log_std_init = log_std_init
|
||||||
dist_kwargs = None
|
|
||||||
# Keyword arguments for gSDE distribution
|
# Keyword arguments for gSDE distribution
|
||||||
if use_sde:
|
if use_sde:
|
||||||
dist_kwargs = {
|
add_dist_kwargs = {
|
||||||
"full_std": full_std,
|
"full_std": full_std,
|
||||||
"squash_output": squash_output,
|
"squash_output": squash_output,
|
||||||
"use_expln": use_expln,
|
"use_expln": use_expln,
|
||||||
"learn_features": False,
|
"learn_features": False,
|
||||||
}
|
}
|
||||||
|
for k in add_dist_kwargs:
|
||||||
|
dist_kwargs[k] = add_dist_kwargs[k]
|
||||||
|
|
||||||
if sde_net_arch is not None:
|
if sde_net_arch is not None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
@ -26,7 +26,7 @@ from ..projections.kl_projection_layer import KLProjectionLayer
|
|||||||
from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass
|
from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass
|
||||||
|
|
||||||
|
|
||||||
class TRL_PG(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||||
"""
|
"""
|
||||||
Differential Trust Region Layer (TRL) for Policy Gradient (PG)
|
Differential Trust Region Layer (TRL) for Policy Gradient (PG)
|
||||||
|
|
||||||
@ -248,7 +248,12 @@ class TRL_PG(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
|||||||
q_dist = new_dist_like(
|
q_dist = new_dist_like(
|
||||||
p_dist, rollout_data.means, rollout_data.stds)
|
p_dist, rollout_data.means, rollout_data.stds)
|
||||||
proj_p = self.projection(p_dist, q_dist, self._global_steps)
|
proj_p = self.projection(p_dist, q_dist, self._global_steps)
|
||||||
|
if isinstance(p_dist, th.distributions.Normal):
|
||||||
|
# Normal uses a weird mapping from dimensions into batch_shape
|
||||||
log_prob = proj_p.log_prob(actions).sum(dim=1)
|
log_prob = proj_p.log_prob(actions).sum(dim=1)
|
||||||
|
else:
|
||||||
|
# UniversalGaussianDistribution instead uses Independent (or MultivariateNormal), which has a more rational dim mapping
|
||||||
|
log_prob = proj_p.log_prob(actions)
|
||||||
values = self.policy.value_net(latent_vf)
|
values = self.policy.value_net(latent_vf)
|
||||||
entropy = proj_p.entropy()
|
entropy = proj_p.entropy()
|
||||||
|
|
||||||
@ -373,10 +378,10 @@ class TRL_PG(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
|||||||
eval_env: Optional[GymEnv] = None,
|
eval_env: Optional[GymEnv] = None,
|
||||||
eval_freq: int = -1,
|
eval_freq: int = -1,
|
||||||
n_eval_episodes: int = 5,
|
n_eval_episodes: int = 5,
|
||||||
tb_log_name: str = "TRL_PG",
|
tb_log_name: str = "PPO",
|
||||||
eval_log_path: Optional[str] = None,
|
eval_log_path: Optional[str] = None,
|
||||||
reset_num_timesteps: bool = True,
|
reset_num_timesteps: bool = True,
|
||||||
) -> "TRL_PG":
|
) -> "PPO":
|
||||||
|
|
||||||
return super().learn(
|
return super().learn(
|
||||||
total_timesteps=total_timesteps,
|
total_timesteps=total_timesteps,
|
||||||
|
11
replay.py
11
replay.py
@ -6,11 +6,10 @@ import os
|
|||||||
import time
|
import time
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from stable_baselines3 import SAC, PPO, A2C
|
|
||||||
from stable_baselines3.common.evaluation import evaluate_policy
|
from stable_baselines3.common.evaluation import evaluate_policy
|
||||||
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
|
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
|
||||||
|
|
||||||
from metastable_baselines.trl_pg import TRL_PG
|
from metastable_baselines.ppo import PPO
|
||||||
import columbus
|
import columbus
|
||||||
|
|
||||||
|
|
||||||
@ -26,14 +25,8 @@ def main(load_path, n_eval_episodes=0):
|
|||||||
use_sde = file_name.find('sde') != -1
|
use_sde = file_name.find('sde') != -1
|
||||||
print(env_name, alg_name, alg_deriv, use_sde)
|
print(env_name, alg_name, alg_deriv, use_sde)
|
||||||
env = gym.make(env_name)
|
env = gym.make(env_name)
|
||||||
if alg_name == 'ppo':
|
|
||||||
Model = PPO
|
|
||||||
elif alg_name == 'trl' and alg_deriv == 'pg':
|
|
||||||
Model = TRL_PG
|
|
||||||
else:
|
|
||||||
raise Exception('Algorithm not implemented for replay')
|
|
||||||
|
|
||||||
model = Model.load(load_path, env=env)
|
model = PPO.load(load_path, env=env)
|
||||||
|
|
||||||
if n_eval_episodes:
|
if n_eval_episodes:
|
||||||
mean_reward, std_reward = evaluate_policy(
|
mean_reward, std_reward = evaluate_policy(
|
||||||
|
47
test.py
47
test.py
@ -6,25 +6,28 @@ import os
|
|||||||
import time
|
import time
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from stable_baselines3 import
|
|
||||||
from stable_baselines3.common.evaluation import evaluate_policy
|
from stable_baselines3.common.evaluation import evaluate_policy
|
||||||
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
|
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
|
||||||
|
|
||||||
from metastable_baselines.ppo import PPO
|
from metastable_baselines.ppo import PPO
|
||||||
|
# from metastable_baselines.sac import SAC
|
||||||
from metastable_baselines.ppo.policies import MlpPolicy
|
from metastable_baselines.ppo.policies import MlpPolicy
|
||||||
from metastable_baselines.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
|
from metastable_baselines.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
|
||||||
import columbus
|
import columbus
|
||||||
|
|
||||||
#root_path = os.getcwd()
|
from metastable_baselines.distributions import Strength, ParametrizationType, EnforcePositiveType, ProbSquashingType
|
||||||
|
|
||||||
root_path = '.'
|
root_path = '.'
|
||||||
|
|
||||||
|
|
||||||
def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=True, saveModel=True, n_eval_episodes=0):
|
def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=2_000_000, showRes=True, saveModel=True, n_eval_episodes=0):
|
||||||
env = gym.make(env_name)
|
env = gym.make(env_name)
|
||||||
use_sde = False
|
use_sde = False
|
||||||
ppo = PPO(
|
ppo = PPO(
|
||||||
MlpPolicy,
|
MlpPolicy,
|
||||||
env,
|
env,
|
||||||
|
policy_kwargs={'dist_kwargs': {'neural_strength': Strength.DIAG, 'cov_strength': Strength.DIAG, 'parameterization_type':
|
||||||
|
ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}},
|
||||||
verbose=0,
|
verbose=0,
|
||||||
tensorboard_log=root_path+"/logs_tb/" +
|
tensorboard_log=root_path+"/logs_tb/" +
|
||||||
env_name+"/ppo"+(['', '_sde'][use_sde])+"/",
|
env_name+"/ppo"+(['', '_sde'][use_sde])+"/",
|
||||||
@ -37,29 +40,29 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=Tr
|
|||||||
use_sde=use_sde, # False
|
use_sde=use_sde, # False
|
||||||
clip_range=0.2,
|
clip_range=0.2,
|
||||||
)
|
)
|
||||||
trl_frob = PPO(
|
# trl_frob = PPO(
|
||||||
MlpPolicy,
|
# MlpPolicy,
|
||||||
env,
|
# env,
|
||||||
projection=FrobeniusProjectionLayer(),
|
# projection=FrobeniusProjectionLayer(),
|
||||||
verbose=0,
|
# verbose=0,
|
||||||
tensorboard_log=root_path+"/logs_tb/"+env_name +
|
# tensorboard_log=root_path+"/logs_tb/"+env_name +
|
||||||
"/trl_frob"+(['', '_sde'][use_sde])+"/",
|
# "/trl_frob"+(['', '_sde'][use_sde])+"/",
|
||||||
learning_rate=3e-4,
|
# learning_rate=3e-4,
|
||||||
gamma=0.99,
|
# gamma=0.99,
|
||||||
gae_lambda=0.95,
|
# gae_lambda=0.95,
|
||||||
normalize_advantage=True,
|
# normalize_advantage=True,
|
||||||
ent_coef=0.03, # 0.1
|
# ent_coef=0.03, # 0.1
|
||||||
vf_coef=0.5,
|
# vf_coef=0.5,
|
||||||
use_sde=use_sde,
|
# use_sde=use_sde,
|
||||||
clip_range=2, # 0.2
|
# clip_range=2, # 0.2
|
||||||
)
|
# )
|
||||||
|
|
||||||
print('PPO:')
|
print('PPO:')
|
||||||
testModel(ppo, timesteps, showRes,
|
testModel(ppo, timesteps, showRes,
|
||||||
saveModel, n_eval_episodes)
|
saveModel, n_eval_episodes)
|
||||||
print('TRL_frob:')
|
# print('TRL_frob:')
|
||||||
testModel(trl_frob, timesteps, showRes,
|
# testModel(trl_frob, timesteps, showRes,
|
||||||
saveModel, n_eval_episodes)
|
# saveModel, n_eval_episodes)
|
||||||
|
|
||||||
|
|
||||||
def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16):
|
def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16):
|
||||||
|
Loading…
Reference in New Issue
Block a user