Various fixes
This commit is contained in:
parent
fe7c7b3db0
commit
04be117a95
@ -1,9 +1,10 @@
|
||||
import torch
|
||||
import gymnasium as gym
|
||||
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv
|
||||
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, SerialEnv
|
||||
from torchrl.record import VideoRecorder
|
||||
from abc import ABC
|
||||
import pdb
|
||||
import numpy as np
|
||||
from tensordict import TensorDict
|
||||
from torchrl.envs import GymWrapper, TransformedEnv
|
||||
from torchrl.envs import BatchSizeTransform
|
||||
@ -56,27 +57,31 @@ class Algo(ABC):
|
||||
return env
|
||||
|
||||
def _wrap_env(self, env_spec):
|
||||
# If given an existing env, ensure it's properly batched
|
||||
if isinstance(env_spec, (GymEnv, GymWrapper)):
|
||||
if not env_spec.batch_size:
|
||||
raise ValueError("Environment must be batched")
|
||||
return env_spec
|
||||
|
||||
# Handle callable without wrapping the recursive call
|
||||
if callable(env_spec):
|
||||
return self._wrap_env(env_spec())
|
||||
|
||||
# Create new batched environment using SerialEnv
|
||||
if isinstance(env_spec, str):
|
||||
env = GymEnv(env_spec, device=self.device)
|
||||
env = SerialEnv(1, lambda: GymEnv(env_spec, device=self.device))
|
||||
elif isinstance(env_spec, gym.Env):
|
||||
env = GymWrapper(env_spec, device=self.device)
|
||||
elif isinstance(env_spec, GymEnv):
|
||||
env = env_spec
|
||||
elif callable(env_spec):
|
||||
base_env = env_spec()
|
||||
return self._wrap_env(base_env)
|
||||
wrapped_env = GymWrapper(env_spec, device=self.device)
|
||||
if wrapped_env.batch_size:
|
||||
env = wrapped_env
|
||||
else:
|
||||
env = SerialEnv(1, lambda: wrapped_env)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
|
||||
f"got {type(env_spec)}"
|
||||
)
|
||||
|
||||
if not env.batch_size:
|
||||
env = TransformedEnv(
|
||||
env,
|
||||
BatchSizeTransform(batch_size=torch.Size([1]))
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
def train_step(self, batch):
|
||||
@ -90,18 +95,21 @@ class Algo(ABC):
|
||||
|
||||
def predict(
|
||||
self,
|
||||
observation,
|
||||
tensordict,
|
||||
state=None,
|
||||
deterministic=False
|
||||
):
|
||||
with torch.no_grad():
|
||||
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
|
||||
td = TensorDict({"observation": obs_tensor})
|
||||
# If numpy array, convert to TensorDict
|
||||
if isinstance(tensordict, np.ndarray):
|
||||
tensordict = TensorDict(
|
||||
{"observation": torch.from_numpy(tensordict).float()},
|
||||
batch_size=[]
|
||||
)
|
||||
|
||||
action_td = self.prob_actor(td)
|
||||
action = action_td["action"]
|
||||
# Move to device
|
||||
tensordict = tensordict.to(self.device)
|
||||
|
||||
# We're not using recurrent policies, so we'll always return None for the state
|
||||
next_state = None
|
||||
|
||||
return action.squeeze(0).cpu().numpy(), next_state
|
||||
# Get action from policy
|
||||
action_td = self.prob_actor(tensordict)
|
||||
return action_td
|
||||
|
@ -43,13 +43,13 @@ class PPO(OnPolicy):
|
||||
env = self.make_env()
|
||||
|
||||
# Get spaces from specs for parallel env
|
||||
obs_space = env.observation_spec
|
||||
act_space = env.action_spec
|
||||
self.obs_space = env.observation_spec
|
||||
self.act_space = env.action_spec
|
||||
|
||||
self.discrete = isinstance(act_space, DiscreteTensorSpec)
|
||||
self.discrete = isinstance(self.act_space, DiscreteTensorSpec)
|
||||
|
||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||
self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
self.critic = Critic(self.obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||
self.actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
|
||||
if self.discrete:
|
||||
distribution_class = torch.distributions.Categorical
|
||||
|
@ -14,6 +14,8 @@ from fancy_rl.objectives import TRPLLoss
|
||||
from copy import deepcopy
|
||||
from tensordict.nn import TensorDictModule
|
||||
from tensordict import TensorDict
|
||||
from torch.distributions import Categorical, MultivariateNormal, Normal
|
||||
|
||||
|
||||
class ProjectedActor(TensorDictModule):
|
||||
def __init__(self, raw_actor, old_actor, projection):
|
||||
@ -26,6 +28,8 @@ class ProjectedActor(TensorDictModule):
|
||||
self.raw_actor = raw_actor
|
||||
self.old_actor = old_actor
|
||||
self.projection = projection
|
||||
self.discrete = raw_actor.discrete
|
||||
self.full_covariance = raw_actor.full_covariance
|
||||
|
||||
class CombinedModule(nn.Module):
|
||||
def __init__(self, raw_actor, old_actor, projection):
|
||||
@ -35,12 +39,41 @@ class ProjectedActor(TensorDictModule):
|
||||
self.projection = projection
|
||||
|
||||
def forward(self, tensordict):
|
||||
# Convert the tuple outputs to TensorDict
|
||||
raw_params = self.raw_actor(tensordict)
|
||||
if isinstance(raw_params, tuple):
|
||||
raw_params = TensorDict({
|
||||
"loc": raw_params[0],
|
||||
"scale": raw_params[1]
|
||||
}, batch_size=[raw_params[0].shape[0]]) # Use the first dimension of the tensor as batch size
|
||||
|
||||
old_params = self.old_actor(tensordict)
|
||||
combined_params = TensorDict({**raw_params, **{f"old_{key}": value for key, value in old_params.items()}}, batch_size=tensordict.batch_size)
|
||||
if isinstance(old_params, tuple):
|
||||
old_params = TensorDict({
|
||||
"loc": old_params[0],
|
||||
"scale": old_params[1]
|
||||
}, batch_size=[old_params[0].shape[0]]) # Use the first dimension of the tensor as batch size
|
||||
|
||||
# Now combine them
|
||||
combined_params = TensorDict({
|
||||
**raw_params,
|
||||
**{f"old_{key}": value for key, value in old_params.items()}
|
||||
}, batch_size=[raw_params["loc"].shape[0]]) # Use the first dimension of loc tensor as batch size
|
||||
|
||||
projected_params = self.projection(combined_params)
|
||||
return projected_params
|
||||
|
||||
def get_dist(self, tensordict):
|
||||
# Forward the observation through the network
|
||||
out = self.forward(tensordict)
|
||||
if self.discrete:
|
||||
return Categorical(logits=out["logits"])
|
||||
else:
|
||||
if self.full_covariance:
|
||||
return MultivariateNormal(loc=out["loc"], scale_tril=out["scale_tril"])
|
||||
else:
|
||||
return Normal(loc=out["loc"], scale=out["scale"])
|
||||
|
||||
class TRPL(OnPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
@ -77,14 +110,16 @@ class TRPL(OnPolicy):
|
||||
# Initialize environment to get observation and action space sizes
|
||||
self.env_spec = env_spec
|
||||
env = self.make_env()
|
||||
obs_space = env.observation_space
|
||||
act_space = env.action_space
|
||||
|
||||
assert not isinstance(act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
|
||||
# Get spaces from specs for parallel env
|
||||
self.obs_space = env.observation_spec
|
||||
self.act_space = env.action_spec
|
||||
|
||||
self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||
self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
self.old_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
assert not isinstance(self.act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
|
||||
|
||||
self.critic = Critic(self.obs_space, critic_hidden_sizes, critic_activation_fn, device)
|
||||
self.raw_actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
self.old_actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
|
||||
|
||||
# Handle projection_class
|
||||
if isinstance(projection_class, str):
|
||||
|
@ -4,6 +4,7 @@ from torchrl.modules import MLP
|
||||
from torchrl.data.tensor_specs import DiscreteTensorSpec
|
||||
from tensordict.nn.distributions import NormalParamExtractor
|
||||
from tensordict import TensorDict
|
||||
from torch.distributions import Categorical, MultivariateNormal, Normal
|
||||
|
||||
class Actor(TensorDictModule):
|
||||
def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
|
||||
@ -50,6 +51,17 @@ class Actor(TensorDictModule):
|
||||
out_keys=out_keys
|
||||
)
|
||||
|
||||
def get_dist(self, tensordict):
|
||||
# Forward the observation through the network
|
||||
out = self.forward(tensordict)
|
||||
if self.discrete:
|
||||
return Categorical(logits=out["logits"])
|
||||
else:
|
||||
if self.full_covariance:
|
||||
return MultivariateNormal(loc=out["loc"], scale_tril=out["scale_tril"])
|
||||
else:
|
||||
return Normal(loc=out["loc"], scale=out["scale"])
|
||||
|
||||
class FullCovarianceNormalParamExtractor(nn.Module):
|
||||
def __init__(self, action_dim):
|
||||
super().__init__()
|
||||
@ -65,8 +77,11 @@ class FullCovarianceNormalParamExtractor(nn.Module):
|
||||
|
||||
class Critic(TensorDictModule):
|
||||
def __init__(self, obs_space, hidden_sizes, activation_fn, device):
|
||||
obs_space = obs_space["observation"]
|
||||
obs_space_shape = obs_space.shape[1:]
|
||||
|
||||
critic_module = MLP(
|
||||
in_features=obs_space.shape[-1],
|
||||
in_features=obs_space_shape[0],
|
||||
out_features=1,
|
||||
num_cells=hidden_sizes,
|
||||
activation_class=getattr(nn, activation_fn),
|
||||
|
Loading…
Reference in New Issue
Block a user