Implemented TRLRolloutBuffer
This commit is contained in:
parent
60c954c8c1
commit
b8488c531b
@ -1,5 +1,5 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, Optional, Type, Union
|
from typing import Any, Dict, Optional, Type, Union, NamedTuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
@ -14,9 +14,10 @@ from stable_baselines3.common.vec_env import VecEnv
|
|||||||
from stable_baselines3.common.buffers import RolloutBuffer
|
from stable_baselines3.common.buffers import RolloutBuffer
|
||||||
from stable_baselines3.common.callbacks import BaseCallback
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
from stable_baselines3.common.utils import obs_as_tensor
|
from stable_baselines3.common.utils import obs_as_tensor
|
||||||
|
from stable_baselines3.common.vec_env import VecNormalize
|
||||||
|
|
||||||
from ..projections.base_projection_layer import BaseProjectionLayer
|
from ..projections.base_projection_layer import BaseProjectionLayer
|
||||||
from ..projections.frob_projection_layer import FrobeniusProjectionLayer
|
# from ..projections.frob_projection_layer import FrobeniusProjectionLayer
|
||||||
|
|
||||||
|
|
||||||
class TRL_PG(OnPolicyAlgorithm):
|
class TRL_PG(OnPolicyAlgorithm):
|
||||||
@ -56,7 +57,8 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
Default: -1 (only sample at the beginning of the rollout)
|
Default: -1 (only sample at the beginning of the rollout)
|
||||||
:param target_kl: Limit the KL divergence between updates,
|
:param target_kl: Limit the KL divergence between updates,
|
||||||
because the clipping is not enough to prevent large update
|
because the clipping is not enough to prevent large update
|
||||||
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
|
# 213 (cf https://github.com/hill-a/stable-baselines/issues/213)
|
||||||
|
see issue
|
||||||
By default, there is no limit on the kl div.
|
By default, there is no limit on the kl div.
|
||||||
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
:param tensorboard_log: the log location for tensorboard (if None, no logging)
|
||||||
:param create_eval_env: Whether to create a second environment that will be
|
:param create_eval_env: Whether to create a second environment that will be
|
||||||
@ -103,7 +105,7 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
device: Union[th.device, str] = "auto",
|
device: Union[th.device, str] = "auto",
|
||||||
|
|
||||||
# Different from PPO:
|
# Different from PPO:
|
||||||
projection: BaseProjectionLayer = BaseProjectionLayer,
|
projection: BaseProjectionLayer = BaseProjectionLayer(),
|
||||||
|
|
||||||
_init_setup_model: bool = True,
|
_init_setup_model: bool = True,
|
||||||
):
|
):
|
||||||
@ -129,9 +131,9 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
_init_setup_model=False,
|
_init_setup_model=False,
|
||||||
supported_action_spaces=(
|
supported_action_spaces=(
|
||||||
spaces.Box,
|
spaces.Box,
|
||||||
spaces.Discrete,
|
# spaces.Discrete,
|
||||||
spaces.MultiDiscrete,
|
# spaces.MultiDiscrete,
|
||||||
spaces.MultiBinary,
|
# spaces.MultiBinary,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -185,6 +187,17 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
|
|
||||||
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
|
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
|
||||||
|
|
||||||
|
# Changed from PPO: We need a bigger RolloutBuffer
|
||||||
|
self.rollout_buffer = TRLRolloutBuffer(
|
||||||
|
self.n_steps,
|
||||||
|
self.observation_space,
|
||||||
|
self.action_space,
|
||||||
|
device=self.device,
|
||||||
|
gamma=self.gamma,
|
||||||
|
gae_lambda=self.gae_lambda,
|
||||||
|
n_envs=self.n_envs,
|
||||||
|
)
|
||||||
|
|
||||||
def train(self) -> None:
|
def train(self) -> None:
|
||||||
"""
|
"""
|
||||||
Update policy using the currently gathered rollout buffer.
|
Update policy using the currently gathered rollout buffer.
|
||||||
@ -237,24 +250,36 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
# log_prob == new_pogpacs (i think)
|
# log_prob == new_pogpacs (i think)
|
||||||
|
|
||||||
# src of evaluate_actions:
|
# src of evaluate_actions:
|
||||||
# features = self.extract_features(obs)
|
# pol = self.policy
|
||||||
# latent_pi, latent_vf = self.mlp_extractor(features)
|
# features = pol.extract_features(rollout_data.observations)
|
||||||
# distribution = self._get_action_dist_from_latent(latent_pi)
|
# latent_pi, latent_vf = pol.mlp_extractor(features)
|
||||||
|
# distribution = pol._get_action_dist_from_latent(latent_pi)
|
||||||
# log_prob = distribution.log_prob(actions)
|
# log_prob = distribution.log_prob(actions)
|
||||||
# values = self.value_net(latent_vf)
|
# values = pol.value_net(latent_vf)
|
||||||
# return values, log_prob, distribution.entropy()
|
# return values, log_prob, distribution.entropy()
|
||||||
|
# entropy = distribution.entropy()
|
||||||
|
|
||||||
# here we go:
|
# here we go:
|
||||||
pol = self.policy
|
pol = self.policy
|
||||||
features = pol.extract_features(rollout_data.observations)
|
features = pol.extract_features(rollout_data.observations)
|
||||||
latent_pi, latent_vf = pol.mlp_extractor(features)
|
latent_pi, latent_vf = pol.mlp_extractor(features)
|
||||||
p = pol._get_action_dist_from_latent(latent_pi)
|
p = pol._get_action_dist_from_latent(latent_pi)
|
||||||
b_q = rollout_data.mean, rollout_data.std
|
p_dist = p.distribution
|
||||||
proj_p = self.projection(pol, p, b_q, self._global_step)
|
# q_means = rollout_data.means
|
||||||
log_prob = proj_p.log_prob(actions)
|
# if len(rollout_data.stds.shape) == 1: # only diag
|
||||||
# or log_prob = pol.log_probability(proj_p, actions)
|
# q_stds = th.diag(rollout_data.stds)
|
||||||
values = self.value_net(latent_vf)
|
# else:
|
||||||
entropy = proj_p.entropy() # or not...
|
# q_stds = rollout_data.stds
|
||||||
|
# q_dist = th.distributions.MultivariateNormal(
|
||||||
|
# q_means, q_stds)
|
||||||
|
q_dist = th.distributions.Normal(
|
||||||
|
rollout_data.means, rollout_data.stds)
|
||||||
|
proj_p = self.projection(p_dist, q_dist, self._global_steps)
|
||||||
|
log_prob = proj_p.log_prob(actions).sum(dim=1)
|
||||||
|
values = self.policy.value_net(latent_vf)
|
||||||
|
entropy = proj_p.entropy()
|
||||||
|
|
||||||
|
# log_prob = p.log_prob(actions)
|
||||||
|
|
||||||
values = values.flatten()
|
values = values.flatten()
|
||||||
# Normalize advantage
|
# Normalize advantage
|
||||||
@ -304,8 +329,7 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
|
|
||||||
# Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss
|
# Difference to PPO: Added trust_region_loss; policy_loss includes entropy_loss + trust_region_loss
|
||||||
trust_region_loss = self.projection.get_trust_region_loss(
|
trust_region_loss = self.projection.get_trust_region_loss(
|
||||||
pol, p, proj_p)
|
p, proj_p)
|
||||||
# NOTE to future-self: policy has a different interface then in orig TRL-impl.
|
|
||||||
|
|
||||||
trust_region_losses.append(trust_region_loss.item())
|
trust_region_losses.append(trust_region_loss.item())
|
||||||
|
|
||||||
@ -434,10 +458,7 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
# Convert to pytorch tensor or to TensorDict
|
# Convert to pytorch tensor or to TensorDict
|
||||||
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
obs_tensor = obs_as_tensor(self._last_obs, self.device)
|
||||||
actions, values, log_probs = self.policy(obs_tensor)
|
actions, values, log_probs = self.policy(obs_tensor)
|
||||||
dist = self.policy.get_distribution(obs_tensor)
|
dist = self.policy.get_distribution(obs_tensor).distribution
|
||||||
# TODO: Enforce this requirement somwhere else...
|
|
||||||
assert isinstance(
|
|
||||||
dist, th.distributions.Normal), 'TRL is only implemented for Policys in a continuous action-space that is gauss-parametarized!'
|
|
||||||
mean, std = dist.mean, dist.stddev
|
mean, std = dist.mean, dist.stddev
|
||||||
actions = actions.cpu().numpy()
|
actions = actions.cpu().numpy()
|
||||||
|
|
||||||
@ -495,3 +516,97 @@ class TRL_PG(OnPolicyAlgorithm):
|
|||||||
callback.on_rollout_end()
|
callback.on_rollout_end()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class TRLRolloutBufferSamples(NamedTuple):
|
||||||
|
observations: th.Tensor
|
||||||
|
actions: th.Tensor
|
||||||
|
old_values: th.Tensor
|
||||||
|
old_log_prob: th.Tensor
|
||||||
|
advantages: th.Tensor
|
||||||
|
returns: th.Tensor
|
||||||
|
means: th.Tensor
|
||||||
|
stds: th.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class TRLRolloutBuffer(RolloutBuffer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
buffer_size: int,
|
||||||
|
observation_space: spaces.Space,
|
||||||
|
action_space: spaces.Space,
|
||||||
|
device: Union[th.device, str] = "cpu",
|
||||||
|
gae_lambda: float = 1,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
n_envs: int = 1,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(buffer_size, observation_space, action_space,
|
||||||
|
device, n_envs=n_envs, gae_lambda=gae_lambda, gamma=gamma)
|
||||||
|
self.means, self.stds = None, None
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
self.means = np.zeros(
|
||||||
|
(self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
|
||||||
|
self.stds = np.zeros(
|
||||||
|
# (self.buffer_size, self.n_envs) + self.action_space.shape + self.action_space.shape, dtype=np.float32)
|
||||||
|
(self.buffer_size, self.n_envs) + self.action_space.shape, dtype=np.float32)
|
||||||
|
super().reset()
|
||||||
|
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
obs: np.ndarray,
|
||||||
|
action: np.ndarray,
|
||||||
|
reward: np.ndarray,
|
||||||
|
episode_start: np.ndarray,
|
||||||
|
value: th.Tensor,
|
||||||
|
log_prob: th.Tensor,
|
||||||
|
mean: th.Tensor,
|
||||||
|
std: th.Tensor,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
:param obs: Observation
|
||||||
|
:param action: Action
|
||||||
|
:param reward:
|
||||||
|
:param episode_start: Start of episode signal.
|
||||||
|
:param value: estimated value of the current state
|
||||||
|
following the current policy.
|
||||||
|
:param log_prob: log probability of the action
|
||||||
|
following the current policy.
|
||||||
|
:param mean: Foo
|
||||||
|
:param std: Bar
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(log_prob.shape) == 0:
|
||||||
|
# Reshape 0-d tensor to avoid error
|
||||||
|
log_prob = log_prob.reshape(-1, 1)
|
||||||
|
|
||||||
|
# Reshape needed when using multiple envs with discrete observations
|
||||||
|
# as numpy cannot broadcast (n_discrete,) to (n_discrete, 1)
|
||||||
|
if isinstance(self.observation_space, spaces.Discrete):
|
||||||
|
obs = obs.reshape((self.n_envs,) + self.obs_shape)
|
||||||
|
|
||||||
|
self.observations[self.pos] = np.array(obs).copy()
|
||||||
|
self.actions[self.pos] = np.array(action).copy()
|
||||||
|
self.rewards[self.pos] = np.array(reward).copy()
|
||||||
|
self.episode_starts[self.pos] = np.array(episode_start).copy()
|
||||||
|
self.values[self.pos] = value.clone().cpu().numpy().flatten()
|
||||||
|
self.log_probs[self.pos] = log_prob.clone().cpu().numpy()
|
||||||
|
self.means[self.pos] = mean.clone().cpu().numpy()
|
||||||
|
self.stds[self.pos] = std.clone().cpu().numpy()
|
||||||
|
self.pos += 1
|
||||||
|
if self.pos == self.buffer_size:
|
||||||
|
self.full = True
|
||||||
|
|
||||||
|
def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> TRLRolloutBufferSamples:
|
||||||
|
data = (
|
||||||
|
self.observations[batch_inds],
|
||||||
|
self.actions[batch_inds],
|
||||||
|
self.values[batch_inds].flatten(),
|
||||||
|
self.log_probs[batch_inds].flatten(),
|
||||||
|
self.advantages[batch_inds].flatten(),
|
||||||
|
self.returns[batch_inds].flatten(),
|
||||||
|
self.means[batch_inds].reshape((len(batch_inds), -1)),
|
||||||
|
self.stds[batch_inds].reshape((len(batch_inds), -1)),
|
||||||
|
)
|
||||||
|
return TRLRolloutBufferSamples(*tuple(map(self.to_torch, data)))
|
||||||
|
Loading…
Reference in New Issue
Block a user