Updating to new sb3 version
This commit is contained in:
parent
1c1a909d27
commit
865efe4221
@ -1,12 +1,5 @@
|
|||||||
from sbBrix.ppo import PPO
|
from metastable_baselines2.ppo import PPO
|
||||||
from sbBrix.sac import SAC
|
from metastable_baselines2.sac import SAC
|
||||||
|
|
||||||
try:
|
|
||||||
import priorConditionedAnnealing as pca
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
class pca():
|
|
||||||
def PCA_Distribution(*args, **kwargs):
|
|
||||||
raise Exception('PCA is not installed; cannot initialize PCA_Distribution.')
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"PPO",
|
"PPO",
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from stable_baselines3.common.distributions import *
|
from stable_baselines3.common.distributions import *
|
||||||
from metastable_baselines2.pca import PCA_Distribution
|
from metastable_baselines2.common.pca import PCA_Distribution
|
||||||
|
|
||||||
|
|
||||||
def _patched_make_proba_distribution(
|
def _patched_make_proba_distribution(
|
||||||
action_space: spaces.Space, use_sde: bool = False, use_pca: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
|
action_space: spaces.Space, use_sde: bool = False, use_pca: bool = False, dist_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
@ -1,5 +1,29 @@
|
|||||||
from stable_baselines3.common.off_policy_algorithm import *
|
import io
|
||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch as th
|
||||||
|
from gymnasium import spaces
|
||||||
|
|
||||||
|
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||||
|
from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer
|
||||||
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
|
from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise
|
||||||
|
from stable_baselines3.common.policies import BasePolicy
|
||||||
|
from stable_baselines3.common.save_util import load_from_pkl, save_to_pkl
|
||||||
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit
|
||||||
|
from stable_baselines3.common.utils import safe_mean, should_collect_more_steps
|
||||||
|
from stable_baselines3.common.vec_env import VecEnv
|
||||||
|
from stable_baselines3.her.her_replay_buffer import HerReplayBuffer
|
||||||
|
|
||||||
|
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
|
||||||
|
|
||||||
|
SelfOffPolicyAlgorithm = TypeVar("SelfOffPolicyAlgorithm", bound="BetterOffPolicyAlgorithm")
|
||||||
|
|
||||||
class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
||||||
"""
|
"""
|
||||||
@ -52,6 +76,8 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
|||||||
:param supported_action_spaces: The action spaces supported by the algorithm.
|
:param supported_action_spaces: The action spaces supported by the algorithm.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
actor: th.nn.Module
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
policy: Union[str, Type[BasePolicy]],
|
policy: Union[str, Type[BasePolicy]],
|
||||||
@ -68,7 +94,7 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
|||||||
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
|
||||||
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
optimize_memory_usage: bool = False,
|
optimize_memory_usage: bool = False,
|
||||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
policy_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
stats_window_size: int = 100,
|
stats_window_size: int = 100,
|
||||||
tensorboard_log: Optional[str] = None,
|
tensorboard_log: Optional[str] = None,
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
@ -217,6 +243,9 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
|||||||
if truncate_last_traj:
|
if truncate_last_traj:
|
||||||
self.replay_buffer.truncate_last_trajectory()
|
self.replay_buffer.truncate_last_trajectory()
|
||||||
|
|
||||||
|
# Update saved replay buffer device to match current setting, see GH#1561
|
||||||
|
self.replay_buffer.device = self.device
|
||||||
|
|
||||||
def _setup_learn(
|
def _setup_learn(
|
||||||
self,
|
self,
|
||||||
total_timesteps: int,
|
total_timesteps: int,
|
||||||
@ -248,10 +277,21 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
|||||||
"You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`"
|
"You should use `reset_num_timesteps=False` or `optimize_memory_usage=False`"
|
||||||
"to avoid that issue."
|
"to avoid that issue."
|
||||||
)
|
)
|
||||||
|
assert replay_buffer is not None # for mypy
|
||||||
# Go to the previous index
|
# Go to the previous index
|
||||||
pos = (replay_buffer.pos - 1) % replay_buffer.buffer_size
|
pos = (replay_buffer.pos - 1) % replay_buffer.buffer_size
|
||||||
replay_buffer.dones[pos] = True
|
replay_buffer.dones[pos] = True
|
||||||
|
|
||||||
|
assert self.env is not None, "You must set the environment before calling _setup_learn()"
|
||||||
|
|
||||||
|
# Vectorize action noise if needed
|
||||||
|
if (
|
||||||
|
self.action_noise is not None
|
||||||
|
and self.env.num_envs > 1
|
||||||
|
and not isinstance(self.action_noise, VectorizedActionNoise)
|
||||||
|
):
|
||||||
|
self.action_noise = VectorizedActionNoise(self.action_noise, self.env.num_envs)
|
||||||
|
|
||||||
return super()._setup_learn(
|
return super()._setup_learn(
|
||||||
total_timesteps,
|
total_timesteps,
|
||||||
callback,
|
callback,
|
||||||
@ -279,6 +319,9 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
|||||||
|
|
||||||
callback.on_training_start(locals(), globals())
|
callback.on_training_start(locals(), globals())
|
||||||
|
|
||||||
|
assert self.env is not None, "You must set the environment before calling learn()"
|
||||||
|
assert isinstance(self.train_freq, TrainFreq) # check done in _setup_learn()
|
||||||
|
|
||||||
while self.num_timesteps < total_timesteps:
|
while self.num_timesteps < total_timesteps:
|
||||||
rollout = self.collect_rollouts(
|
rollout = self.collect_rollouts(
|
||||||
self.env,
|
self.env,
|
||||||
@ -290,7 +333,7 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
|||||||
log_interval=log_interval,
|
log_interval=log_interval,
|
||||||
)
|
)
|
||||||
|
|
||||||
if rollout.continue_training is False:
|
if not rollout.continue_training:
|
||||||
break
|
break
|
||||||
|
|
||||||
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
|
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
|
||||||
@ -341,6 +384,7 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
|||||||
# Note: when using continuous actions,
|
# Note: when using continuous actions,
|
||||||
# we assume that the policy uses tanh to scale the action
|
# we assume that the policy uses tanh to scale the action
|
||||||
# We use non-deterministic action in the case of SAC, for TD3, it does not matter
|
# We use non-deterministic action in the case of SAC, for TD3, it does not matter
|
||||||
|
assert self._last_obs is not None, "self._last_obs was not set"
|
||||||
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
|
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
|
||||||
|
|
||||||
# Rescale the action from [low, high] to [-1, 1]
|
# Rescale the action from [low, high] to [-1, 1]
|
||||||
@ -364,6 +408,9 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
|||||||
"""
|
"""
|
||||||
Write log.
|
Write log.
|
||||||
"""
|
"""
|
||||||
|
assert self.ep_info_buffer is not None
|
||||||
|
assert self.ep_success_buffer is not None
|
||||||
|
|
||||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
||||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
||||||
self.logger.record("time/episodes", self._episode_num, exclude="tensorboard")
|
self.logger.record("time/episodes", self._episode_num, exclude="tensorboard")
|
||||||
@ -520,7 +567,7 @@ class BetterOffPolicyAlgorithm(OffPolicyAlgorithm):
|
|||||||
# Give access to local variables
|
# Give access to local variables
|
||||||
callback.update_locals(locals())
|
callback.update_locals(locals())
|
||||||
# Only stop training if return value is False, not when it is None.
|
# Only stop training if return value is False, not when it is None.
|
||||||
if callback.on_step() is False:
|
if not callback.on_step():
|
||||||
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training=False)
|
return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training=False)
|
||||||
|
|
||||||
# Retrieve reward and episode length if using Monitor wrapper
|
# Retrieve reward and episode length if using Monitor wrapper
|
||||||
|
@ -1,5 +1,22 @@
|
|||||||
from stable_baselines3.common.on_policy_algorithm import *
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch as th
|
||||||
|
from gymnasium import spaces
|
||||||
|
|
||||||
|
from stable_baselines3.common.base_class import BaseAlgorithm
|
||||||
|
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
|
||||||
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
|
from stable_baselines3.common.policies import ActorCriticPolicy
|
||||||
|
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
|
||||||
|
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
|
||||||
|
from stable_baselines3.common.vec_env import VecEnv
|
||||||
|
|
||||||
|
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
|
||||||
|
|
||||||
|
SelfOnPolicyAlgorithm = TypeVar("SelfOnPolicyAlgorithm", bound="BetterOnPolicyAlgorithm")
|
||||||
|
|
||||||
class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
||||||
"""
|
"""
|
||||||
@ -21,6 +38,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
instead of action noise exploration (default: False)
|
instead of action noise exploration (default: False)
|
||||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
|
||||||
Default: -1 (only sample at the beginning of the rollout)
|
Default: -1 (only sample at the beginning of the rollout)
|
||||||
|
:param use_pca: Whether to use PCA.
|
||||||
|
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
|
||||||
|
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation.
|
||||||
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
|
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
|
||||||
the reported success rate, mean episode length, and mean reward over
|
the reported success rate, mean episode length, and mean reward over
|
||||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||||
@ -50,6 +70,8 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
use_sde: bool,
|
use_sde: bool,
|
||||||
sde_sample_freq: int,
|
sde_sample_freq: int,
|
||||||
use_pca: bool,
|
use_pca: bool,
|
||||||
|
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
|
||||||
|
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
stats_window_size: int = 100,
|
stats_window_size: int = 100,
|
||||||
tensorboard_log: Optional[str] = None,
|
tensorboard_log: Optional[str] = None,
|
||||||
monitor_wrapper: bool = True,
|
monitor_wrapper: bool = True,
|
||||||
@ -62,6 +84,9 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
):
|
):
|
||||||
assert not (use_sde and use_pca)
|
assert not (use_sde and use_pca)
|
||||||
self.use_pca = use_pca
|
self.use_pca = use_pca
|
||||||
|
|
||||||
|
assert not rollout_buffer_class and not rollout_buffer_kwargs
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
policy=policy,
|
policy=policy,
|
||||||
env=env,
|
env=env,
|
||||||
@ -77,6 +102,8 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
device=device,
|
device=device,
|
||||||
use_sde=use_sde,
|
use_sde=use_sde,
|
||||||
sde_sample_freq=sde_sample_freq,
|
sde_sample_freq=sde_sample_freq,
|
||||||
|
#rollout_buffer_class = rollout_buffer_class,
|
||||||
|
#rollout_buffer_kwargs = rollout_buffer_kwargs,
|
||||||
# support_multi_env=True,
|
# support_multi_env=True,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
stats_window_size=stats_window_size,
|
stats_window_size=stats_window_size,
|
||||||
@ -86,13 +113,20 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
_init_setup_model=_init_setup_model
|
_init_setup_model=_init_setup_model
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.rollout_buffer_class = None
|
||||||
|
self.rollout_buffer_kwargs = {}
|
||||||
|
|
||||||
def _setup_model(self) -> None:
|
def _setup_model(self) -> None:
|
||||||
self._setup_lr_schedule()
|
self._setup_lr_schedule()
|
||||||
self.set_random_seed(self.seed)
|
self.set_random_seed(self.seed)
|
||||||
|
|
||||||
buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RolloutBuffer
|
if self.rollout_buffer_class is None:
|
||||||
|
if isinstance(self.observation_space, spaces.Dict):
|
||||||
|
self.rollout_buffer_class = DictRolloutBuffer
|
||||||
|
else:
|
||||||
|
self.rollout_buffer_class = RolloutBuffer
|
||||||
|
|
||||||
self.rollout_buffer = buffer_cls(
|
self.rollout_buffer = self.rollout_buffer_class(
|
||||||
self.n_steps,
|
self.n_steps,
|
||||||
self.observation_space,
|
self.observation_space,
|
||||||
self.action_space,
|
self.action_space,
|
||||||
@ -100,6 +134,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
gamma=self.gamma,
|
gamma=self.gamma,
|
||||||
gae_lambda=self.gae_lambda,
|
gae_lambda=self.gae_lambda,
|
||||||
n_envs=self.n_envs,
|
n_envs=self.n_envs,
|
||||||
|
**self.rollout_buffer_kwargs,
|
||||||
)
|
)
|
||||||
self.policy = self.policy_class( # pytype:disable=not-instantiable
|
self.policy = self.policy_class( # pytype:disable=not-instantiable
|
||||||
self.observation_space,
|
self.observation_space,
|
||||||
@ -158,7 +193,14 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
clipped_actions = actions
|
clipped_actions = actions
|
||||||
# Clip the actions to avoid out of bound error
|
# Clip the actions to avoid out of bound error
|
||||||
if isinstance(self.action_space, spaces.Box):
|
if isinstance(self.action_space, spaces.Box):
|
||||||
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
if self.policy.squash_output:
|
||||||
|
# Unscale the actions to match env bounds
|
||||||
|
# if they were previously squashed (scaled in [-1, 1])
|
||||||
|
clipped_actions = self.policy.unscale_action(clipped_actions)
|
||||||
|
else:
|
||||||
|
# Otherwise, clip the actions to avoid out of bound error
|
||||||
|
# as we are sampling from an unbounded Gaussian distribution
|
||||||
|
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
|
||||||
|
|
||||||
new_obs, rewards, dones, infos = env.step(clipped_actions)
|
new_obs, rewards, dones, infos = env.step(clipped_actions)
|
||||||
|
|
||||||
@ -166,7 +208,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
|
|
||||||
# Give access to local variables
|
# Give access to local variables
|
||||||
callback.update_locals(locals())
|
callback.update_locals(locals())
|
||||||
if callback.on_step() is False:
|
if not callback.on_step():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self._update_info_buffer(infos)
|
self._update_info_buffer(infos)
|
||||||
@ -199,6 +241,8 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
|
|
||||||
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)
|
||||||
|
|
||||||
|
callback.update_locals(locals())
|
||||||
|
|
||||||
callback.on_rollout_end()
|
callback.on_rollout_end()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@ -234,7 +278,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
while self.num_timesteps < total_timesteps:
|
while self.num_timesteps < total_timesteps:
|
||||||
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
|
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
|
||||||
|
|
||||||
if continue_training is False:
|
if not continue_training:
|
||||||
break
|
break
|
||||||
|
|
||||||
iteration += 1
|
iteration += 1
|
||||||
@ -242,6 +286,7 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
|
|
||||||
# Display training infos
|
# Display training infos
|
||||||
if log_interval is not None and iteration % log_interval == 0:
|
if log_interval is not None and iteration % log_interval == 0:
|
||||||
|
assert self.ep_info_buffer is not None
|
||||||
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
|
||||||
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
|
||||||
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
self.logger.record("time/iterations", iteration, exclude="tensorboard")
|
||||||
|
5
metastable_baselines2/common/pca.py
Normal file
5
metastable_baselines2/common/pca.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
try:
|
||||||
|
from priorConditionedAnnealing import PCA_Distribution
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
def PCA_Distribution(*args, **kwargs):
|
||||||
|
raise Exception('PCA is not installed; cannot initialize PCA_Distribution.')
|
@ -34,11 +34,12 @@ from stable_baselines3.common.torch_layers import (
|
|||||||
|
|
||||||
from stable_baselines3.common.policies import ContinuousCritic
|
from stable_baselines3.common.policies import ContinuousCritic
|
||||||
|
|
||||||
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
|
from stable_baselines3.common.type_aliases import Schedule
|
||||||
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
|
from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor
|
||||||
|
|
||||||
from .distributions import make_proba_distribution
|
from .distributions import make_proba_distribution
|
||||||
from metastable_baselines2.pca import PCA_Distribution
|
from metastable_baselines2.common.pca import PCA_Distribution
|
||||||
|
|
||||||
|
|
||||||
SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel")
|
SelfBaseModel = TypeVar("SelfBaseModel", bound="BaseModel")
|
||||||
|
|
||||||
@ -896,7 +897,7 @@ class Actor(BasePolicy):
|
|||||||
else:
|
else:
|
||||||
self.action_dist.base_noise.reset()
|
self.action_dist.base_noise.reset()
|
||||||
|
|
||||||
def get_action_dist_params(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Get the parameters for the action distribution.
|
Get the parameters for the action distribution.
|
||||||
|
|
||||||
@ -917,17 +918,17 @@ class Actor(BasePolicy):
|
|||||||
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
log_std = th.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX)
|
||||||
return mean_actions, log_std, {}
|
return mean_actions, log_std, {}
|
||||||
|
|
||||||
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||||
# Note: the action is squashed
|
# Note: the action is squashed
|
||||||
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
|
return self.action_dist.actions_from_params(mean_actions, log_std, deterministic=deterministic, **kwargs)
|
||||||
|
|
||||||
def action_log_prob(self, obs: PyTorchObs) -> Tuple[th.Tensor, th.Tensor]:
|
def action_log_prob(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||||
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
mean_actions, log_std, kwargs = self.get_action_dist_params(obs)
|
||||||
# return action and associated log prob
|
# return action and associated log prob
|
||||||
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
|
return self.action_dist.log_prob_from_params(mean_actions, log_std, **kwargs)
|
||||||
|
|
||||||
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||||
return self(observation, deterministic)
|
return self(observation, deterministic)
|
||||||
|
|
||||||
|
|
||||||
@ -1106,10 +1107,10 @@ class SACPolicy(BasePolicy):
|
|||||||
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
|
critic_kwargs = self._update_features_extractor(self.critic_kwargs, features_extractor)
|
||||||
return ContinuousCritic(**critic_kwargs).to(self.device)
|
return ContinuousCritic(**critic_kwargs).to(self.device)
|
||||||
|
|
||||||
def forward(self, obs: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
def forward(self, obs: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||||
return self._predict(obs, deterministic=deterministic)
|
return self._predict(obs, deterministic=deterministic)
|
||||||
|
|
||||||
def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
|
def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||||
return self.actor(observation, deterministic)
|
return self.actor(observation, deterministic)
|
||||||
|
|
||||||
def set_training_mode(self, mode: bool) -> None:
|
def set_training_mode(self, mode: bool) -> None:
|
||||||
|
@ -113,7 +113,7 @@ class SAC(BetterOffPolicyAlgorithm):
|
|||||||
use_pca: bool = False,
|
use_pca: bool = False,
|
||||||
stats_window_size: int = 100,
|
stats_window_size: int = 100,
|
||||||
tensorboard_log: Optional[str] = None,
|
tensorboard_log: Optional[str] = None,
|
||||||
policy_kwargs: Optional[Dict[str, Any]] = None,
|
policy_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
device: Union[th.device, str] = "auto",
|
device: Union[th.device, str] = "auto",
|
||||||
@ -156,7 +156,7 @@ class SAC(BetterOffPolicyAlgorithm):
|
|||||||
# Inverse of the reward scale
|
# Inverse of the reward scale
|
||||||
self.ent_coef = ent_coef
|
self.ent_coef = ent_coef
|
||||||
self.target_update_interval = target_update_interval
|
self.target_update_interval = target_update_interval
|
||||||
self.ent_coef_optimizer = Optional[th.optim.Adam] = None
|
self.ent_coef_optimizer = None
|
||||||
|
|
||||||
if _init_setup_model:
|
if _init_setup_model:
|
||||||
self._setup_model()
|
self._setup_model()
|
||||||
|
Loading…
Reference in New Issue
Block a user