Making MultivariateNormal Policies work (and porting Normal to
Independent)
This commit is contained in:
		
							parent
							
								
									b1ed9fc2b8
								
							
						
					
					
						commit
						ab557a8856
					
				@ -4,7 +4,7 @@ from enum import Enum
 | 
				
			|||||||
import gym
 | 
					import gym
 | 
				
			||||||
import torch as th
 | 
					import torch as th
 | 
				
			||||||
from torch import nn
 | 
					from torch import nn
 | 
				
			||||||
from torch.distributions import Normal, MultivariateNormal
 | 
					from torch.distributions import Normal, Independent, MultivariateNormal
 | 
				
			||||||
from math import pi
 | 
					from math import pi
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from stable_baselines3.common.preprocessing import get_action_dim
 | 
					from stable_baselines3.common.preprocessing import get_action_dim
 | 
				
			||||||
@ -37,6 +37,7 @@ class Strength(Enum):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ParametrizationType(Enum):
 | 
					class ParametrizationType(Enum):
 | 
				
			||||||
 | 
					    NONE = 0
 | 
				
			||||||
    CHOL = 1
 | 
					    CHOL = 1
 | 
				
			||||||
    SPHERICAL_CHOL = 2
 | 
					    SPHERICAL_CHOL = 2
 | 
				
			||||||
    # Not (yet?) implemented:
 | 
					    # Not (yet?) implemented:
 | 
				
			||||||
@ -46,6 +47,7 @@ class ParametrizationType(Enum):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class EnforcePositiveType(Enum):
 | 
					class EnforcePositiveType(Enum):
 | 
				
			||||||
    # TODO: Allow custom params for softplus?
 | 
					    # TODO: Allow custom params for softplus?
 | 
				
			||||||
 | 
					    NONE = (0, nn.Identity())
 | 
				
			||||||
    SOFTPLUS = (1, nn.Softplus(beta=1, threshold=20))
 | 
					    SOFTPLUS = (1, nn.Softplus(beta=1, threshold=20))
 | 
				
			||||||
    ABS = (2, th.abs)
 | 
					    ABS = (2, th.abs)
 | 
				
			||||||
    RELU = (3, nn.ReLU(inplace=False))
 | 
					    RELU = (3, nn.ReLU(inplace=False))
 | 
				
			||||||
@ -89,14 +91,14 @@ def get_legal_setups(allowedEPTs=None, allowedParStrength=None, allowedCovStreng
 | 
				
			|||||||
                # TODO: Implement
 | 
					                # TODO: Implement
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
            if ps == Strength.NONE:
 | 
					            if ps == Strength.NONE:
 | 
				
			||||||
                yield (ps, cs, None, None)
 | 
					                yield (ps, cs, EnforcePositiveType.NONE, ProbSquashingType.NONE)
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                for ept in allowedEPTs:
 | 
					                for ept in allowedEPTs:
 | 
				
			||||||
                    if cs == Strength.FULL:
 | 
					                    if cs == Strength.FULL:
 | 
				
			||||||
                        for pt in allowedPTs:
 | 
					                        for pt in allowedPTs:
 | 
				
			||||||
                            yield (ps, cs, ept, pt)
 | 
					                            yield (ps, cs, ept, pt)
 | 
				
			||||||
                    else:
 | 
					                    else:
 | 
				
			||||||
                        yield (ps, cs, ept, None)
 | 
					                        yield (ps, cs, ept, ProbSquashingType.NONE)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def make_proba_distribution(
 | 
					def make_proba_distribution(
 | 
				
			||||||
@ -138,7 +140,7 @@ class UniversalGaussianDistribution(SB3_Distribution):
 | 
				
			|||||||
    :param action_dim:  Dimension of the action space.
 | 
					    :param action_dim:  Dimension of the action space.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.CHOL, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE):
 | 
					    def __init__(self, action_dim: int, use_sde: bool = False, neural_strength: Strength = Strength.DIAG, cov_strength: Strength = Strength.DIAG, parameterization_type: ParametrizationType = ParametrizationType.NONE, enforce_positive_type: EnforcePositiveType = EnforcePositiveType.ABS, prob_squashing_type: ProbSquashingType = ProbSquashingType.NONE):
 | 
				
			||||||
        super(UniversalGaussianDistribution, self).__init__()
 | 
					        super(UniversalGaussianDistribution, self).__init__()
 | 
				
			||||||
        self.action_dim = action_dim
 | 
					        self.action_dim = action_dim
 | 
				
			||||||
        self.par_strength = neural_strength
 | 
					        self.par_strength = neural_strength
 | 
				
			||||||
@ -155,16 +157,19 @@ class UniversalGaussianDistribution(SB3_Distribution):
 | 
				
			|||||||
        if use_sde:
 | 
					        if use_sde:
 | 
				
			||||||
            raise Exception('SDE is not yet implemented')
 | 
					            raise Exception('SDE is not yet implemented')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert (parameterization_type != ParametrizationType.NONE) == (
 | 
				
			||||||
 | 
					            cov_strength == Strength.FULL), 'You should set an ParameterizationType iff the cov-strength is full'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def new_dist_like_me(self, mean: th.Tensor, chol: th.Tensor):
 | 
					    def new_dist_like_me(self, mean: th.Tensor, chol: th.Tensor):
 | 
				
			||||||
        p = self.distribution
 | 
					        p = self.distribution
 | 
				
			||||||
        if isinstance(p, th.distributions.Normal):
 | 
					        if isinstance(p, Independent):
 | 
				
			||||||
            if p.stddev.shape != chol.shape:
 | 
					            if p.stddev.shape != chol.shape:
 | 
				
			||||||
                chol = th.diagonal(chol, dim1=1, dim2=2)
 | 
					                chol = th.diagonal(chol, dim1=1, dim2=2)
 | 
				
			||||||
            np = th.distributions.Normal(mean, chol)
 | 
					            np = Independent(Normal(mean, chol), 1)
 | 
				
			||||||
        elif isinstance(p, th.distributions.MultivariateNormal):
 | 
					        elif isinstance(p, MultivariateNormal):
 | 
				
			||||||
            np = th.distributions.MultivariateNormal(mean, scale_tril=chol)
 | 
					            np = MultivariateNormal(mean, scale_tril=chol)
 | 
				
			||||||
        new = UniversalGaussianDistribution(self.action_dim, neural_strength=self.par_strength, cov_strength=self.cov_strength,
 | 
					        new = UniversalGaussianDistribution(self.action_dim, neural_strength=self.par_strength, cov_strength=self.cov_strength,
 | 
				
			||||||
                                            parameterization_type=self.par_strength, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type)
 | 
					                                            parameterization_type=self.par_type, enforce_positive_type=self.enforce_positive_type, prob_squashing_type=self.prob_squashing_type)
 | 
				
			||||||
        new.distribution = np
 | 
					        new.distribution = np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return new
 | 
					        return new
 | 
				
			||||||
@ -202,9 +207,10 @@ class UniversalGaussianDistribution(SB3_Distribution):
 | 
				
			|||||||
        # TODO: latent_pi is for SDE, implement.
 | 
					        # TODO: latent_pi is for SDE, implement.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]:
 | 
					        if self.cov_strength in [Strength.NONE, Strength.SCALAR, Strength.DIAG]:
 | 
				
			||||||
            self.distribution = Normal(mean_actions, chol)
 | 
					            self.distribution = Independent(Normal(mean_actions, chol), 1)
 | 
				
			||||||
        elif self.cov_strength in [Strength.FULL]:
 | 
					        elif self.cov_strength in [Strength.FULL]:
 | 
				
			||||||
            self.distribution = MultivariateNormal(mean_actions, cholesky=chol)
 | 
					            self.distribution = MultivariateNormal(
 | 
				
			||||||
 | 
					                mean_actions, scale_tril=chol)
 | 
				
			||||||
        if self.distribution == None:
 | 
					        if self.distribution == None:
 | 
				
			||||||
            raise Exception('Unable to create torch distribution')
 | 
					            raise Exception('Unable to create torch distribution')
 | 
				
			||||||
        return self
 | 
					        return self
 | 
				
			||||||
@ -218,10 +224,10 @@ class UniversalGaussianDistribution(SB3_Distribution):
 | 
				
			|||||||
        :return:
 | 
					        :return:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        log_prob = self.distribution.log_prob(actions)
 | 
					        log_prob = self.distribution.log_prob(actions)
 | 
				
			||||||
        return sum_independent_dims(log_prob)
 | 
					        return log_prob
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def entropy(self) -> th.Tensor:
 | 
					    def entropy(self) -> th.Tensor:
 | 
				
			||||||
        return sum_independent_dims(self.distribution.entropy())
 | 
					        return self.distribution.entropy()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def sample(self) -> th.Tensor:
 | 
					    def sample(self) -> th.Tensor:
 | 
				
			||||||
        # Reparametrization trick to pass gradients
 | 
					        # Reparametrization trick to pass gradients
 | 
				
			||||||
 | 
				
			|||||||
@ -6,7 +6,7 @@ from ..distributions import UniversalGaussianDistribution, AnyDistribution
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_mean_and_chol(p: AnyDistribution, expand=False):
 | 
					def get_mean_and_chol(p: AnyDistribution, expand=False):
 | 
				
			||||||
    if isinstance(p, th.distributions.Normal):
 | 
					    if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
 | 
				
			||||||
        if expand:
 | 
					        if expand:
 | 
				
			||||||
            return p.mean, th.diag_embed(p.stddev)
 | 
					            return p.mean, th.diag_embed(p.stddev)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
@ -32,7 +32,7 @@ def get_mean_and_sqrt(p: UniversalGaussianDistribution):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_cov(p: AnyDistribution):
 | 
					def get_cov(p: AnyDistribution):
 | 
				
			||||||
    if isinstance(p, th.distributions.Normal):
 | 
					    if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
 | 
				
			||||||
        return th.diag_embed(p.variance)
 | 
					        return th.diag_embed(p.variance)
 | 
				
			||||||
    elif isinstance(p, th.distributions.MultivariateNormal):
 | 
					    elif isinstance(p, th.distributions.MultivariateNormal):
 | 
				
			||||||
        return p.covariance_matrix
 | 
					        return p.covariance_matrix
 | 
				
			||||||
@ -45,7 +45,7 @@ def get_cov(p: AnyDistribution):
 | 
				
			|||||||
def has_diag_cov(p: AnyDistribution, numerical_check=True):
 | 
					def has_diag_cov(p: AnyDistribution, numerical_check=True):
 | 
				
			||||||
    if isinstance(p, SB3_Distribution):
 | 
					    if isinstance(p, SB3_Distribution):
 | 
				
			||||||
        return has_diag_cov(p.distribution, numerical_check=numerical_check)
 | 
					        return has_diag_cov(p.distribution, numerical_check=numerical_check)
 | 
				
			||||||
    if isinstance(p, th.distributions.Normal):
 | 
					    if isinstance(p, th.distributions.Normal) or isinstance(p, th.distributions.Independent):
 | 
				
			||||||
        return True
 | 
					        return True
 | 
				
			||||||
    if not numerical_check:
 | 
					    if not numerical_check:
 | 
				
			||||||
        return False
 | 
					        return False
 | 
				
			||||||
@ -67,11 +67,15 @@ def get_diag_cov_vec(p: AnyDistribution, check_diag=True, numerical_check=True):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor):
 | 
					def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor):
 | 
				
			||||||
    if isinstance(orig_p, UniversalGaussianDistribution):
 | 
					    if isinstance(orig_p, UniversalGaussianDistribution):
 | 
				
			||||||
        return orig_p.new_list_like_me(mean, chol)
 | 
					        return orig_p.new_dist_like_me(mean, chol)
 | 
				
			||||||
    elif isinstance(orig_p, th.distributions.Normal):
 | 
					    elif isinstance(orig_p, th.distributions.Normal):
 | 
				
			||||||
        if orig_p.stddev.shape != chol.shape:
 | 
					        if orig_p.stddev.shape != chol.shape:
 | 
				
			||||||
            chol = th.diagonal(chol, dim1=1, dim2=2)
 | 
					            chol = th.diagonal(chol, dim1=1, dim2=2)
 | 
				
			||||||
        return th.distributions.Normal(mean, chol)
 | 
					        return th.distributions.Normal(mean, chol)
 | 
				
			||||||
 | 
					    elif isinstance(orig_p, th.distributions.Independent):
 | 
				
			||||||
 | 
					        if orig_p.stddev.shape != chol.shape:
 | 
				
			||||||
 | 
					            chol = th.diagonal(chol, dim1=1, dim2=2)
 | 
				
			||||||
 | 
					        return th.distributions.Independent(th.distributions.Normal(mean, chol), 1)
 | 
				
			||||||
    elif isinstance(orig_p, th.distributions.MultivariateNormal):
 | 
					    elif isinstance(orig_p, th.distributions.MultivariateNormal):
 | 
				
			||||||
        return th.distributions.MultivariateNormal(mean, scale_tril=chol)
 | 
					        return th.distributions.MultivariateNormal(mean, scale_tril=chol)
 | 
				
			||||||
    elif isinstance(orig_p, SB3_Distribution):
 | 
					    elif isinstance(orig_p, SB3_Distribution):
 | 
				
			||||||
@ -79,6 +83,10 @@ def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor):
 | 
				
			|||||||
        if isinstance(p, th.distributions.Normal):
 | 
					        if isinstance(p, th.distributions.Normal):
 | 
				
			||||||
            p_out = orig_p.__class__(orig_p.action_dim)
 | 
					            p_out = orig_p.__class__(orig_p.action_dim)
 | 
				
			||||||
            p_out.distribution = th.distributions.Normal(mean, chol)
 | 
					            p_out.distribution = th.distributions.Normal(mean, chol)
 | 
				
			||||||
 | 
					        elif isinstance(p, th.distributions.Independent):
 | 
				
			||||||
 | 
					            p_out = orig_p.__class__(orig_p.action_dim)
 | 
				
			||||||
 | 
					            p_out.distribution = th.distributions.Independent(
 | 
				
			||||||
 | 
					                th.distributions.Normal(mean, chol), 1)
 | 
				
			||||||
        elif isinstance(p, th.distributions.MultivariateNormal):
 | 
					        elif isinstance(p, th.distributions.MultivariateNormal):
 | 
				
			||||||
            p_out = orig_p.__class__(orig_p.action_dim)
 | 
					            p_out = orig_p.__class__(orig_p.action_dim)
 | 
				
			||||||
            p_out.distribution = th.distributions.MultivariateNormal(
 | 
					            p_out.distribution = th.distributions.MultivariateNormal(
 | 
				
			||||||
 | 
				
			|||||||
@ -1,4 +1,5 @@
 | 
				
			|||||||
from typing import Any, Dict, Optional, Type, Union, NamedTuple
 | 
					from typing import Any, Dict, Optional, Type, Union, NamedTuple
 | 
				
			||||||
 | 
					from more_itertools import distribute
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import torch as th
 | 
					import torch as th
 | 
				
			||||||
@ -10,6 +11,8 @@ from stable_baselines3.common.vec_env import VecEnv
 | 
				
			|||||||
from stable_baselines3.common.callbacks import BaseCallback
 | 
					from stable_baselines3.common.callbacks import BaseCallback
 | 
				
			||||||
from stable_baselines3.common.utils import obs_as_tensor
 | 
					from stable_baselines3.common.utils import obs_as_tensor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..misc.distTools import get_mean_and_chol
 | 
				
			||||||
 | 
					from ..distributions.distributions import Strength, UniversalGaussianDistribution
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TRL requires the origina mean and covariance from the policy when the datapoint was created.
 | 
					# TRL requires the origina mean and covariance from the policy when the datapoint was created.
 | 
				
			||||||
# GaussianRolloutBuffer extends the RolloutBuffer by these two fields
 | 
					# GaussianRolloutBuffer extends the RolloutBuffer by these two fields
 | 
				
			||||||
@ -120,6 +123,12 @@ class GaussianRolloutCollectorAuxclass():
 | 
				
			|||||||
    def _setup_model(self) -> None:
 | 
					    def _setup_model(self) -> None:
 | 
				
			||||||
        super()._setup_model()
 | 
					        super()._setup_model()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cov_shape = self.action_space.shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if isinstance(self.policy.action_dist, UniversalGaussianDistribution):
 | 
				
			||||||
 | 
					            if self.policy.action_dist.cov_strength == Strength.FULL:
 | 
				
			||||||
 | 
					                cov_shape = cov_shape + cov_shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.rollout_buffer = GaussianRolloutBuffer(
 | 
					        self.rollout_buffer = GaussianRolloutBuffer(
 | 
				
			||||||
            self.n_steps,
 | 
					            self.n_steps,
 | 
				
			||||||
            self.observation_space,
 | 
					            self.observation_space,
 | 
				
			||||||
@ -128,6 +137,7 @@ class GaussianRolloutCollectorAuxclass():
 | 
				
			|||||||
            gamma=self.gamma,
 | 
					            gamma=self.gamma,
 | 
				
			||||||
            gae_lambda=self.gae_lambda,
 | 
					            gae_lambda=self.gae_lambda,
 | 
				
			||||||
            n_envs=self.n_envs,
 | 
					            n_envs=self.n_envs,
 | 
				
			||||||
 | 
					            cov_shape=cov_shape,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def collect_rollouts(
 | 
					    def collect_rollouts(
 | 
				
			||||||
 | 
				
			|||||||
@ -1,2 +1,2 @@
 | 
				
			|||||||
from ..trl_pg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
 | 
					from ..ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
 | 
				
			||||||
from ..trl_pg.trl_pg import TRL_PG
 | 
					from .ppo import PPO
 | 
				
			||||||
 | 
				
			|||||||
@ -95,6 +95,7 @@ class ActorCriticPolicy(BasePolicy):
 | 
				
			|||||||
        normalize_images: bool = True,
 | 
					        normalize_images: bool = True,
 | 
				
			||||||
        optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
 | 
					        optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
 | 
				
			||||||
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
 | 
					        optimizer_kwargs: Optional[Dict[str, Any]] = None,
 | 
				
			||||||
 | 
					        dist_kwargs: Optional[Dict[str, Any]] = None,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if optimizer_kwargs is None:
 | 
					        if optimizer_kwargs is None:
 | 
				
			||||||
@ -130,15 +131,16 @@ class ActorCriticPolicy(BasePolicy):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        self.normalize_images = normalize_images
 | 
					        self.normalize_images = normalize_images
 | 
				
			||||||
        self.log_std_init = log_std_init
 | 
					        self.log_std_init = log_std_init
 | 
				
			||||||
        dist_kwargs = None
 | 
					 | 
				
			||||||
        # Keyword arguments for gSDE distribution
 | 
					        # Keyword arguments for gSDE distribution
 | 
				
			||||||
        if use_sde:
 | 
					        if use_sde:
 | 
				
			||||||
            dist_kwargs = {
 | 
					            add_dist_kwargs = {
 | 
				
			||||||
                "full_std": full_std,
 | 
					                "full_std": full_std,
 | 
				
			||||||
                "squash_output": squash_output,
 | 
					                "squash_output": squash_output,
 | 
				
			||||||
                "use_expln": use_expln,
 | 
					                "use_expln": use_expln,
 | 
				
			||||||
                "learn_features": False,
 | 
					                "learn_features": False,
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					            for k in add_dist_kwargs:
 | 
				
			||||||
 | 
					                dist_kwargs[k] = add_dist_kwargs[k]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if sde_net_arch is not None:
 | 
					        if sde_net_arch is not None:
 | 
				
			||||||
            warnings.warn(
 | 
					            warnings.warn(
 | 
				
			||||||
 | 
				
			|||||||
@ -26,7 +26,7 @@ from ..projections.kl_projection_layer import KLProjectionLayer
 | 
				
			|||||||
from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass
 | 
					from ..misc.rollout_buffer import GaussianRolloutCollectorAuxclass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TRL_PG(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
 | 
					class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Differential Trust Region Layer (TRL) for Policy Gradient (PG)
 | 
					    Differential Trust Region Layer (TRL) for Policy Gradient (PG)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -248,7 +248,12 @@ class TRL_PG(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
 | 
				
			|||||||
                q_dist = new_dist_like(
 | 
					                q_dist = new_dist_like(
 | 
				
			||||||
                    p_dist, rollout_data.means, rollout_data.stds)
 | 
					                    p_dist, rollout_data.means, rollout_data.stds)
 | 
				
			||||||
                proj_p = self.projection(p_dist, q_dist, self._global_steps)
 | 
					                proj_p = self.projection(p_dist, q_dist, self._global_steps)
 | 
				
			||||||
 | 
					                if isinstance(p_dist, th.distributions.Normal):
 | 
				
			||||||
 | 
					                    # Normal uses a weird mapping from dimensions into batch_shape
 | 
				
			||||||
                    log_prob = proj_p.log_prob(actions).sum(dim=1)
 | 
					                    log_prob = proj_p.log_prob(actions).sum(dim=1)
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    # UniversalGaussianDistribution instead uses Independent (or MultivariateNormal), which has a more rational dim mapping
 | 
				
			||||||
 | 
					                    log_prob = proj_p.log_prob(actions)
 | 
				
			||||||
                values = self.policy.value_net(latent_vf)
 | 
					                values = self.policy.value_net(latent_vf)
 | 
				
			||||||
                entropy = proj_p.entropy()
 | 
					                entropy = proj_p.entropy()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -373,10 +378,10 @@ class TRL_PG(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
 | 
				
			|||||||
        eval_env: Optional[GymEnv] = None,
 | 
					        eval_env: Optional[GymEnv] = None,
 | 
				
			||||||
        eval_freq: int = -1,
 | 
					        eval_freq: int = -1,
 | 
				
			||||||
        n_eval_episodes: int = 5,
 | 
					        n_eval_episodes: int = 5,
 | 
				
			||||||
        tb_log_name: str = "TRL_PG",
 | 
					        tb_log_name: str = "PPO",
 | 
				
			||||||
        eval_log_path: Optional[str] = None,
 | 
					        eval_log_path: Optional[str] = None,
 | 
				
			||||||
        reset_num_timesteps: bool = True,
 | 
					        reset_num_timesteps: bool = True,
 | 
				
			||||||
    ) -> "TRL_PG":
 | 
					    ) -> "PPO":
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return super().learn(
 | 
					        return super().learn(
 | 
				
			||||||
            total_timesteps=total_timesteps,
 | 
					            total_timesteps=total_timesteps,
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										11
									
								
								replay.py
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								replay.py
									
									
									
									
									
								
							@ -6,11 +6,10 @@ import os
 | 
				
			|||||||
import time
 | 
					import time
 | 
				
			||||||
import datetime
 | 
					import datetime
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from stable_baselines3 import SAC, PPO, A2C
 | 
					 | 
				
			||||||
from stable_baselines3.common.evaluation import evaluate_policy
 | 
					from stable_baselines3.common.evaluation import evaluate_policy
 | 
				
			||||||
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
 | 
					from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from metastable_baselines.trl_pg import TRL_PG
 | 
					from metastable_baselines.ppo import PPO
 | 
				
			||||||
import columbus
 | 
					import columbus
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -26,14 +25,8 @@ def main(load_path, n_eval_episodes=0):
 | 
				
			|||||||
    use_sde = file_name.find('sde') != -1
 | 
					    use_sde = file_name.find('sde') != -1
 | 
				
			||||||
    print(env_name, alg_name, alg_deriv, use_sde)
 | 
					    print(env_name, alg_name, alg_deriv, use_sde)
 | 
				
			||||||
    env = gym.make(env_name)
 | 
					    env = gym.make(env_name)
 | 
				
			||||||
    if alg_name == 'ppo':
 | 
					 | 
				
			||||||
        Model = PPO
 | 
					 | 
				
			||||||
    elif alg_name == 'trl' and alg_deriv == 'pg':
 | 
					 | 
				
			||||||
        Model = TRL_PG
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
        raise Exception('Algorithm not implemented for replay')
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model = Model.load(load_path, env=env)
 | 
					    model = PPO.load(load_path, env=env)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if n_eval_episodes:
 | 
					    if n_eval_episodes:
 | 
				
			||||||
        mean_reward, std_reward = evaluate_policy(
 | 
					        mean_reward, std_reward = evaluate_policy(
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										49
									
								
								test.py
									
									
									
									
									
								
							
							
						
						
									
										49
									
								
								test.py
									
									
									
									
									
								
							@ -6,25 +6,28 @@ import os
 | 
				
			|||||||
import time
 | 
					import time
 | 
				
			||||||
import datetime
 | 
					import datetime
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from stable_baselines3 import
 | 
					 | 
				
			||||||
from stable_baselines3.common.evaluation import evaluate_policy
 | 
					from stable_baselines3.common.evaluation import evaluate_policy
 | 
				
			||||||
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
 | 
					from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, MultiInputActorCriticPolicy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from metastable_baselines.ppo import PPO
 | 
					from metastable_baselines.ppo import PPO
 | 
				
			||||||
 | 
					# from metastable_baselines.sac import SAC
 | 
				
			||||||
from metastable_baselines.ppo.policies import MlpPolicy
 | 
					from metastable_baselines.ppo.policies import MlpPolicy
 | 
				
			||||||
from metastable_baselines.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
 | 
					from metastable_baselines.projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer
 | 
				
			||||||
import columbus
 | 
					import columbus
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#root_path = os.getcwd()
 | 
					from metastable_baselines.distributions import Strength, ParametrizationType, EnforcePositiveType, ProbSquashingType
 | 
				
			||||||
 | 
					
 | 
				
			||||||
root_path = '.'
 | 
					root_path = '.'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=True, saveModel=True, n_eval_episodes=0):
 | 
					def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=2_000_000, showRes=True, saveModel=True, n_eval_episodes=0):
 | 
				
			||||||
    env = gym.make(env_name)
 | 
					    env = gym.make(env_name)
 | 
				
			||||||
    use_sde = False
 | 
					    use_sde = False
 | 
				
			||||||
    ppo = PPO(
 | 
					    ppo = PPO(
 | 
				
			||||||
        MlpPolicy,
 | 
					        MlpPolicy,
 | 
				
			||||||
        env,
 | 
					        env,
 | 
				
			||||||
 | 
					        policy_kwargs={'dist_kwargs': {'neural_strength': Strength.DIAG, 'cov_strength': Strength.DIAG, 'parameterization_type':
 | 
				
			||||||
 | 
					                       ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}},
 | 
				
			||||||
        verbose=0,
 | 
					        verbose=0,
 | 
				
			||||||
        tensorboard_log=root_path+"/logs_tb/" +
 | 
					        tensorboard_log=root_path+"/logs_tb/" +
 | 
				
			||||||
        env_name+"/ppo"+(['', '_sde'][use_sde])+"/",
 | 
					        env_name+"/ppo"+(['', '_sde'][use_sde])+"/",
 | 
				
			||||||
@ -37,29 +40,29 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=10_000_000, showRes=Tr
 | 
				
			|||||||
        use_sde=use_sde,  # False
 | 
					        use_sde=use_sde,  # False
 | 
				
			||||||
        clip_range=0.2,
 | 
					        clip_range=0.2,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    trl_frob = PPO(
 | 
					    # trl_frob = PPO(
 | 
				
			||||||
        MlpPolicy,
 | 
					    #    MlpPolicy,
 | 
				
			||||||
        env,
 | 
					    #    env,
 | 
				
			||||||
        projection=FrobeniusProjectionLayer(),
 | 
					    #    projection=FrobeniusProjectionLayer(),
 | 
				
			||||||
        verbose=0,
 | 
					    #    verbose=0,
 | 
				
			||||||
        tensorboard_log=root_path+"/logs_tb/"+env_name +
 | 
					    #    tensorboard_log=root_path+"/logs_tb/"+env_name +
 | 
				
			||||||
        "/trl_frob"+(['', '_sde'][use_sde])+"/",
 | 
					    #    "/trl_frob"+(['', '_sde'][use_sde])+"/",
 | 
				
			||||||
        learning_rate=3e-4,
 | 
					    #    learning_rate=3e-4,
 | 
				
			||||||
        gamma=0.99,
 | 
					    #    gamma=0.99,
 | 
				
			||||||
        gae_lambda=0.95,
 | 
					    #    gae_lambda=0.95,
 | 
				
			||||||
        normalize_advantage=True,
 | 
					    #    normalize_advantage=True,
 | 
				
			||||||
        ent_coef=0.03,  # 0.1
 | 
					    #    ent_coef=0.03,  # 0.1
 | 
				
			||||||
        vf_coef=0.5,
 | 
					    #    vf_coef=0.5,
 | 
				
			||||||
        use_sde=use_sde,
 | 
					    #    use_sde=use_sde,
 | 
				
			||||||
        clip_range=2,  # 0.2
 | 
					    #    clip_range=2,  # 0.2
 | 
				
			||||||
    )
 | 
					    # )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print('PPO:')
 | 
					    print('PPO:')
 | 
				
			||||||
    testModel(ppo, timesteps, showRes,
 | 
					    testModel(ppo, timesteps, showRes,
 | 
				
			||||||
              saveModel, n_eval_episodes)
 | 
					              saveModel, n_eval_episodes)
 | 
				
			||||||
    print('TRL_frob:')
 | 
					    # print('TRL_frob:')
 | 
				
			||||||
    testModel(trl_frob, timesteps, showRes,
 | 
					    # testModel(trl_frob, timesteps, showRes,
 | 
				
			||||||
              saveModel, n_eval_episodes)
 | 
					    #          saveModel, n_eval_episodes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16):
 | 
					def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=16):
 | 
				
			||||||
@ -95,7 +98,7 @@ def testModel(model, timesteps, showRes=False, saveModel=False, n_eval_episodes=
 | 
				
			|||||||
            env.render()
 | 
					            env.render()
 | 
				
			||||||
            episode_reward += reward
 | 
					            episode_reward += reward
 | 
				
			||||||
            if done:
 | 
					            if done:
 | 
				
			||||||
                #print("Reward:", episode_reward)
 | 
					                # print("Reward:", episode_reward)
 | 
				
			||||||
                episode_reward = 0.0
 | 
					                episode_reward = 0.0
 | 
				
			||||||
                obs = env.reset()
 | 
					                obs = env.reset()
 | 
				
			||||||
    env.reset()
 | 
					    env.reset()
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user