fancy_rl/fancy_rl/algos/algo.py

107 lines
3.5 KiB
Python

import torch
import gymnasium as gym
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv
from torchrl.record import VideoRecorder
from abc import ABC
import pdb
from tensordict import TensorDict
from torchrl.envs import GymWrapper, TransformedEnv
from torchrl.envs import BatchSizeTransform
from fancy_rl.loggers import TerminalLogger
class Algo(ABC):
def __init__(
self,
env_spec,
loggers,
optimizers,
learning_rate,
n_steps,
batch_size,
n_epochs,
gamma,
total_timesteps,
eval_interval,
eval_deterministic,
entropy_coef,
critic_coef,
normalize_advantage,
device=None,
eval_episodes=10,
env_spec_eval=None,
):
self.env_spec = env_spec
self.env_spec_eval = env_spec_eval if env_spec_eval is not None else env_spec
self.loggers = loggers if loggers != None else [TerminalLogger(None, None)]
self.optimizers = optimizers
self.learning_rate = learning_rate
self.n_steps = n_steps
self.batch_size = batch_size
self.n_epochs = n_epochs
self.gamma = gamma
self.total_timesteps = total_timesteps
self.eval_interval = eval_interval
self.eval_deterministic = eval_deterministic
self.entropy_coef = entropy_coef
self.critic_coef = critic_coef
self.normalize_advantage = normalize_advantage
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
self.eval_episodes = eval_episodes
def make_env(self, eval=False):
env_spec = self.env_spec_eval if eval else self.env_spec
env = self._wrap_env(env_spec)
env.reset()
return env
def _wrap_env(self, env_spec):
if isinstance(env_spec, str):
env = GymEnv(env_spec, device=self.device)
elif isinstance(env_spec, gym.Env):
env = GymWrapper(env_spec, device=self.device)
elif isinstance(env_spec, GymEnv):
env = env_spec
elif callable(env_spec):
base_env = env_spec()
return self._wrap_env(base_env)
else:
raise ValueError(
f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
f"got {type(env_spec)}"
)
if not env.batch_size:
env = TransformedEnv(
env,
BatchSizeTransform(batch_size=torch.Size([1]))
)
return env
def train_step(self, batch):
raise NotImplementedError("train_step method must be implemented in subclass.")
def train(self):
raise NotImplementedError("train method must be implemented in subclass.")
def evaluate(self, epoch):
raise NotImplementedError("evaluate method must be implemented in subclass.")
def predict(
self,
observation,
state=None,
deterministic=False
):
with torch.no_grad():
obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
td = TensorDict({"observation": obs_tensor})
action_td = self.prob_actor(td)
action = action_td["action"]
# We're not using recurrent policies, so we'll always return None for the state
next_state = None
return action.squeeze(0).cpu().numpy(), next_state