Finalized factoring out projections
This commit is contained in:
parent
0aeea4e2e5
commit
4532135812
@ -1,31 +0,0 @@
|
|||||||
import torch as th
|
|
||||||
from torch.distributions.multivariate_normal import _batch_mahalanobis
|
|
||||||
|
|
||||||
|
|
||||||
def mahalanobis_alt(u, v, std):
|
|
||||||
"""
|
|
||||||
Stolen from Fabian's Code (Public Version)
|
|
||||||
|
|
||||||
"""
|
|
||||||
delta = u - v
|
|
||||||
return th.triangular_solve(delta, std, upper=False)[0].pow(2).sum([-2, -1])
|
|
||||||
|
|
||||||
|
|
||||||
def mahalanobis(u, v, chol):
|
|
||||||
delta = u - v
|
|
||||||
return _batch_mahalanobis(chol, delta)
|
|
||||||
|
|
||||||
|
|
||||||
def frob_sq(diff, is_spd=False):
|
|
||||||
# If diff is spd, we can use a (probably) more performant algorithm
|
|
||||||
if is_spd:
|
|
||||||
return _frob_sq_spd(diff)
|
|
||||||
return th.norm(diff, p='fro', dim=tuple(range(1, diff.dim()))).pow(2)
|
|
||||||
|
|
||||||
|
|
||||||
def _frob_sq_spd(diff):
|
|
||||||
return _batch_trace(diff @ diff)
|
|
||||||
|
|
||||||
|
|
||||||
def _batch_trace(x):
|
|
||||||
return th.diagonal(x, dim1=-2, dim2=-1).sum(-1)
|
|
@ -336,8 +336,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
|
|||||||
|
|
||||||
policy_loss = surrogate_loss + self.ent_coef * entropy_loss + \
|
policy_loss = surrogate_loss + self.ent_coef * entropy_loss + \
|
||||||
trust_region_loss + self.action_coef * action_loss
|
trust_region_loss + self.action_coef * action_loss
|
||||||
import pdb
|
|
||||||
pdb.set_trace()
|
|
||||||
pg_losses.append(policy_loss.item())
|
pg_losses.append(policy_loss.item())
|
||||||
|
|
||||||
loss = policy_loss + self.vf_coef * value_loss
|
loss = policy_loss + self.vf_coef * value_loss
|
||||||
|
2
test.py
2
test.py
@ -28,7 +28,7 @@ def main(env_name='ColumbusCandyland_Aux10-v0', timesteps=1_000_000, showRes=Tru
|
|||||||
MlpPolicyPPO,
|
MlpPolicyPPO,
|
||||||
env,
|
env,
|
||||||
# KLProjectionLayer(trust_region_coeff=0.01),
|
# KLProjectionLayer(trust_region_coeff=0.01),
|
||||||
projection=KLProjectionLayer(trust_region_coeff=0.01),
|
projection=WassersteinProjectionLayer(trust_region_coeff=0.01),
|
||||||
policy_kwargs={'dist_kwargs': {'neural_strength': Strength.NONE, 'cov_strength': Strength.DIAG, 'parameterization_type':
|
policy_kwargs={'dist_kwargs': {'neural_strength': Strength.NONE, 'cov_strength': Strength.DIAG, 'parameterization_type':
|
||||||
ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}},
|
ParametrizationType.NONE, 'enforce_positive_type': EnforcePositiveType.ABS, 'prob_squashing_type': ProbSquashingType.NONE}},
|
||||||
verbose=0,
|
verbose=0,
|
||||||
|
Loading…
Reference in New Issue
Block a user