Extend policy impl
This commit is contained in:
		
							parent
							
								
									ca1ee980ef
								
							
						
					
					
						commit
						e927afcc30
					
				@ -3,31 +3,63 @@ from tensordict.nn import TensorDictModule
 | 
				
			|||||||
from torchrl.modules import MLP
 | 
					from torchrl.modules import MLP
 | 
				
			||||||
from tensordict.nn.distributions import NormalParamExtractor
 | 
					from tensordict.nn.distributions import NormalParamExtractor
 | 
				
			||||||
from fancy_rl.utils import is_discrete_space, get_space_shape
 | 
					from fancy_rl.utils import is_discrete_space, get_space_shape
 | 
				
			||||||
 | 
					from tensordict import TensorDict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Actor(TensorDictModule):
 | 
					class Actor(TensorDictModule):
 | 
				
			||||||
    def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device):
 | 
					    def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
 | 
				
			||||||
 | 
					        self.discrete = is_discrete_space(act_space)
 | 
				
			||||||
        act_space_shape = get_space_shape(act_space)
 | 
					        act_space_shape = get_space_shape(act_space)
 | 
				
			||||||
        if is_discrete_space(act_space):
 | 
					 | 
				
			||||||
            out_features = act_space_shape[-1]
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            out_features = act_space_shape[-1] * 2
 | 
					 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        actor_module = nn.Sequential(
 | 
					        if self.discrete and full_covariance:
 | 
				
			||||||
            MLP(
 | 
					            raise ValueError("Full covariance is not applicable for discrete action spaces.")
 | 
				
			||||||
                in_features=get_space_shape(obs_space)[-1],
 | 
					        
 | 
				
			||||||
                out_features=out_features,
 | 
					        self.full_covariance = full_covariance
 | 
				
			||||||
                num_cells=hidden_sizes,
 | 
					
 | 
				
			||||||
                activation_class=getattr(nn, activation_fn),
 | 
					        if self.discrete:
 | 
				
			||||||
                device=device
 | 
					            out_features = act_space_shape[-1]
 | 
				
			||||||
            ),
 | 
					            out_keys = ["action_logits"]
 | 
				
			||||||
            NormalParamExtractor() if not is_discrete_space(act_space) else nn.Identity(),
 | 
					        else:
 | 
				
			||||||
 | 
					            if full_covariance:
 | 
				
			||||||
 | 
					                out_features = act_space_shape[-1] + (act_space_shape[-1] * (act_space_shape[-1] + 1)) // 2
 | 
				
			||||||
 | 
					                out_keys = ["loc", "scale_tril"]
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                out_features = act_space_shape[-1] * 2
 | 
				
			||||||
 | 
					                out_keys = ["loc", "scale"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        actor_module = MLP(
 | 
				
			||||||
 | 
					            in_features=get_space_shape(obs_space)[-1],
 | 
				
			||||||
 | 
					            out_features=out_features,
 | 
				
			||||||
 | 
					            num_cells=hidden_sizes,
 | 
				
			||||||
 | 
					            activation_class=getattr(nn, activation_fn),
 | 
				
			||||||
 | 
					            device=device
 | 
				
			||||||
        ).to(device)
 | 
					        ).to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not self.discrete:
 | 
				
			||||||
 | 
					            if full_covariance:
 | 
				
			||||||
 | 
					                param_extractor = FullCovarianceNormalParamExtractor(act_space_shape[-1])
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                param_extractor = NormalParamExtractor()
 | 
				
			||||||
 | 
					            actor_module = nn.Sequential(actor_module, param_extractor)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        super().__init__(
 | 
					        super().__init__(
 | 
				
			||||||
            module=actor_module,
 | 
					            module=actor_module,
 | 
				
			||||||
            in_keys=["observation"],
 | 
					            in_keys=["observation"],
 | 
				
			||||||
            out_keys=["loc", "scale"] if not is_discrete_space(act_space) else ["action_logits"],
 | 
					            out_keys=out_keys
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class FullCovarianceNormalParamExtractor(nn.Module):
 | 
				
			||||||
 | 
					    def __init__(self, action_dim):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.action_dim = action_dim
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        loc = x[:, :self.action_dim]
 | 
				
			||||||
 | 
					        scale_tril = torch.zeros(x.shape[0], self.action_dim, self.action_dim, device=x.device)
 | 
				
			||||||
 | 
					        tril_indices = torch.tril_indices(row=self.action_dim, col=self.action_dim, offset=0)
 | 
				
			||||||
 | 
					        scale_tril[:, tril_indices[0], tril_indices[1]] = x[:, self.action_dim:]
 | 
				
			||||||
 | 
					        scale_tril.diagonal(dim1=-2, dim2=-1).exp_()
 | 
				
			||||||
 | 
					        return TensorDict({"loc": loc, "scale_tril": scale_tril}, batch_size=x.shape[0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Critic(TensorDictModule):
 | 
					class Critic(TensorDictModule):
 | 
				
			||||||
    def __init__(self, obs_space, hidden_sizes, activation_fn, device):
 | 
					    def __init__(self, obs_space, hidden_sizes, activation_fn, device):
 | 
				
			||||||
        critic_module = MLP(
 | 
					        critic_module = MLP(
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user