Making MultivariateNormal Policies work (and porting Normal to
Independent)
This commit is contained in:
parent
b1ed9fc2b8
commit
ab557a8856
@ -4,7 +4,7 @@ from enum import Enum
|
||||
import gym
|
||||
import torch as th
|
||||
from torch import nn
|
||||
from torch.distributions import Normal, MultivariateNormal
|
||||
from torch.distributions import Normal, Independent, MultivariateNormal
|
||||
from math import pi
|
||||
|
||||
from stable_baselines3.common.preprocessing import get_action_dim
|
||||
@ -37,6 +37,7 @@ class Strength(Enum):
|
||||
|
||||
|
||||
class ParametrizationType(Enum):
|
||||
NONE = 0
|
||||
CHOL = 1
|
||||
SPHERICAL_CHOL = 2
|
||||
# Not (yet?) implemented:
|
||||
@ -46,6 +47,7 @@ class ParametrizationType(Enum):
|
||||
|
||||
class EnforcePositiveType(Enum):
|
||||
# TODO: Allow custom params for softplus?
|
||||
NONE = (0, nn.Identity())
|
||||
SOFTPLUS = (1, nn.Softplus(beta=1, threshold=20))
|
||||
ABS = (2, th.abs)
|
||||
RELU = (3, nn.ReLU(inplace=False))
|
||||
@ -89,14 +91,14 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng
|
||||
# TODO: Implement
|
||||
continue
|
||||
if ps == Strength.NONE:
|
||||
yield (ps, cs, None, None)
|
||||
yield (ps, cs, EnforcePositiveType.NONE, ProbSquashingType.NONE)
|
||||
else:
|
||||
for ept in allowedEPTs:
|
||||
if cs == Strength.FULL:
|
||||
for pt in allowedPTs:
|
||||
yield (ps, cs, ept, pt)
|
||||
else:
|
||||
yield (ps, cs, ept, None)
|
||||
yield (ps, cs, ept, ProbSquashingType.NONE)
|
||||
|
||||
|
||||
def make_proba_distribution(
|
||||
@ -138,7 +140,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
:param action_dim: Dimension of the action space.
|
||||
"""
|
||||
|
||||
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.CHOL, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE):
|
||||
def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE):
|
||||
super(UniversalGaussianDistribution, self).__init__()
|
||||
self.action_dim = action_dim
|
||||
self.par_strength = neural_strength
|
||||
@ -155,16 +157,19 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
if use_sde:
|
||||
raise Exception('SDE is not yet implemented')
|
||||
|
||||
assert (parameterization_type != ParametrizationType.NONE) == (
|
||||
cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
|
||||
|
||||
def new_dist_like_me(self, mean: th.Tensor, chol: th.Tensor):
|
||||
p = self.distribution
|
||||
if isinstance(p, th.distributions.Normal):
|
||||
if isinstance(p, Independent):
|
||||
if p.stddev.shape != chol.shape:
|
||||
chol = th.diagonal(chol, dim1=1, dim2=2)
|
||||
np = th.distributions.Normal(mean, chol)
|
||||
elif isinstance(p, th.distributions.MultivariateNormal):
|
||||
np = th.distributions.MultivariateNormal(mean, scale_tril=chol)
|
||||
np = Independent(Normal(mean, chol), 1)
|
||||
elif isinstance(p, MultivariateNormal):
|
||||
np = MultivariateNormal(mean, scale_tril=chol)
|
||||
new = UniversalGaussianDistribution(self.action_dim, neural_strength=self.par_strength, cov_strength=self.cov_strength,
|
||||
parameterization_type=self.par_strength, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type)
|
||||
parameterization_type=self.par_type, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type)
|
||||
new.distribution = np
|
||||
|
||||
return new
|
||||
@ -202,9 +207,10 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
# TODO: latent_pi is for SDE, implement.
|
||||
|
||||
if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]:
|
||||
self.distribution = Normal(mean_actions, chol)
|
||||
self.distribution = Independent(Normal(mean_actions, chol), 1)
|
||||
elif self.cov_strength in [Strength.FULL]:
|
||||
self.distribution = MultivariateNormal(mean_actions, cholesky=chol)
|
||||
self.distribution = MultivariateNormal(
|
||||
mean_actions, scale_tril=chol)
|
||||
if self.distribution == None:
|
||||
raise Exception('Unable to create torch distribution')
|
||||
return self
|
||||
@ -218,10 +224,10 @@ class UniversalGaussianDistribution(SB3_Distribution):
|
||||
:return:
|
||||
"""
|
||||
log_prob = self.distribution.log_prob(actions)
|
||||
return sum_independent_dims(log_prob)
|
||||
return log_prob
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return sum_independent_dims(self.distribution.entropy())
|
||||
return self.distribution.entropy()
|
||||
|
||||
def sample(self) -> th.Tensor:
|
||||
# Reparametrization trick to pass gradients
|
||||
|
@ -6,7 +6,7 @@ from ..distributions import UniversalGaussianDistribution, AnyDistribution
|
||||
|
||||
|
||||
def get_mean_and_chol(p: AnyDistribution, expand=False):
|
||||
if isinstance(p, th.distributions.Normal):
|
||||
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
|
||||
if expand:
|
||||
return p.mean, th.diag_embed(p.stddev)
|
||||
else:
|
||||
@ -32,7 +32,7 @@ def get_mean_and_sqrt(p: UniversalGaussianDistribution):
|
||||
|
||||
|
||||
def get_cov(p: AnyDistribution):
|
||||
if isinstance(p, th.distributions.Normal):
|
||||
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
|
||||
return th.diag_embed(p.variance)
|
||||
elif isinstance(p, th.distributions.MultivariateNormal):
|
||||
return p.covariance_matrix
|
||||
@ -45,7 +45,7 @@ def get_cov(p: AnyDistribution):
|
||||
def has_diag_cov(p: AnyDistribution, numerical_check=True):
|
||||
if isinstance(p, SB3_Distribution):
|
||||
return has_diag_cov(p.distribution, numerical_check=numerical_check)
|
||||
if isinstance(p, th.distributions.Normal):
|
||||
if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
|
||||
return True
|
||||
if not numerical_check:
|
||||
return False
|
||||
@ -67,11 +67,15 @@ def get_diag_cov_vec(p: AnyDistribution, check_diag=True, numerical_check=True):
|
||||
|
||||
def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor):
|
||||
if isinstance(orig_p, UniversalGaussianDistribution):
|
||||
return orig_p.new_list_like_me(mean, chol)
|
||||
return orig_p.new_dist_like_me(mean, chol)
|
||||
elif isinstance(orig_p, th.distributions.Normal):
|
||||
if orig_p.stddev.shape != chol.shape:
|
||||
chol = th.diagonal(chol, dim1=1, dim2=2)
|
||||
return th.distributions.Normal(mean, chol)
|
||||
elif isinstance(orig_p, th.distributions.Independent):
|
||||
if orig_p.stddev.shape != chol.shape:
|
||||
chol = th.diagonal(chol, dim1=1, dim2=2)
|
||||
return th.distributions.Independent(th.distributions.Normal(mean, chol), 1)
|
||||
elif isinstance(orig_p, th.distributions.MultivariateNormal):
|
||||
return th.distributions.MultivariateNormal(mean, scale_tril=chol)
|
||||
elif isinstance(orig_p, SB3_Distribution):
|
||||
@ -79,6 +83,10 @@ def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor):
|
||||
if isinstance(p, th.distributions.Normal):
|
||||
p_out = orig_p.__class__(orig_p.action_dim)
|
||||
p_out.distribution = th.distributions.Normal(mean, chol)
|
||||
elif isinstance(p, th.distributions.Independent):
|
||||
p_out = orig_p.__class__(orig_p.action_dim)
|
||||
p_out.distribution = th.distributions.Independent(
|
||||
th.distributions.Normal(mean, chol), 1)
|
||||
elif isinstance(p, th.distributions.MultivariateNormal):
|
||||
p_out = orig_p.__class__(orig_p.action_dim)
|
||||
p_out.distribution = th.distributions.MultivariateNormal(
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Any, Dict, Optional, Type, Union, NamedTuple
|
||||
from more_itertools import distribute
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
@ -10,6 +11,8 @@ from stable_baselines3.common.vec_env import VecEnv
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.utils import obs_as_tensor
|
||||
|
||||
from ..misc.distTools import get_mean_and_chol
|
||||
from ..distributions.distributions import Strength, UniversalGaussianDistribution
|
||||
|
||||
# TRL requires the origina mean and covariance from the policy when the datapoint was created.
|
||||
# GaussianRolloutBuffer extends the RolloutBuffer by these two fields
|
||||
@ -120,6 +123,12 @@ class GaussianRolloutCollectorAuxclass():
|
||||
def _setup_model(self) -> None:
|
||||
super()._setup_model()
|
||||
|
||||
cov_shape = self.action_space.shape
|
||||
|
||||
if isinstance(self.policy.action_dist, UniversalGaussianDistribution):
|
||||
if self.policy.action_dist.cov_strength == Strength.FULL:
|
||||
cov_shape = cov_shape + cov_shape
|
||||
|
||||
self.rollout_buffer = GaussianRolloutBuffer(
|
||||
self.n_steps,
|
||||
self.observation_space,
|
||||
@ -128,6 +137,7 @@ class GaussianRolloutCollectorAuxclass():
|
||||
gamma=self.gamma,
|
||||
gae_lambda=self.gae_lambda,
|
||||
n_envs=self.n_envs,
|
||||
cov_shape=cov_shape,
|
||||
)
|
||||
|
||||
def collect_rollouts(
|
||||
|
@ -1,2 +1,2 @@
|
||||
from ..trl_pg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||
from ..trl_pg.trl_pg import TRL_PG
|
||||
from ..ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
|
||||
from .ppo import PPO
|
||||
|
@ -95,6 +95,7 @@ class ActorCriticPolicy(BasePolicy):
|
||||
normalize_images: bool = True,
|
||||
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
|
||||
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
dist_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
|
||||
if optimizer_kwargs is None:
|
||||
@ -130,15 +131,16 @@ class ActorCriticPolicy(BasePolicy):
|
||||
|
||||
self.normalize_images = normalize_images
|
||||
self.log_std_init = log_std_init
|
||||
dist_kwargs = None
|
||||
# Keyword arguments for gSDE distribution
|
||||
if use_sde:
|
||||
dist_kwargs = {
|
||||
add_dist_kwargs = {
|
||||
"full_std": full_std,
|
||||
"squash_output": squash_output,
|
||||
"use_expln": use_expln,
|
||||
"learn_features": False,
|
||||
}
|
||||
for k in add_dist_kwargs:
|
||||
dist_kwargs[k] = add_dist_kwargs[k]
|
||||
|
||||
if sde_net_arch is not None:
|
||||
warnings.warn(
|
||||
|
@ -26,7 +26,7 @@ from ..projections.kl_projection_layer import KLProjectionLayer
|
||||
from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass
|
||||
|
||||
|
||||
class TRL_PG(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
"""
|
||||
Differential Trust Region Layer (TRL) for Policy Gradient (PG)
|
||||
|
||||
@ -248,7 +248,12 @@ class TRL_PG(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
q_dist = new_dist_like(
|
||||
p_dist, rollout_data.means, rollout_data.stds)
|
||||
proj_p = self.projection(p_dist, q_dist, self._global_steps)
|
||||
if isinstance(p_dist, th.distributions.Normal):
|
||||
# Normal uses a weird mapping from dimensions into batch_shape
|
||||
log_prob = proj_p.log_prob(actions).sum(dim=1)
|
||||
else:
|
||||
# UniversalGaussianDistribution instead uses Independent (or MultivariateNormal), which has a more rational dim mapping
|
||||
log_prob = proj_p.log_prob(actions)
|
||||
values = self.policy.value_net(latent_vf)
|
||||
entropy = proj_p.entropy()
|
||||
|
||||
@ -373,10 +378,10 @@ class TRL_PG(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
||||
eval_env: Optional[GymEnv] = None,
|
||||
eval_freq: int = -1,
|
||||
n_eval_episodes: int = 5,
|
||||
tb_log_name: str = "TRL_PG",
|
||||
tb_log_name: str = "PPO",
|
||||
eval_log_path: Optional[str] = None,
|
||||
reset_num_timesteps: bool = True,
|
||||
) -> "TRL_PG":
|
||||
) -> "PPO":
|
||||
|
||||
return super().learn(
|
||||
total_timesteps=total_timesteps,
|
||||
|
11
replay.py
11
replay.py
@ -6,11 +6,10 @@ import os
|
||||
import time
|
||||
import datetime
|
||||
|
||||
from stable_baselines3 import SAC, PPO, A2C
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
|
||||
|
||||
from metastable_baselines.trl_pg import TRL_PG
|
||||
from metastable_baselines.ppo import PPO
|
||||
import columbus
|
||||
|
||||
|
||||
@ -26,14 +25,8 @@ def main(load_path, n_eval_episodes=0):
|
||||
use_sde = file_name.find('sde') != -1
|
||||
print(env_name, alg_name, alg_deriv, use_sde)
|
||||
env = gym.make(env_name)
|
||||
if alg_name == 'ppo':
|
||||
Model = PPO
|
||||
elif alg_name == 'trl' and alg_deriv == 'pg':
|
||||
Model = TRL_PG
|
||||
else:
|
||||
raise Exception('Algorithm not implemented for replay')
|
||||
|
||||
model = Model.load(load_path, env=env)
|
||||
model = PPO.load(load_path, env=env)
|
||||
|
||||
if n_eval_episodes:
|
||||
mean_reward, std_reward = evaluate_policy(
|
||||
|
47
test.py
47
test.py
@ -6,25 +6,28 @@ import os
|
||||
import time
|
||||
import datetime
|
||||
|
||||
from stable_baselines3 import
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
|
||||
|
||||
from metastable_baselines.ppo import PPO
|
||||
# from metastable_baselines.sac import SAC
|
||||
from metastable_baselines.ppo.policies import MlpPolicy
|
||||
from metastable_baselines.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
|
||||
import columbus
|
||||
|
||||
#root_path = os.getcwd()
|
||||
from metastable_baselines.distributions import Strength, ParametrizationType, EnforcePositiveType, ProbSquashingType
|
||||
|
||||
root_path = '.'
|
||||
|
||||
|
||||
def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=True, saveModel=True, n_eval_episodes=0):
|
||||
def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=2_000_000, showRes=True, saveModel=True, n_eval_episodes=0):
|
||||
env = gym.make(env_name)
|
||||
use_sde = False
|
||||
ppo = PPO(
|
||||
MlpPolicy,
|
||||
env,
|
||||
policy_kwargs={'dist_kwargs': {'neural_strength': Strength.DIAG, 'cov_strength': Strength.DIAG, 'parameterization_type':
|
||||
ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}},
|
||||
verbose=0,
|
||||
tensorboard_log=root_path+"/logs_tb/" +
|
||||
env_name+"/ppo"+(['', '_sde'][use_sde])+"/",
|
||||
@ -37,29 +40,29 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=Tr
|
||||
use_sde=use_sde, # False
|
||||
clip_range=0.2,
|
||||
)
|
||||
trl_frob = PPO(
|
||||
MlpPolicy,
|
||||
env,
|
||||
projection=FrobeniusProjectionLayer(),
|
||||
verbose=0,
|
||||
tensorboard_log=root_path+"/logs_tb/"+env_name +
|
||||
"/trl_frob"+(['', '_sde'][use_sde])+"/",
|
||||
learning_rate=3e-4,
|
||||
gamma=0.99,
|
||||
gae_lambda=0.95,
|
||||
normalize_advantage=True,
|
||||
ent_coef=0.03, # 0.1
|
||||
vf_coef=0.5,
|
||||
use_sde=use_sde,
|
||||
clip_range=2, # 0.2
|
||||
)
|
||||
# trl_frob = PPO(
|
||||
# MlpPolicy,
|
||||
# env,
|
||||
# projection=FrobeniusProjectionLayer(),
|
||||
# verbose=0,
|
||||
# tensorboard_log=root_path+"/logs_tb/"+env_name +
|
||||
# "/trl_frob"+(['', '_sde'][use_sde])+"/",
|
||||
# learning_rate=3e-4,
|
||||
# gamma=0.99,
|
||||
# gae_lambda=0.95,
|
||||
# normalize_advantage=True,
|
||||
# ent_coef=0.03, # 0.1
|
||||
# vf_coef=0.5,
|
||||
# use_sde=use_sde,
|
||||
# clip_range=2, # 0.2
|
||||
# )
|
||||
|
||||
print('PPO:')
|
||||
testModel(ppo, timesteps, showRes,
|
||||
saveModel, n_eval_episodes)
|
||||
print('TRL_frob:')
|
||||
testModel(trl_frob, timesteps, showRes,
|
||||
saveModel, n_eval_episodes)
|
||||
# print('TRL_frob:')
|
||||
# testModel(trl_frob, timesteps, showRes,
|
||||
# saveModel, n_eval_episodes)
|
||||
|
||||
|
||||
def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16):
|
||||
|
Loading…
Reference in New Issue
Block a user