107 lines
3.5 KiB
Python
107 lines
3.5 KiB
Python
import torch
|
|
import gymnasium as gym
|
|
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv
|
|
from torchrl.record import VideoRecorder
|
|
from abc import ABC
|
|
import pdb
|
|
from tensordict import TensorDict
|
|
from torchrl.envs import GymWrapper, TransformedEnv
|
|
from torchrl.envs import BatchSizeTransform
|
|
|
|
from fancy_rl.loggers import TerminalLogger
|
|
|
|
class Algo(ABC):
|
|
def __init__(
|
|
self,
|
|
env_spec,
|
|
loggers,
|
|
optimizers,
|
|
learning_rate,
|
|
n_steps,
|
|
batch_size,
|
|
n_epochs,
|
|
gamma,
|
|
total_timesteps,
|
|
eval_interval,
|
|
eval_deterministic,
|
|
entropy_coef,
|
|
critic_coef,
|
|
normalize_advantage,
|
|
device=None,
|
|
eval_episodes=10,
|
|
env_spec_eval=None,
|
|
):
|
|
self.env_spec = env_spec
|
|
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
|
|
self.loggers = loggers if loggers != None else [TerminalLogger(None, None)]
|
|
self.optimizers = optimizers
|
|
self.learning_rate = learning_rate
|
|
self.n_steps = n_steps
|
|
self.batch_size = batch_size
|
|
self.n_epochs = n_epochs
|
|
self.gamma = gamma
|
|
self.total_timesteps = total_timesteps
|
|
self.eval_interval = eval_interval
|
|
self.eval_deterministic = eval_deterministic
|
|
self.entropy_coef = entropy_coef
|
|
self.critic_coef = critic_coef
|
|
self.normalize_advantage = normalize_advantage
|
|
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
|
|
self.eval_episodes = eval_episodes
|
|
|
|
def make_env(self, eval=False):
|
|
env_spec = self.env_spec_eval if eval else self.env_spec
|
|
env = self._wrap_env(env_spec)
|
|
env.reset()
|
|
return env
|
|
|
|
def _wrap_env(self, env_spec):
|
|
if isinstance(env_spec, str):
|
|
env = GymEnv(env_spec, device=self.device)
|
|
elif isinstance(env_spec, gym.Env):
|
|
env = GymWrapper(env_spec, device=self.device)
|
|
elif isinstance(env_spec, GymEnv):
|
|
env = env_spec
|
|
elif callable(env_spec):
|
|
base_env = env_spec()
|
|
return self._wrap_env(base_env)
|
|
else:
|
|
raise ValueError(
|
|
f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
|
|
f"got {type(env_spec)}"
|
|
)
|
|
|
|
if not env.batch_size:
|
|
env = TransformedEnv(
|
|
env,
|
|
BatchSizeTransform(batch_size=torch.Size([1]))
|
|
)
|
|
|
|
return env
|
|
|
|
def train_step(self, batch):
|
|
raise NotImplementedError("train_step method must be implemented in subclass.")
|
|
|
|
def train(self):
|
|
raise NotImplementedError("train method must be implemented in subclass.")
|
|
|
|
def evaluate(self, epoch):
|
|
raise NotImplementedError("evaluate method must be implemented in subclass.")
|
|
|
|
def predict(
|
|
self,
|
|
observation,
|
|
state=None,
|
|
deterministic=False
|
|
):
|
|
with torch.no_grad():
|
|
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
|
|
td = TensorDict({"observation": obs_tensor})
|
|
|
|
action_td = self.prob_actor(td)
|
|
action = action_td["action"]
|
|
|
|
# We're not using recurrent policies, so we'll always return None for the state
|
|
next_state = None
|
|
|
|
return action.squeeze(0).cpu().numpy(), next_state |