Implemented Policies with Contextual Covariance
This commit is contained in:
		
							parent
							
								
									41e4170b2f
								
							
						
					
					
						commit
						fae19509bc
					
				@ -1 +1,2 @@
 | 
				
			|||||||
#TODO: License or such
 | 
					#TODO: License or such
 | 
				
			||||||
 | 
					from .distributions import *
 | 
				
			||||||
 | 
				
			|||||||
@ -2,10 +2,10 @@ import torch as th
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from stable_baselines3.common.distributions import Distribution as SB3_Distribution
 | 
					from stable_baselines3.common.distributions import Distribution as SB3_Distribution
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from metastable_baselines.distributions.distributions import UniversalGaussianDistribution
 | 
					from ..distributions import UniversalGaussianDistribution, AnyDistribution
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_mean_and_chol(p, expand=False):
 | 
					def get_mean_and_chol(p: AnyDistribution, expand=False):
 | 
				
			||||||
    if isinstance(p, th.distributions.Normal):
 | 
					    if isinstance(p, th.distributions.Normal):
 | 
				
			||||||
        if expand:
 | 
					        if expand:
 | 
				
			||||||
            return p.mean, th.diag_embed(p.stddev)
 | 
					            return p.mean, th.diag_embed(p.stddev)
 | 
				
			||||||
@ -19,7 +19,7 @@ def get_mean_and_chol(p, expand=False):
 | 
				
			|||||||
        raise Exception('Dist-Type not implemented')
 | 
					        raise Exception('Dist-Type not implemented')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_mean_and_sqrt(p):
 | 
					def get_mean_and_sqrt(p: UniversalGaussianDistribution):
 | 
				
			||||||
    raise Exception('Not yet implemented...')
 | 
					    raise Exception('Not yet implemented...')
 | 
				
			||||||
    if isinstance(p, th.distributions.Normal):
 | 
					    if isinstance(p, th.distributions.Normal):
 | 
				
			||||||
        return p.mean, p.stddev
 | 
					        return p.mean, p.stddev
 | 
				
			||||||
@ -31,7 +31,7 @@ def get_mean_and_sqrt(p):
 | 
				
			|||||||
        raise Exception('Dist-Type not implemented')
 | 
					        raise Exception('Dist-Type not implemented')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_cov(p):
 | 
					def get_cov(p: AnyDistribution):
 | 
				
			||||||
    if isinstance(p, th.distributions.Normal):
 | 
					    if isinstance(p, th.distributions.Normal):
 | 
				
			||||||
        return th.diag_embed(p.variance)
 | 
					        return th.diag_embed(p.variance)
 | 
				
			||||||
    elif isinstance(p, th.distributions.MultivariateNormal):
 | 
					    elif isinstance(p, th.distributions.MultivariateNormal):
 | 
				
			||||||
@ -42,7 +42,7 @@ def get_cov(p):
 | 
				
			|||||||
        raise Exception('Dist-Type not implemented')
 | 
					        raise Exception('Dist-Type not implemented')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def has_diag_cov(p, numerical_check=True):
 | 
					def has_diag_cov(p: AnyDistribution, numerical_check=True):
 | 
				
			||||||
    if isinstance(p, SB3_Distribution):
 | 
					    if isinstance(p, SB3_Distribution):
 | 
				
			||||||
        return has_diag_cov(p.distribution, numerical_check=numerical_check)
 | 
					        return has_diag_cov(p.distribution, numerical_check=numerical_check)
 | 
				
			||||||
    if isinstance(p, th.distributions.Normal):
 | 
					    if isinstance(p, th.distributions.Normal):
 | 
				
			||||||
@ -54,18 +54,18 @@ def has_diag_cov(p, numerical_check=True):
 | 
				
			|||||||
    return th.equal(cov - th.diag_embed(th.diagonal(cov, dim1=-2, dim2=-1), th.zeros_like(cov)))
 | 
					    return th.equal(cov - th.diag_embed(th.diagonal(cov, dim1=-2, dim2=-1), th.zeros_like(cov)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def is_contextual(p):
 | 
					def is_contextual(p: AnyDistribution):
 | 
				
			||||||
    # TODO: Implement for UniveralGaussianDist
 | 
					    # TODO: Implement for UniveralGaussianDist
 | 
				
			||||||
    return False
 | 
					    return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_diag_cov_vec(p, check_diag=True, numerical_check=True):
 | 
					def get_diag_cov_vec(p: AnyDistribution, check_diag=True, numerical_check=True):
 | 
				
			||||||
    if check_diag and not has_diag_cov(p):
 | 
					    if check_diag and not has_diag_cov(p):
 | 
				
			||||||
        raise Exception('Cannot reduce cov-mat to diag-vec: Is not diagonal')
 | 
					        raise Exception('Cannot reduce cov-mat to diag-vec: Is not diagonal')
 | 
				
			||||||
    return th.diagonal(get_cov(p), dim1=-2, dim2=-1)
 | 
					    return th.diagonal(get_cov(p), dim1=-2, dim2=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def new_dist_like(orig_p, mean, chol):
 | 
					def new_dist_like(orig_p: AnyDistribution, mean: th.Tensor, chol: th.Tensor):
 | 
				
			||||||
    if isinstance(orig_p, UniversalGaussianDistribution):
 | 
					    if isinstance(orig_p, UniversalGaussianDistribution):
 | 
				
			||||||
        return orig_p.new_list_like_me(mean, chol)
 | 
					        return orig_p.new_list_like_me(mean, chol)
 | 
				
			||||||
    elif isinstance(orig_p, th.distributions.Normal):
 | 
					    elif isinstance(orig_p, th.distributions.Normal):
 | 
				
			||||||
 | 
				
			|||||||
@ -1,20 +0,0 @@
 | 
				
			|||||||
import torch as th
 | 
					 | 
				
			||||||
from torch import nn
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class FakeModule(nn.Module):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    A torch.nn Module, that drops the input and returns a tensor given at initialization.
 | 
					 | 
				
			||||||
    Gradients can pass through this Module and affect the given tensor.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    # In order to reduce the code required to allow suppor for contextual covariance and parametric covariance, we just channel the parametric covariance through such a FakeModule
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def __init__(self, tensor):
 | 
					 | 
				
			||||||
        super().__init__()
 | 
					 | 
				
			||||||
        self.tensor = tensor
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def forward(self, x):
 | 
					 | 
				
			||||||
        return self.tensor
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def string(self):
 | 
					 | 
				
			||||||
        return '<FakeModule: '+str(self.tensor)+'>'
 | 
					 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user