Ensure 'actions' is on correct device for entropy calculation
This commit is contained in:
parent
35df8f44da
commit
c21682ad84
@ -114,7 +114,6 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
# Premature optimization is the root of all evil
|
# Premature optimization is the root of all evil
|
||||||
self._build_conditioner()
|
self._build_conditioner()
|
||||||
# *Optimizes it anyways*
|
# *Optimizes it anyways*
|
||||||
print('[i] PCA-Distribution initialized')
|
|
||||||
|
|
||||||
def proba_distribution_net(self, latent_dim: int):
|
def proba_distribution_net(self, latent_dim: int):
|
||||||
mu_net = nn.Linear(latent_dim, self.action_dim)
|
mu_net = nn.Linear(latent_dim, self.action_dim)
|
||||||
@ -129,7 +128,7 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
def log_prob(self, actions: th.Tensor) -> th.Tensor:
|
||||||
return sum_independent_dims(self.distribution.log_prob(actions))
|
return sum_independent_dims(self.distribution.log_prob(actions.to(self.distribution.mean.device)))
|
||||||
|
|
||||||
def entropy(self) -> th.Tensor:
|
def entropy(self) -> th.Tensor:
|
||||||
return sum_independent_dims(self.distribution.entropy())
|
return sum_independent_dims(self.distribution.entropy())
|
||||||
@ -146,7 +145,7 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
return self.sample(traj=trajectory)
|
return self.sample(traj=trajectory)
|
||||||
|
|
||||||
def sample(self, traj: th.Tensor, f_sigma: int = 1, epsilon=None) -> th.Tensor:
|
def sample(self, traj: th.Tensor, f_sigma: int = 1, epsilon=None) -> th.Tensor:
|
||||||
pi_mean, pi_std = self.distribution.mean, self.distribution.scale
|
pi_mean, pi_std = self.distribution.mean.cpu(), self.distribution.scale.cpu()
|
||||||
rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std)
|
rho_mean, rho_std = self._conditioning_engine(traj, pi_mean, pi_std)
|
||||||
rho_std *= f_sigma
|
rho_std *= f_sigma
|
||||||
eta = self._get_rigged(pi_mean, pi_std,
|
eta = self._get_rigged(pi_mean, pi_std,
|
||||||
@ -187,8 +186,9 @@ class PCA_Distribution(SB3_Distribution):
|
|||||||
return pi_mean, pi_std
|
return pi_mean, pi_std
|
||||||
|
|
||||||
traj = self._pad_and_cut_trajectory(trajectory)
|
traj = self._pad_and_cut_trajectory(trajectory)
|
||||||
|
# Numpy is fun
|
||||||
y_np = np.append(np.swapaxes(traj, -1, -2),
|
y_np = np.append(np.swapaxes(traj, -1, -2),
|
||||||
np.expand_dims(pi_mean, -1), -1)
|
np.repeat(np.expand_dims(pi_mean, -1), traj.shape[0], 0), -1)
|
||||||
|
|
||||||
with th.no_grad():
|
with th.no_grad():
|
||||||
conditioners = th.Tensor(self._adapt_conditioner(pi_std))
|
conditioners = th.Tensor(self._adapt_conditioner(pi_std))
|
||||||
@ -300,7 +300,7 @@ class StdNet(nn.Module):
|
|||||||
return self._ensure_positive_func(self.param)
|
return self._ensure_positive_func(self.param)
|
||||||
elif self.par_strength == Par_Strength.CONT_SCALAR:
|
elif self.par_strength == Par_Strength.CONT_SCALAR:
|
||||||
cont = self.net(x)
|
cont = self.net(x)
|
||||||
diag_chol = th.ones(self.action_dim) * cont * self.std_init
|
diag_chol = th.ones(self.action_dim, device=cont.device) * cont * self.std_init
|
||||||
return self._ensure_positive_func(diag_chol)
|
return self._ensure_positive_func(diag_chol)
|
||||||
elif self.par_strength == Par_Strength.CONT_HYBRID:
|
elif self.par_strength == Par_Strength.CONT_HYBRID:
|
||||||
cont = self.net(x)
|
cont = self.net(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user