Ensure 'actions' is on correct device for entropy calculation
This commit is contained in:
parent
35df8f44da
commit
c21682ad84
@ -114,7 +114,6 @@ class PCA_Distribution(SB3_Distribution):
|
||||
# Premature optimization is the root of all evil
|
||||
self._build_conditioner()
|
||||
# *Optimizes it anyways*
|
||||
print('[i] PCA-Distribution initialized')
|
||||
|
||||
def proba_distribution_net(self, latent_dim: int):
|
||||
mu_net = nn.Linear(latent_dim, self.action_dim)
|
||||
@ -129,7 +128,7 @@ class PCA_Distribution(SB3_Distribution):
|
||||
return self
|
||||
|
||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||
return sum_independent_dims(self.distribution.log_prob(actions))
|
||||
return sum_independent_dims(self.distribution.log_prob(actions.to(self.distribution.mean.device)))
|
||||
|
||||
def entropy(self) -> th.Tensor:
|
||||
return sum_independent_dims(self.distribution.entropy())
|
||||
@ -146,7 +145,7 @@ class PCA_Distribution(SB3_Distribution):
|
||||
return self.sample(traj=trajectory)
|
||||
|
||||
def sample(self, traj: th.Tensor, f_sigma: int = 1, epsilon=None) -> th.Tensor:
|
||||
pi_mean, pi_std = self.distribution.mean, self.distribution.scale
|
||||
pi_mean, pi_std = self.distribution.mean.cpu(), self.distribution.scale.cpu()
|
||||
rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std)
|
||||
rho_std *= f_sigma
|
||||
eta = self._get_rigged(pi_mean, pi_std,
|
||||
@ -187,8 +186,9 @@ class PCA_Distribution(SB3_Distribution):
|
||||
return pi_mean, pi_std
|
||||
|
||||
traj = self._pad_and_cut_trajectory(trajectory)
|
||||
# Numpy is fun
|
||||
y_np = np.append(np.swapaxes(traj, -1, -2),
|
||||
np.expand_dims(pi_mean, -1), -1)
|
||||
np.repeat(np.expand_dims(pi_mean, -1), traj.shape[0], 0), -1)
|
||||
|
||||
with th.no_grad():
|
||||
conditioners = th.Tensor(self._adapt_conditioner(pi_std))
|
||||
@ -300,7 +300,7 @@ class StdNet(nn.Module):
|
||||
return self._ensure_positive_func(self.param)
|
||||
elif self.par_strength == Par_Strength.CONT_SCALAR:
|
||||
cont = self.net(x)
|
||||
diag_chol = th.ones(self.action_dim) * cont * self.std_init
|
||||
diag_chol = th.ones(self.action_dim, device=cont.device) * cont * self.std_init
|
||||
return self._ensure_positive_func(diag_chol)
|
||||
elif self.par_strength == Par_Strength.CONT_HYBRID:
|
||||
cont = self.net(x)
|
||||
|
Loading…
Reference in New Issue
Block a user