Ensure 'actions' is on correct device for entropy calculation

This commit is contained in:
Dominik Moritz Roth 2023-05-21 18:15:11 +02:00
parent 35df8f44da
commit c21682ad84

View File

@ -114,7 +114,6 @@ class PCA_Distribution(SB3_Distribution):
# Premature optimization is the root of all evil
self._build_conditioner()
# *Optimizes it anyways*
print('[i] PCA-Distribution initialized')
def proba_distribution_net(self, latent_dim: int):
mu_net = nn.Linear(latent_dim, self.action_dim)
@ -129,7 +128,7 @@ class PCA_Distribution(SB3_Distribution):
return self
def log_prob(self, actions: th.Tensor) -> th.Tensor:
return sum_independent_dims(self.distribution.log_prob(actions))
return sum_independent_dims(self.distribution.log_prob(actions.to(self.distribution.mean.device)))
def entropy(self) -> th.Tensor:
return sum_independent_dims(self.distribution.entropy())
@ -146,7 +145,7 @@ class PCA_Distribution(SB3_Distribution):
return self.sample(traj=trajectory)
def sample(self, traj: th.Tensor, f_sigma: int = 1, epsilon=None) -> th.Tensor:
pi_mean, pi_std = self.distribution.mean, self.distribution.scale
pi_mean, pi_std = self.distribution.mean.cpu(), self.distribution.scale.cpu()
rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std)
rho_std *= f_sigma
eta = self._get_rigged(pi_mean, pi_std,
@ -187,8 +186,9 @@ class PCA_Distribution(SB3_Distribution):
return pi_mean, pi_std
traj = self._pad_and_cut_trajectory(trajectory)
# Numpy is fun
y_np = np.append(np.swapaxes(traj, -1, -2),
np.expand_dims(pi_mean, -1), -1)
np.repeat(np.expand_dims(pi_mean, -1), traj.shape[0], 0), -1)
with th.no_grad():
conditioners = th.Tensor(self._adapt_conditioner(pi_std))
@ -300,7 +300,7 @@ class StdNet(nn.Module):
return self._ensure_positive_func(self.param)
elif self.par_strength == Par_Strength.CONT_SCALAR:
cont = self.net(x)
diag_chol = th.ones(self.action_dim) * cont * self.std_init
diag_chol = th.ones(self.action_dim, device=cont.device) * cont * self.std_init
return self._ensure_positive_func(diag_chol)
elif self.par_strength == Par_Strength.CONT_HYBRID:
cont = self.net(x)