Make use of new utils to work with any space
This commit is contained in:
		
							parent
							
								
									59060c7533
								
							
						
					
					
						commit
						50733bb1a4
					
				| @ -2,21 +2,22 @@ import torch.nn as nn | |||||||
| from tensordict.nn import TensorDictModule | from tensordict.nn import TensorDictModule | ||||||
| from torchrl.modules import MLP | from torchrl.modules import MLP | ||||||
| from tensordict.nn.distributions import NormalParamExtractor | from tensordict.nn.distributions import NormalParamExtractor | ||||||
|  | from fancy_rl.utils import is_discrete_space, get_space_shape | ||||||
| 
 | 
 | ||||||
| class SharedModule(TensorDictModule): | class SharedModule(TensorDictModule): | ||||||
|     def __init__(self, obs_space, hidden_sizes, activation_fn, device): |     def __init__(self, obs_space, hidden_sizes, activation_fn, device): | ||||||
|         if hidden_sizes: |         if hidden_sizes: | ||||||
|             shared_module = MLP( |             shared_module = MLP( | ||||||
|                 in_features=obs_space.shape[-1], |                 in_features=get_space_shape(obs_space)[-1], | ||||||
|                 out_features=hidden_sizes[-1], |                 out_features=hidden_sizes[-1], | ||||||
|                 num_cells=hidden_sizes, |                 num_cells=hidden_sizes[:-1], | ||||||
|                 activation_class=getattr(nn, activation_fn), |                 activation_class=getattr(nn, activation_fn), | ||||||
|                 device=device |                 device=device | ||||||
|             ) |             ) | ||||||
|             out_features = hidden_sizes[-1] |             out_features = hidden_sizes[-1] | ||||||
|         else: |         else: | ||||||
|             shared_module = nn.Identity() |             shared_module = nn.Identity() | ||||||
|             out_features = obs_space.shape[-1] |             out_features = get_space_shape(obs_space)[-1] | ||||||
|          |          | ||||||
|         super().__init__( |         super().__init__( | ||||||
|             module=shared_module, |             module=shared_module, | ||||||
| @ -27,20 +28,26 @@ class SharedModule(TensorDictModule): | |||||||
| 
 | 
 | ||||||
| class Actor(TensorDictModule): | class Actor(TensorDictModule): | ||||||
|     def __init__(self, shared_module, act_space, hidden_sizes, activation_fn, device): |     def __init__(self, shared_module, act_space, hidden_sizes, activation_fn, device): | ||||||
|  |         act_space_shape = get_space_shape(act_space) | ||||||
|  |         if is_discrete_space(act_space): | ||||||
|  |             out_features = act_space_shape[-1] | ||||||
|  |         else: | ||||||
|  |             out_features = act_space_shape[-1] * 2 | ||||||
|  | 
 | ||||||
|         actor_module = nn.Sequential( |         actor_module = nn.Sequential( | ||||||
|             MLP( |             MLP( | ||||||
|                 in_features=shared_module.out_features, |                 in_features=shared_module.out_features, | ||||||
|                 out_features=act_space.shape[-1] * 2, |                 out_features=out_features, | ||||||
|                 num_cells=hidden_sizes, |                 num_cells=hidden_sizes, | ||||||
|                 activation_class=getattr(nn, activation_fn), |                 activation_class=getattr(nn, activation_fn), | ||||||
|                 device=device |                 device=device | ||||||
|             ), |             ), | ||||||
|             NormalParamExtractor(), |             NormalParamExtractor() if not is_discrete_space(act_space) else nn.Identity(), | ||||||
|         ).to(device) |         ).to(device) | ||||||
|         super().__init__( |         super().__init__( | ||||||
|             module=actor_module, |             module=actor_module, | ||||||
|             in_keys=["shared"], |             in_keys=["shared"], | ||||||
|             out_keys=["loc", "scale"], |             out_keys=["loc", "scale"] if not is_discrete_space(act_space) else ["action_logits"], | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| class Critic(TensorDictModule): | class Critic(TensorDictModule): | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user