Smashing bugs: dont confuse chol with chol_net

This commit is contained in:
Dominik Moritz Roth 2022-07-19 10:07:20 +02:00
parent b7de99b1fc
commit 5f32435751

View File

@ -1,6 +1,8 @@
import warnings import warnings
from typing import Any, Dict, List, Optional, Tuple, Type, Union from typing import Any, Dict, List, Optional, Tuple, Type, Union
import math
import gym import gym
import torch as th import torch as th
from torch import nn from torch import nn
@ -18,6 +20,8 @@ from stable_baselines3.common.torch_layers import (
) )
from stable_baselines3.common.type_aliases import Schedule from stable_baselines3.common.type_aliases import Schedule
from ..distributions import UniversalGaussianDistribution
# CAP the standard deviation of the actor # CAP the standard deviation of the actor
LOG_STD_MAX = 2 LOG_STD_MAX = 2
LOG_STD_MIN = -20 LOG_STD_MIN = -20
@ -64,6 +68,7 @@ class Actor(BasePolicy):
use_expln: bool = False, use_expln: bool = False,
clip_mean: float = 2.0, clip_mean: float = 2.0,
normalize_images: bool = True, normalize_images: bool = True,
dist_kwargs={},
): ):
super().__init__( super().__init__(
observation_space, observation_space,
@ -86,7 +91,8 @@ class Actor(BasePolicy):
self.clip_mean = clip_mean self.clip_mean = clip_mean
if sde_net_arch is not None: if sde_net_arch is not None:
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning) warnings.warn(
"sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
action_dim = get_action_dim(self.action_space) action_dim = get_action_dim(self.action_space)
latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn) latent_pi_net = create_mlp(features_dim, -1, net_arch, activation_fn)
@ -94,20 +100,29 @@ class Actor(BasePolicy):
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim last_layer_dim = net_arch[-1] if len(net_arch) > 0 else features_dim
if self.use_sde: if self.use_sde:
# TODO: Port to UGD
self.action_dist = StateDependentNoiseDistribution( self.action_dist = StateDependentNoiseDistribution(
action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True action_dim, full_std=full_std, use_expln=use_expln, learn_features=True, squash_output=True
) )
self.mu, self.log_std = self.action_dist.proba_distribution_net( self.mu_net, self.chol_net = self.action_dist.proba_distribution_net(
latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, log_std_init=log_std_init
) )
# Avoid numerical issues by limiting the mean of the Gaussian # Avoid numerical issues by limiting the mean of the Gaussian
# to be in [-clip_mean, clip_mean] # to be in [-clip_mean, clip_mean]
if clip_mean > 0.0: if clip_mean > 0.0:
self.mu = nn.Sequential(self.mu, nn.Hardtanh(min_val=-clip_mean, max_val=clip_mean)) self.mu = nn.Sequential(self.mu, nn.Hardtanh(
min_val=-clip_mean, max_val=clip_mean))
else: else:
self.action_dist = SquashedDiagGaussianDistribution(action_dim) self.action_dist = UniversalGaussianDistribution(
self.mu = nn.Linear(last_layer_dim, action_dim) action_dim, **dist_kwargs)
self.log_std = nn.Linear(last_layer_dim, action_dim) self.mu_net, self.chol_net = self.action_dist.proba_distribution_net(
latent_dim=last_layer_dim, latent_sde_dim=last_layer_dim, std_init=math.exp(
self.log_std_init)
)
#self.action_dist = SquashedDiagGaussianDistribution(action_dim)
#self.mu = nn.Linear(last_layer_dim, action_dim)
#self.log_std = nn.Linear(last_layer_dim, action_dim)
def _get_constructor_parameters(self) -> Dict[str, Any]: def _get_constructor_parameters(self) -> Dict[str, Any]:
data = super()._get_constructor_parameters() data = super()._get_constructor_parameters()
@ -138,8 +153,9 @@ class Actor(BasePolicy):
:return: :return:
""" """
msg = "get_std() is only available when using gSDE" msg = "get_std() is only available when using gSDE"
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg assert isinstance(self.action_dist,
return self.action_dist.get_std(self.log_std) StateDependentNoiseDistribution), msg
return self.chol
def reset_noise(self, batch_size: int = 1) -> None: def reset_noise(self, batch_size: int = 1) -> None:
""" """
@ -148,8 +164,9 @@ class Actor(BasePolicy):
:param batch_size: :param batch_size:
""" """
msg = "reset_noise() is only available when using gSDE" msg = "reset_noise() is only available when using gSDE"
assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg assert isinstance(self.action_dist,
self.action_dist.sample_weights(self.log_std, batch_size=batch_size) StateDependentNoiseDistribution), msg
self.action_dist.sample_weights(self.chol, batch_size=batch_size)
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]: def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
""" """
@ -161,25 +178,25 @@ class Actor(BasePolicy):
""" """
features = self.extract_features(obs) features = self.extract_features(obs)
latent_pi = self.latent_pi(features) latent_pi = self.latent_pi(features)
mean_actions = self.mu(latent_pi) mean_actions = self.mu_net(latent_pi)
if self.use_sde: if self.use_sde:
return mean_actions, self.log_std, dict(latent_sde=latent_pi) return mean_actions, self.chol, dict(latent_sde=latent_pi)
# Unstructured exploration (Original implementation) # Unstructured exploration (Original implementation)
log_std = self.log_std(latent_pi) chol = self.chol_net(latent_pi)
# Original Implementation to cap the standard deviation # Original Implementation to cap the standard deviation
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) self.chol = th.clamp(chol, LOG_STD_MIN, LOG_STD_MAX)
return mean_actions, log_std, {} return mean_actions, self.chol, {}
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs) mean_actions, chol, kwargs = self.get_action_dist_params(obs)
# Note: the action is squashed # Note: the action is squashed
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs) return self.action_dist.actions_from_params(mean_actions, chol, deterministic=deterministic, **kwargs)
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
mean_actions, log_std, kwargs = self.get_action_dist_params(obs) mean_actions, chol, kwargs = self.get_action_dist_params(obs)
# return action and associated log prob # return action and associated log prob
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs) return self.action_dist.log_prob_from_params(mean_actions, chol, **kwargs)
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self(observation, deterministic) return self(observation, deterministic)
@ -236,6 +253,7 @@ class SACPolicy(BasePolicy):
optimizer_kwargs: Optional[Dict[str, Any]] = None, optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2, n_critics: int = 2,
share_features_extractor: bool = True, share_features_extractor: bool = True,
dist_kwargs={},
): ):
super().__init__( super().__init__(
observation_space, observation_space,
@ -267,7 +285,8 @@ class SACPolicy(BasePolicy):
self.actor_kwargs = self.net_args.copy() self.actor_kwargs = self.net_args.copy()
if sde_net_arch is not None: if sde_net_arch is not None:
warnings.warn("sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning) warnings.warn(
"sde_net_arch is deprecated and will be removed in SB3 v2.4.0.", DeprecationWarning)
sde_kwargs = { sde_kwargs = {
"use_sde": use_sde, "use_sde": use_sde,
@ -275,6 +294,7 @@ class SACPolicy(BasePolicy):
"use_expln": use_expln, "use_expln": use_expln,
"clip_mean": clip_mean, "clip_mean": clip_mean,
} }
self.actor_kwargs.update(sde_kwargs) self.actor_kwargs.update(sde_kwargs)
self.critic_kwargs = self.net_args.copy() self.critic_kwargs = self.net_args.copy()
self.critic_kwargs.update( self.critic_kwargs.update(
@ -289,17 +309,22 @@ class SACPolicy(BasePolicy):
self.critic, self.critic_target = None, None self.critic, self.critic_target = None, None
self.share_features_extractor = share_features_extractor self.share_features_extractor = share_features_extractor
self.dist_kwargs = dist_kwargs
self._build(lr_schedule) self._build(lr_schedule)
def _build(self, lr_schedule: Schedule) -> None: def _build(self, lr_schedule: Schedule) -> None:
self.actor = self.make_actor() self.actor = self.make_actor()
self.actor.optimizer = self.optimizer_class(self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) self.actor.optimizer = self.optimizer_class(
self.actor.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)
if self.share_features_extractor: if self.share_features_extractor:
self.critic = self.make_critic(features_extractor=self.actor.features_extractor) self.critic = self.make_critic(
features_extractor=self.actor.features_extractor)
# Do not optimize the shared features extractor with the critic loss # Do not optimize the shared features extractor with the critic loss
# otherwise, there are gradient computation issues # otherwise, there are gradient computation issues
critic_parameters = [param for name, param in self.critic.named_parameters() if "features_extractor" not in name] critic_parameters = [param for name, param in self.critic.named_parameters(
) if "features_extractor" not in name]
else: else:
# Create a separate features extractor for the critic # Create a separate features extractor for the critic
# this requires more memory and computation # this requires more memory and computation
@ -310,7 +335,8 @@ class SACPolicy(BasePolicy):
self.critic_target = self.make_critic(features_extractor=None) self.critic_target = self.make_critic(features_extractor=None)
self.critic_target.load_state_dict(self.critic.state_dict()) self.critic_target.load_state_dict(self.critic.state_dict())
self.critic.optimizer = self.optimizer_class(critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs) self.critic.optimizer = self.optimizer_class(
critic_parameters, lr=lr_schedule(1), **self.optimizer_kwargs)
# Target networks should always be in eval mode # Target networks should always be in eval mode
self.critic_target.set_training_mode(False) self.critic_target.set_training_mode(False)
@ -327,7 +353,8 @@ class SACPolicy(BasePolicy):
use_expln=self.actor_kwargs["use_expln"], use_expln=self.actor_kwargs["use_expln"],
clip_mean=self.actor_kwargs["clip_mean"], clip_mean=self.actor_kwargs["clip_mean"],
n_critics=self.critic_kwargs["n_critics"], n_critics=self.critic_kwargs["n_critics"],
lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone # dummy lr schedule, not needed for loading policy alone
lr_schedule=self._dummy_schedule,
optimizer_class=self.optimizer_class, optimizer_class=self.optimizer_class,
optimizer_kwargs=self.optimizer_kwargs, optimizer_kwargs=self.optimizer_kwargs,
features_extractor_class=self.features_extractor_class, features_extractor_class=self.features_extractor_class,
@ -345,11 +372,13 @@ class SACPolicy(BasePolicy):
self.actor.reset_noise(batch_size=batch_size) self.actor.reset_noise(batch_size=batch_size)
def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor: def make_actor(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> Actor:
actor_kwargs = self._update_features_extractor(self.actor_kwargs, features_extractor) actor_kwargs = self._update_features_extractor(
return Actor(**actor_kwargs).to(self.device) self.actor_kwargs, features_extractor)
return Actor(**actor_kwargs, dist_kwargs=self.dist_kwargs).to(self.device)
def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ContinuousCritic: def make_critic(self, features_extractor: Optional[BaseFeaturesExtractor] = None) -> ContinuousCritic:
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor) critic_kwargs = self._update_features_extractor(
self.critic_kwargs, features_extractor)
return ContinuousCritic(**critic_kwargs).to(self.device) return ContinuousCritic(**critic_kwargs).to(self.device)
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor: