Various fixes

This commit is contained in:
Dominik Moritz Roth 2025-01-22 13:46:00 +01:00
parent fe7c7b3db0
commit 04be117a95
4 changed files with 94 additions and 36 deletions

View File

@ -1,9 +1,10 @@
import torch import torch
import gymnasium as gym import gymnasium as gym
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, SerialEnv
from torchrl.record import VideoRecorder from torchrl.record import VideoRecorder
from abc import ABC from abc import ABC
import pdb import pdb
import numpy as np
from tensordict import TensorDict from tensordict import TensorDict
from torchrl.envs import GymWrapper, TransformedEnv from torchrl.envs import GymWrapper, TransformedEnv
from torchrl.envs import BatchSizeTransform from torchrl.envs import BatchSizeTransform
@ -56,27 +57,31 @@ class Algo(ABC):
return env return env
def _wrap_env(self, env_spec): def _wrap_env(self, env_spec):
# If given an existing env, ensure it's properly batched
if isinstance(env_spec, (GymEnv, GymWrapper)):
if not env_spec.batch_size:
raise ValueError("Environment must be batched")
return env_spec
# Handle callable without wrapping the recursive call
if callable(env_spec):
return self._wrap_env(env_spec())
# Create new batched environment using SerialEnv
if isinstance(env_spec, str): if isinstance(env_spec, str):
env = GymEnv(env_spec, device=self.device) env = SerialEnv(1, lambda: GymEnv(env_spec, device=self.device))
elif isinstance(env_spec, gym.Env): elif isinstance(env_spec, gym.Env):
env = GymWrapper(env_spec, device=self.device) wrapped_env = GymWrapper(env_spec, device=self.device)
elif isinstance(env_spec, GymEnv): if wrapped_env.batch_size:
env = env_spec env = wrapped_env
elif callable(env_spec): else:
base_env = env_spec() env = SerialEnv(1, lambda: wrapped_env)
return self._wrap_env(base_env)
else: else:
raise ValueError( raise ValueError(
f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, " f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
f"got {type(env_spec)}" f"got {type(env_spec)}"
) )
if not env.batch_size:
env = TransformedEnv(
env,
BatchSizeTransform(batch_size=torch.Size([1]))
)
return env return env
def train_step(self, batch): def train_step(self, batch):
@ -90,18 +95,21 @@ class Algo(ABC):
def predict( def predict(
self, self,
observation, tensordict,
state=None, state=None,
deterministic=False deterministic=False
): ):
with torch.no_grad(): with torch.no_grad():
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0) # If numpy array, convert to TensorDict
td = TensorDict({"observation": obs_tensor}) if isinstance(tensordict, np.ndarray):
tensordict = TensorDict(
{"observation": torch.from_numpy(tensordict).float()},
batch_size=[]
)
action_td = self.prob_actor(td) # Move to device
action = action_td["action"] tensordict = tensordict.to(self.device)
# We're not using recurrent policies, so we'll always return None for the state # Get action from policy
next_state = None action_td = self.prob_actor(tensordict)
return action_td
return action.squeeze(0).cpu().numpy(), next_state

View File

@ -43,13 +43,13 @@ class PPO(OnPolicy):
env = self.make_env() env = self.make_env()
# Get spaces from specs for parallel env # Get spaces from specs for parallel env
obs_space = env.observation_spec self.obs_space = env.observation_spec
act_space = env.action_spec self.act_space = env.action_spec
self.discrete = isinstance(act_space, DiscreteTensorSpec) self.discrete = isinstance(self.act_space, DiscreteTensorSpec)
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) self.critic = Critic(self.obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance) self.actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
if self.discrete: if self.discrete:
distribution_class = torch.distributions.Categorical distribution_class = torch.distributions.Categorical

View File

@ -14,6 +14,8 @@ from fancy_rl.objectives import TRPLLoss
from copy import deepcopy from copy import deepcopy
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
from tensordict import TensorDict from tensordict import TensorDict
from torch.distributions import Categorical, MultivariateNormal, Normal
class ProjectedActor(TensorDictModule): class ProjectedActor(TensorDictModule):
def __init__(self, raw_actor, old_actor, projection): def __init__(self, raw_actor, old_actor, projection):
@ -26,6 +28,8 @@ class ProjectedActor(TensorDictModule):
self.raw_actor = raw_actor self.raw_actor = raw_actor
self.old_actor = old_actor self.old_actor = old_actor
self.projection = projection self.projection = projection
self.discrete = raw_actor.discrete
self.full_covariance = raw_actor.full_covariance
class CombinedModule(nn.Module): class CombinedModule(nn.Module):
def __init__(self, raw_actor, old_actor, projection): def __init__(self, raw_actor, old_actor, projection):
@ -35,12 +39,41 @@ class ProjectedActor(TensorDictModule):
self.projection = projection self.projection = projection
def forward(self, tensordict): def forward(self, tensordict):
# Convert the tuple outputs to TensorDict
raw_params = self.raw_actor(tensordict) raw_params = self.raw_actor(tensordict)
if isinstance(raw_params, tuple):
raw_params = TensorDict({
"loc": raw_params[0],
"scale": raw_params[1]
}, batch_size=[raw_params[0].shape[0]]) # Use the first dimension of the tensor as batch size
old_params = self.old_actor(tensordict) old_params = self.old_actor(tensordict)
combined_params = TensorDict({**raw_params, **{f"old_{key}": value for key, value in old_params.items()}}, batch_size=tensordict.batch_size) if isinstance(old_params, tuple):
old_params = TensorDict({
"loc": old_params[0],
"scale": old_params[1]
}, batch_size=[old_params[0].shape[0]]) # Use the first dimension of the tensor as batch size
# Now combine them
combined_params = TensorDict({
**raw_params,
**{f"old_{key}": value for key, value in old_params.items()}
}, batch_size=[raw_params["loc"].shape[0]]) # Use the first dimension of loc tensor as batch size
projected_params = self.projection(combined_params) projected_params = self.projection(combined_params)
return projected_params return projected_params
def get_dist(self, tensordict):
# Forward the observation through the network
out = self.forward(tensordict)
if self.discrete:
return Categorical(logits=out["logits"])
else:
if self.full_covariance:
return MultivariateNormal(loc=out["loc"], scale_tril=out["scale_tril"])
else:
return Normal(loc=out["loc"], scale=out["scale"])
class TRPL(OnPolicy): class TRPL(OnPolicy):
def __init__( def __init__(
self, self,
@ -77,14 +110,16 @@ class TRPL(OnPolicy):
# Initialize environment to get observation and action space sizes # Initialize environment to get observation and action space sizes
self.env_spec = env_spec self.env_spec = env_spec
env = self.make_env() env = self.make_env()
obs_space = env.observation_space
act_space = env.action_space
assert not isinstance(act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces" # Get spaces from specs for parallel env
self.obs_space = env.observation_spec
self.act_space = env.action_spec
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device) assert not isinstance(self.act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
self.old_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance) self.critic = Critic(self.obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.raw_actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
self.old_actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
# Handle projection_class # Handle projection_class
if isinstance(projection_class, str): if isinstance(projection_class, str):

View File

@ -4,6 +4,7 @@ from torchrl.modules import MLP
from torchrl.data.tensor_specs import DiscreteTensorSpec from torchrl.data.tensor_specs import DiscreteTensorSpec
from tensordict.nn.distributions import NormalParamExtractor from tensordict.nn.distributions import NormalParamExtractor
from tensordict import TensorDict from tensordict import TensorDict
from torch.distributions import Categorical, MultivariateNormal, Normal
class Actor(TensorDictModule): class Actor(TensorDictModule):
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False): def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
@ -50,6 +51,17 @@ class Actor(TensorDictModule):
out_keys=out_keys out_keys=out_keys
) )
def get_dist(self, tensordict):
# Forward the observation through the network
out = self.forward(tensordict)
if self.discrete:
return Categorical(logits=out["logits"])
else:
if self.full_covariance:
return MultivariateNormal(loc=out["loc"], scale_tril=out["scale_tril"])
else:
return Normal(loc=out["loc"], scale=out["scale"])
class FullCovarianceNormalParamExtractor(nn.Module): class FullCovarianceNormalParamExtractor(nn.Module):
def __init__(self, action_dim): def __init__(self, action_dim):
super().__init__() super().__init__()
@ -65,8 +77,11 @@ class FullCovarianceNormalParamExtractor(nn.Module):
class Critic(TensorDictModule): class Critic(TensorDictModule):
def __init__(self, obs_space, hidden_sizes, activation_fn, device): def __init__(self, obs_space, hidden_sizes, activation_fn, device):
obs_space = obs_space["observation"]
obs_space_shape = obs_space.shape[1:]
critic_module = MLP( critic_module = MLP(
in_features=obs_space.shape[-1], in_features=obs_space_shape[0],
out_features=1, out_features=1,
num_cells=hidden_sizes, num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn), activation_class=getattr(nn, activation_fn),