From 0e4eedae5e02231d8d781ddd917090b4536a4247 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 10 Aug 2022 11:55:08 +0200 Subject: [PATCH] Fixed gradient throught spherical-chol --- metastable_baselines/distributions/distributions.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/metastable_baselines/distributions/distributions.py b/metastable_baselines/distributions/distributions.py index b66cea8..51af0c7 100644 --- a/metastable_baselines/distributions/distributions.py +++ b/metastable_baselines/distributions/distributions.py @@ -514,10 +514,12 @@ class CholNet(nn.Module): n = sphe_chol.shape[-1] L = th.zeros_like(sphe_chol) for i in range(n): - t = 1 + #t = 1 + t = th.Tensor([1])[0] #s = '' for j in range(i+1): - maybe_cos = 1 + #maybe_cos = 1 + maybe_cos = th.Tensor([1])[0] #s_maybe_cos = '' if i != j and j < n-1 and i < n: if batch: