Various fixes

This commit is contained in:
Dominik Moritz Roth 2025-01-22 13:46:00 +01:00
parent fe7c7b3db0
commit 04be117a95
4 changed files with 94 additions and 36 deletions

View File

@ -1,9 +1,10 @@
import torch
import gymnasium as gym
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, SerialEnv
from torchrl.record import VideoRecorder
from abc import ABC
import pdb
import numpy as np
from tensordict import TensorDict
from torchrl.envs import GymWrapper, TransformedEnv
from torchrl.envs import BatchSizeTransform
@ -56,27 +57,31 @@ class Algo(ABC):
return env
def _wrap_env(self, env_spec):
# If given an existing env, ensure it's properly batched
if isinstance(env_spec, (GymEnv, GymWrapper)):
if not env_spec.batch_size:
raise ValueError("Environment must be batched")
return env_spec
# Handle callable without wrapping the recursive call
if callable(env_spec):
return self._wrap_env(env_spec())
# Create new batched environment using SerialEnv
if isinstance(env_spec, str):
env = GymEnv(env_spec, device=self.device)
env = SerialEnv(1, lambda: GymEnv(env_spec, device=self.device))
elif isinstance(env_spec, gym.Env):
env = GymWrapper(env_spec, device=self.device)
elif isinstance(env_spec, GymEnv):
env = env_spec
elif callable(env_spec):
base_env = env_spec()
return self._wrap_env(base_env)
wrapped_env = GymWrapper(env_spec, device=self.device)
if wrapped_env.batch_size:
env = wrapped_env
else:
env = SerialEnv(1, lambda: wrapped_env)
else:
raise ValueError(
f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
f"got {type(env_spec)}"
)
if not env.batch_size:
env = TransformedEnv(
env,
BatchSizeTransform(batch_size=torch.Size([1]))
)
return env
def train_step(self, batch):
@ -90,18 +95,21 @@ class Algo(ABC):
def predict(
self,
observation,
tensordict,
state=None,
deterministic=False
):
with torch.no_grad():
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
td = TensorDict({"observation": obs_tensor})
# If numpy array, convert to TensorDict
if isinstance(tensordict, np.ndarray):
tensordict = TensorDict(
{"observation": torch.from_numpy(tensordict).float()},
batch_size=[]
)
action_td = self.prob_actor(td)
action = action_td["action"]
# Move to device
tensordict = tensordict.to(self.device)
# We're not using recurrent policies, so we'll always return None for the state
next_state = None
return action.squeeze(0).cpu().numpy(), next_state
# Get action from policy
action_td = self.prob_actor(tensordict)
return action_td

View File

@ -43,13 +43,13 @@ class PPO(OnPolicy):
env = self.make_env()
# Get spaces from specs for parallel env
obs_space = env.observation_spec
act_space = env.action_spec
self.obs_space = env.observation_spec
self.act_space = env.action_spec
self.discrete = isinstance(act_space, DiscreteTensorSpec)
self.discrete = isinstance(self.act_space, DiscreteTensorSpec)
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
self.critic = Critic(self.obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
if self.discrete:
distribution_class = torch.distributions.Categorical

View File

@ -14,6 +14,8 @@ from fancy_rl.objectives import TRPLLoss
from copy import deepcopy
from tensordict.nn import TensorDictModule
from tensordict import TensorDict
from torch.distributions import Categorical, MultivariateNormal, Normal
class ProjectedActor(TensorDictModule):
def __init__(self, raw_actor, old_actor, projection):
@ -26,6 +28,8 @@ class ProjectedActor(TensorDictModule):
self.raw_actor = raw_actor
self.old_actor = old_actor
self.projection = projection
self.discrete = raw_actor.discrete
self.full_covariance = raw_actor.full_covariance
class CombinedModule(nn.Module):
def __init__(self, raw_actor, old_actor, projection):
@ -35,12 +39,41 @@ class ProjectedActor(TensorDictModule):
self.projection = projection
def forward(self, tensordict):
# Convert the tuple outputs to TensorDict
raw_params = self.raw_actor(tensordict)
if isinstance(raw_params, tuple):
raw_params = TensorDict({
"loc": raw_params[0],
"scale": raw_params[1]
}, batch_size=[raw_params[0].shape[0]]) # Use the first dimension of the tensor as batch size
old_params = self.old_actor(tensordict)
combined_params = TensorDict({**raw_params, **{f"old_{key}": value for key, value in old_params.items()}}, batch_size=tensordict.batch_size)
if isinstance(old_params, tuple):
old_params = TensorDict({
"loc": old_params[0],
"scale": old_params[1]
}, batch_size=[old_params[0].shape[0]]) # Use the first dimension of the tensor as batch size
# Now combine them
combined_params = TensorDict({
**raw_params,
**{f"old_{key}": value for key, value in old_params.items()}
}, batch_size=[raw_params["loc"].shape[0]]) # Use the first dimension of loc tensor as batch size
projected_params = self.projection(combined_params)
return projected_params
def get_dist(self, tensordict):
# Forward the observation through the network
out = self.forward(tensordict)
if self.discrete:
return Categorical(logits=out["logits"])
else:
if self.full_covariance:
return MultivariateNormal(loc=out["loc"], scale_tril=out["scale_tril"])
else:
return Normal(loc=out["loc"], scale=out["scale"])
class TRPL(OnPolicy):
def __init__(
self,
@ -77,14 +110,16 @@ class TRPL(OnPolicy):
# Initialize environment to get observation and action space sizes
self.env_spec = env_spec
env = self.make_env()
obs_space = env.observation_space
act_space = env.action_space
assert not isinstance(act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
# Get spaces from specs for parallel env
self.obs_space = env.observation_spec
self.act_space = env.action_spec
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
self.old_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
assert not isinstance(self.act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
self.critic = Critic(self.obs_space, critic_hidden_sizes, critic_activation_fn, device)
self.raw_actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
self.old_actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
# Handle projection_class
if isinstance(projection_class, str):

View File

@ -4,6 +4,7 @@ from torchrl.modules import MLP
from torchrl.data.tensor_specs import DiscreteTensorSpec
from tensordict.nn.distributions import NormalParamExtractor
from tensordict import TensorDict
from torch.distributions import Categorical, MultivariateNormal, Normal
class Actor(TensorDictModule):
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
@ -50,6 +51,17 @@ class Actor(TensorDictModule):
out_keys=out_keys
)
def get_dist(self, tensordict):
# Forward the observation through the network
out = self.forward(tensordict)
if self.discrete:
return Categorical(logits=out["logits"])
else:
if self.full_covariance:
return MultivariateNormal(loc=out["loc"], scale_tril=out["scale_tril"])
else:
return Normal(loc=out["loc"], scale=out["scale"])
class FullCovarianceNormalParamExtractor(nn.Module):
def __init__(self, action_dim):
super().__init__()
@ -65,8 +77,11 @@ class FullCovarianceNormalParamExtractor(nn.Module):
class Critic(TensorDictModule):
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
obs_space = obs_space["observation"]
obs_space_shape = obs_space.shape[1:]
critic_module = MLP(
in_features=obs_space.shape[-1],
in_features=obs_space_shape[0],
out_features=1,
num_cells=hidden_sizes,
activation_class=getattr(nn, activation_fn),