Why does KL double free?
This commit is contained in:
parent
75d73049b4
commit
12e422aec7
@ -307,7 +307,10 @@ class ActorCriticPolicy(BasePolicy):
|
|||||||
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
|
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
|
||||||
elif isinstance(self.action_dist, UniversalGaussianDistribution):
|
elif isinstance(self.action_dist, UniversalGaussianDistribution):
|
||||||
if self.sqrt_induced_gaussian:
|
if self.sqrt_induced_gaussian:
|
||||||
cov_sqrt = self.chol_net(latent_pi)
|
chol_sqrt_cov = self.chol_net(latent_pi)
|
||||||
|
if len(chol_sqrt_cov.shape) == 2:
|
||||||
|
chol_sqrt_cov = th.diag_embed(chol_sqrt_cov)
|
||||||
|
cov_sqrt = th.bmm(chol_sqrt_cov.mT, chol_sqrt_cov)
|
||||||
dist = self.action_dist.proba_distribution_from_sqrt(
|
dist = self.action_dist.proba_distribution_from_sqrt(
|
||||||
mean_actions, cov_sqrt, latent_pi)
|
mean_actions, cov_sqrt, latent_pi)
|
||||||
mean, chol = get_mean_and_chol(dist, expand=False)
|
mean, chol = get_mean_and_chol(dist, expand=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user