Making MultivariateNormal Policies work (and porting Normal to

Independent)
This commit is contained in:
Dominik Moritz Roth 2022-07-15 15:03:51 +02:00
parent b1ed9fc2b8
commit ab557a8856
8 changed files with 84 additions and 57 deletions

View File

@ -4,7 +4,7 @@ from enum import Enum
import gym
import torch as th
from torch import nn
from torch.distributions import Normal, MultivariateNormal
from torch.distributions import Normal, Independent, MultivariateNormal
from math import pi
from stable_baselines3.common.preprocessing import get_action_dim
@ -37,6 +37,7 @@ class Strength(Enum):
class ParametrizationType(Enum):
NONE = 0
CHOL = 1
SPHERICAL_CHOL = 2
# Not (yet?) implemented:
@ -46,6 +47,7 @@ class ParametrizationType(Enum):
class EnforcePositiveType(Enum):
# TODO: Allow custom params for softplus?
NONE = (0, nn.Identity())
SOFTPLUS = (1, nn.Softplus(beta=1, threshold=20))
ABS = (2, th.abs)
RELU = (3, nn.ReLU(inplace=False))
@ -89,14 +91,14 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng
# TODO: Implement
continue
if ps == Strength.NONE:
yield (ps, cs, None, None)
yield (ps, cs, EnforcePositiveType.NONE, ProbSquashingType.NONE)
else:
for ept in allowedEPTs:
if cs == Strength.FULL:
for pt in allowedPTs:
yield (ps, cs, ept, pt)
else:
yield (ps, cs, ept, None)
yield (ps, cs, ept, ProbSquashingType.NONE)
def make_proba_distribution(
@ -138,7 +140,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
:param action_dim: Dimension of the action space.
"""
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.CHOL, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE):
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE):
super(UniversalGaussianDistribution, self).__init__()
self.action_dim = action_dim
self.par_strength = neural_strength
@ -155,16 +157,19 @@ class UniversalGaussianDistribution(SB3_Distribution):
if use_sde:
raise Exception('SDE is not yet implemented')
assert (parameterization_type != ParametrizationType.NONE) == (
cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
def new_dist_like_me(self, mean: th.Tensor, chol: th.Tensor):
p = self.distribution
if isinstance(p, th.distributions.Normal):
if isinstance(p, Independent):
if p.stddev.shape != chol.shape:
chol = th.diagonal(chol, dim1=1, dim2=2)
np = th.distributions.Normal(mean, chol)
elif isinstance(p, th.distributions.MultivariateNormal):
np = th.distributions.MultivariateNormal(mean, scale_tril=chol)
np = Independent(Normal(mean, chol), 1)
elif isinstance(p, MultivariateNormal):
np = MultivariateNormal(mean, scale_tril=chol)
new = UniversalGaussianDistribution(self.action_dim, neural_strength=self.par_strength, cov_strength=self.cov_strength,
parameterization_type=self.par_strength, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type)
parameterization_type=self.par_type, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type)
new.distribution = np
return new
@ -202,9 +207,10 @@ class UniversalGaussianDistribution(SB3_Distribution):
# TODO: latent_pi is for SDE, implement.
if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]:
self.distribution = Normal(mean_actions, chol)
self.distribution = Independent(Normal(mean_actions, chol), 1)
elif self.cov_strength in [Strength.FULL]:
self.distribution = MultivariateNormal(mean_actions, cholesky=chol)
self.distribution = MultivariateNormal(
mean_actions, scale_tril=chol)
if self.distribution == None:
raise Exception('Unable to create torch distribution')
return self
@ -218,10 +224,10 @@ class UniversalGaussianDistribution(SB3_Distribution):
:return:
"""
log_prob = self.distribution.log_prob(actions)
return sum_independent_dims(log_prob)
return log_prob
def entropy(self) -> th.Tensor:
return sum_independent_dims(self.distribution.entropy())
return self.distribution.entropy()
def sample(self) -> th.Tensor:
# Reparametrization trick to pass gradients

View File

@ -6,7 +6,7 @@ from ..distributions import UniversalGaussianDistribution, AnyDistribution
def get_mean_and_chol(p: AnyDistribution, expand=False):
if isinstance(p, th.distributions.Normal):
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
if expand:
return p.mean, th.diag_embed(p.stddev)
else:
@ -32,7 +32,7 @@ def get_mean_and_sqrt(p: UniversalGaussianDistribution):
def get_cov(p: AnyDistribution):
if isinstance(p, th.distributions.Normal):
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
return th.diag_embed(p.variance)
elif isinstance(p, th.distributions.MultivariateNormal):
return p.covariance_matrix
@ -45,7 +45,7 @@ def get_cov(p: AnyDistribution):
def has_diag_cov(p: AnyDistribution, numerical_check=True):
if isinstance(p, SB3_Distribution):
return has_diag_cov(p.distribution, numerical_check=numerical_check)
if isinstance(p, th.distributions.Normal):
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
return True
if not numerical_check:
return False
@ -67,11 +67,15 @@ def get_diag_cov_vec(p: AnyDistribution, check_diag=True, numerical_check=True):
def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor):
if isinstance(orig_p, UniversalGaussianDistribution):
return orig_p.new_list_like_me(mean, chol)
return orig_p.new_dist_like_me(mean, chol)
elif isinstance(orig_p, th.distributions.Normal):
if orig_p.stddev.shape != chol.shape:
chol = th.diagonal(chol, dim1=1, dim2=2)
return th.distributions.Normal(mean, chol)
elif isinstance(orig_p, th.distributions.Independent):
if orig_p.stddev.shape != chol.shape:
chol = th.diagonal(chol, dim1=1, dim2=2)
return th.distributions.Independent(th.distributions.Normal(mean, chol), 1)
elif isinstance(orig_p, th.distributions.MultivariateNormal):
return th.distributions.MultivariateNormal(mean, scale_tril=chol)
elif isinstance(orig_p, SB3_Distribution):
@ -79,6 +83,10 @@ def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor):
if isinstance(p, th.distributions.Normal):
p_out = orig_p.__class__(orig_p.action_dim)
p_out.distribution = th.distributions.Normal(mean, chol)
elif isinstance(p, th.distributions.Independent):
p_out = orig_p.__class__(orig_p.action_dim)
p_out.distribution = th.distributions.Independent(
th.distributions.Normal(mean, chol), 1)
elif isinstance(p, th.distributions.MultivariateNormal):
p_out = orig_p.__class__(orig_p.action_dim)
p_out.distribution = th.distributions.MultivariateNormal(

View File

@ -1,4 +1,5 @@
from typing import Any, Dict, Optional, Type, Union, NamedTuple
from more_itertools import distribute
import numpy as np
import torch as th
@ -10,6 +11,8 @@ from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.utils import obs_as_tensor
from ..misc.distTools import get_mean_and_chol
from ..distributions.distributions import Strength, UniversalGaussianDistribution
# TRL requires the origina mean and covariance from the policy when the datapoint was created.
# GaussianRolloutBuffer extends the RolloutBuffer by these two fields
@ -120,6 +123,12 @@ class GaussianRolloutCollectorAuxclass():
def _setup_model(self) -> None:
super()._setup_model()
cov_shape = self.action_space.shape
if isinstance(self.policy.action_dist, UniversalGaussianDistribution):
if self.policy.action_dist.cov_strength == Strength.FULL:
cov_shape = cov_shape + cov_shape
self.rollout_buffer = GaussianRolloutBuffer(
self.n_steps,
self.observation_space,
@ -128,6 +137,7 @@ class GaussianRolloutCollectorAuxclass():
gamma=self.gamma,
gae_lambda=self.gae_lambda,
n_envs=self.n_envs,
cov_shape=cov_shape,
)
def collect_rollouts(

View File

@ -1,2 +1,2 @@
from ..trl_pg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from ..trl_pg.trl_pg import TRL_PG
from ..ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
from .ppo import PPO

View File

@ -95,6 +95,7 @@ class ActorCriticPolicy(BasePolicy):
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
dist_kwargs: Optional[Dict[str, Any]] = None,
):
if optimizer_kwargs is None:
@ -130,15 +131,16 @@ class ActorCriticPolicy(BasePolicy):
self.normalize_images = normalize_images
self.log_std_init = log_std_init
dist_kwargs = None
# Keyword arguments for gSDE distribution
if use_sde:
dist_kwargs = {
add_dist_kwargs = {
"full_std": full_std,
"squash_output": squash_output,
"use_expln": use_expln,
"learn_features": False,
}
for k in add_dist_kwargs:
dist_kwargs[k] = add_dist_kwargs[k]
if sde_net_arch is not None:
warnings.warn(

View File

@ -26,7 +26,7 @@ from ..projections.kl_projection_layer import KLProjectionLayer
from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass
class TRL_PG(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
"""
Differential Trust Region Layer (TRL) for Policy Gradient (PG)
@ -248,7 +248,12 @@ class TRL_PG(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
q_dist = new_dist_like(
p_dist, rollout_data.means, rollout_data.stds)
proj_p = self.projection(p_dist, q_dist, self._global_steps)
if isinstance(p_dist, th.distributions.Normal):
# Normal uses a weird mapping from dimensions into batch_shape
log_prob = proj_p.log_prob(actions).sum(dim=1)
else:
# UniversalGaussianDistribution instead uses Independent (or MultivariateNormal), which has a more rational dim mapping
log_prob = proj_p.log_prob(actions)
values = self.policy.value_net(latent_vf)
entropy = proj_p.entropy()
@ -373,10 +378,10 @@ class TRL_PG(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
eval_env: Optional[GymEnv] = None,
eval_freq: int = -1,
n_eval_episodes: int = 5,
tb_log_name: str = "TRL_PG",
tb_log_name: str = "PPO",
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> "TRL_PG":
) -> "PPO":
return super().learn(
total_timesteps=total_timesteps,

View File

@ -6,11 +6,10 @@ import os
import time
import datetime
from stable_baselines3 import SAC, PPO, A2C
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
from metastable_baselines.trl_pg import TRL_PG
from metastable_baselines.ppo import PPO
import columbus
@ -26,14 +25,8 @@ def main(load_path, n_eval_episodes=0):
use_sde = file_name.find('sde') != -1
print(env_name, alg_name, alg_deriv, use_sde)
env = gym.make(env_name)
if alg_name == 'ppo':
Model = PPO
elif alg_name == 'trl' and alg_deriv == 'pg':
Model = TRL_PG
else:
raise Exception('Algorithm not implemented for replay')
model = Model.load(load_path, env=env)
model = PPO.load(load_path, env=env)
if n_eval_episodes:
mean_reward, std_reward = evaluate_policy(

47
test.py
View File

@ -6,25 +6,28 @@ import os
import time
import datetime
from stable_baselines3 import
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
from metastable_baselines.ppo import PPO
# from metastable_baselines.sac import SAC
from metastable_baselines.ppo.policies import MlpPolicy
from metastable_baselines.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
import columbus
#root_path = os.getcwd()
from metastable_baselines.distributions import Strength, ParametrizationType, EnforcePositiveType, ProbSquashingType
root_path = '.'
def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=True, saveModel=True, n_eval_episodes=0):
def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=2_000_000, showRes=True, saveModel=True, n_eval_episodes=0):
env = gym.make(env_name)
use_sde = False
ppo = PPO(
MlpPolicy,
env,
policy_kwargs={'dist_kwargs': {'neural_strength': Strength.DIAG, 'cov_strength': Strength.DIAG, 'parameterization_type':
ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}},
verbose=0,
tensorboard_log=root_path+"/logs_tb/" +
env_name+"/ppo"+(['', '_sde'][use_sde])+"/",
@ -37,29 +40,29 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=Tr
use_sde=use_sde, # False
clip_range=0.2,
)
trl_frob = PPO(
MlpPolicy,
env,
projection=FrobeniusProjectionLayer(),
verbose=0,
tensorboard_log=root_path+"/logs_tb/"+env_name +
"/trl_frob"+(['', '_sde'][use_sde])+"/",
learning_rate=3e-4,
gamma=0.99,
gae_lambda=0.95,
normalize_advantage=True,
ent_coef=0.03, # 0.1
vf_coef=0.5,
use_sde=use_sde,
clip_range=2, # 0.2
)
# trl_frob = PPO(
# MlpPolicy,
# env,
# projection=FrobeniusProjectionLayer(),
# verbose=0,
# tensorboard_log=root_path+"/logs_tb/"+env_name +
# "/trl_frob"+(['', '_sde'][use_sde])+"/",
# learning_rate=3e-4,
# gamma=0.99,
# gae_lambda=0.95,
# normalize_advantage=True,
# ent_coef=0.03, # 0.1
# vf_coef=0.5,
# use_sde=use_sde,
# clip_range=2, # 0.2
# )
print('PPO:')
testModel(ppo, timesteps, showRes,
saveModel, n_eval_episodes)
print('TRL_frob:')
testModel(trl_frob, timesteps, showRes,
saveModel, n_eval_episodes)
# print('TRL_frob:')
# testModel(trl_frob, timesteps, showRes,
# saveModel, n_eval_episodes)
def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16):