Compare commits
	
		
			5 Commits
		
	
	
		
			e938018494
			...
			3816adef9a
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 3816adef9a | |||
| 04be117a95 | |||
| fe7c7b3db0 | |||
| 90666a695c | |||
| c1189351cf | 
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -1,5 +1,6 @@
 | 
				
			|||||||
__pycache__
 | 
					__pycache__
 | 
				
			||||||
.venv
 | 
					.venv
 | 
				
			||||||
 | 
					.vscode
 | 
				
			||||||
wandb
 | 
					wandb
 | 
				
			||||||
*.egg-info/
 | 
					*.egg-info/
 | 
				
			||||||
test.py
 | 
					test.py
 | 
				
			||||||
 | 
				
			|||||||
@ -1,9 +1,10 @@
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
import gymnasium as gym
 | 
					import gymnasium as gym
 | 
				
			||||||
from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, ParallelEnv
 | 
					from torchrl.envs import GymEnv, TransformedEnv, Compose, RewardSum, StepCounter, SerialEnv
 | 
				
			||||||
from torchrl.record import VideoRecorder
 | 
					from torchrl.record import VideoRecorder
 | 
				
			||||||
from abc import ABC
 | 
					from abc import ABC
 | 
				
			||||||
import pdb
 | 
					import pdb
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
from tensordict import TensorDict
 | 
					from tensordict import TensorDict
 | 
				
			||||||
from torchrl.envs import GymWrapper, TransformedEnv
 | 
					from torchrl.envs import GymWrapper, TransformedEnv
 | 
				
			||||||
from torchrl.envs import BatchSizeTransform
 | 
					from torchrl.envs import BatchSizeTransform
 | 
				
			||||||
@ -56,27 +57,31 @@ class Algo(ABC):
 | 
				
			|||||||
        return env
 | 
					        return env
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _wrap_env(self, env_spec):
 | 
					    def _wrap_env(self, env_spec):
 | 
				
			||||||
 | 
					        # If given an existing env, ensure it's properly batched
 | 
				
			||||||
 | 
					        if isinstance(env_spec, (GymEnv, GymWrapper)):
 | 
				
			||||||
 | 
					            if not env_spec.batch_size:
 | 
				
			||||||
 | 
					                raise ValueError("Environment must be batched")
 | 
				
			||||||
 | 
					            return env_spec
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Handle callable without wrapping the recursive call
 | 
				
			||||||
 | 
					        if callable(env_spec):
 | 
				
			||||||
 | 
					            return self._wrap_env(env_spec())
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        # Create new batched environment using SerialEnv
 | 
				
			||||||
        if isinstance(env_spec, str):
 | 
					        if isinstance(env_spec, str):
 | 
				
			||||||
            env = GymEnv(env_spec, device=self.device)
 | 
					            env = SerialEnv(1, lambda: GymEnv(env_spec, device=self.device))
 | 
				
			||||||
        elif isinstance(env_spec, gym.Env):
 | 
					        elif isinstance(env_spec, gym.Env):
 | 
				
			||||||
            env = GymWrapper(env_spec, device=self.device)
 | 
					            wrapped_env = GymWrapper(env_spec, device=self.device)
 | 
				
			||||||
        elif isinstance(env_spec, GymEnv):
 | 
					            if wrapped_env.batch_size:
 | 
				
			||||||
            env = env_spec
 | 
					                env = wrapped_env
 | 
				
			||||||
        elif callable(env_spec):
 | 
					            else:
 | 
				
			||||||
            base_env = env_spec()
 | 
					                env = SerialEnv(1, lambda: wrapped_env)
 | 
				
			||||||
            return self._wrap_env(base_env)
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            raise ValueError(
 | 
					            raise ValueError(
 | 
				
			||||||
                f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
 | 
					                f"env_spec must be a string, callable, Gymnasium environment, or GymEnv, "
 | 
				
			||||||
                f"got {type(env_spec)}"
 | 
					                f"got {type(env_spec)}"
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        if not env.batch_size:
 | 
					 | 
				
			||||||
            env = TransformedEnv(
 | 
					 | 
				
			||||||
                env,
 | 
					 | 
				
			||||||
                BatchSizeTransform(batch_size=torch.Size([1]))
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        
 | 
					 | 
				
			||||||
        return env
 | 
					        return env
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def train_step(self, batch):
 | 
					    def train_step(self, batch):
 | 
				
			||||||
@ -90,18 +95,21 @@ class Algo(ABC):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def predict(
 | 
					    def predict(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        observation,
 | 
					        tensordict,
 | 
				
			||||||
        state=None,
 | 
					        state=None,
 | 
				
			||||||
        deterministic=False
 | 
					        deterministic=False
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        with torch.no_grad():
 | 
					        with torch.no_grad():
 | 
				
			||||||
            obs_tensor = torch.as_tensor(observation, device=self.device).unsqueeze(0)
 | 
					            # If numpy array, convert to TensorDict
 | 
				
			||||||
            td = TensorDict({"observation": obs_tensor})
 | 
					            if isinstance(tensordict, np.ndarray):
 | 
				
			||||||
 | 
					                tensordict = TensorDict(
 | 
				
			||||||
 | 
					                    {"observation": torch.from_numpy(tensordict).float()}, 
 | 
				
			||||||
 | 
					                    batch_size=[]
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
            action_td = self.prob_actor(td)
 | 
					            # Move to device
 | 
				
			||||||
            action = action_td["action"]
 | 
					            tensordict = tensordict.to(self.device)
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
            # We're not using recurrent policies, so we'll always return None for the state
 | 
					            # Get action from policy
 | 
				
			||||||
            next_state = None
 | 
					            action_td = self.prob_actor(tensordict)
 | 
				
			||||||
            
 | 
					            return action_td
 | 
				
			||||||
            return action.squeeze(0).cpu().numpy(), next_state
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -43,13 +43,13 @@ class PPO(OnPolicy):
 | 
				
			|||||||
        env = self.make_env()
 | 
					        env = self.make_env()
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        # Get spaces from specs for parallel env
 | 
					        # Get spaces from specs for parallel env
 | 
				
			||||||
        obs_space = env.observation_spec
 | 
					        self.obs_space = env.observation_spec
 | 
				
			||||||
        act_space = env.action_spec
 | 
					        self.act_space = env.action_spec
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        self.discrete = isinstance(act_space, DiscreteTensorSpec)
 | 
					        self.discrete = isinstance(self.act_space, DiscreteTensorSpec)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
 | 
					        self.critic = Critic(self.obs_space, critic_hidden_sizes, critic_activation_fn, device)
 | 
				
			||||||
        self.actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
 | 
					        self.actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.discrete:
 | 
					        if self.discrete:
 | 
				
			||||||
            distribution_class = torch.distributions.Categorical
 | 
					            distribution_class = torch.distributions.Categorical
 | 
				
			||||||
 | 
				
			|||||||
@ -14,6 +14,8 @@ from fancy_rl.objectives import TRPLLoss
 | 
				
			|||||||
from copy import deepcopy
 | 
					from copy import deepcopy
 | 
				
			||||||
from tensordict.nn import TensorDictModule
 | 
					from tensordict.nn import TensorDictModule
 | 
				
			||||||
from tensordict import TensorDict
 | 
					from tensordict import TensorDict
 | 
				
			||||||
 | 
					from torch.distributions import Categorical, MultivariateNormal, Normal
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ProjectedActor(TensorDictModule):
 | 
					class ProjectedActor(TensorDictModule):
 | 
				
			||||||
    def __init__(self, raw_actor, old_actor, projection):
 | 
					    def __init__(self, raw_actor, old_actor, projection):
 | 
				
			||||||
@ -26,6 +28,8 @@ class ProjectedActor(TensorDictModule):
 | 
				
			|||||||
        self.raw_actor = raw_actor
 | 
					        self.raw_actor = raw_actor
 | 
				
			||||||
        self.old_actor = old_actor
 | 
					        self.old_actor = old_actor
 | 
				
			||||||
        self.projection = projection
 | 
					        self.projection = projection
 | 
				
			||||||
 | 
					        self.discrete = raw_actor.discrete
 | 
				
			||||||
 | 
					        self.full_covariance = raw_actor.full_covariance
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class CombinedModule(nn.Module):
 | 
					    class CombinedModule(nn.Module):
 | 
				
			||||||
        def __init__(self, raw_actor, old_actor, projection):
 | 
					        def __init__(self, raw_actor, old_actor, projection):
 | 
				
			||||||
@ -35,12 +39,41 @@ class ProjectedActor(TensorDictModule):
 | 
				
			|||||||
            self.projection = projection
 | 
					            self.projection = projection
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def forward(self, tensordict):
 | 
					        def forward(self, tensordict):
 | 
				
			||||||
 | 
					            # Convert the tuple outputs to TensorDict
 | 
				
			||||||
            raw_params = self.raw_actor(tensordict)
 | 
					            raw_params = self.raw_actor(tensordict)
 | 
				
			||||||
 | 
					            if isinstance(raw_params, tuple):
 | 
				
			||||||
 | 
					                raw_params = TensorDict({
 | 
				
			||||||
 | 
					                    "loc": raw_params[0], 
 | 
				
			||||||
 | 
					                    "scale": raw_params[1]
 | 
				
			||||||
 | 
					                }, batch_size=[raw_params[0].shape[0]])  # Use the first dimension of the tensor as batch size
 | 
				
			||||||
 | 
					                    
 | 
				
			||||||
            old_params = self.old_actor(tensordict)
 | 
					            old_params = self.old_actor(tensordict)
 | 
				
			||||||
            combined_params = TensorDict({**raw_params, **{f"old_{key}": value for key, value in old_params.items()}}, batch_size=tensordict.batch_size)
 | 
					            if isinstance(old_params, tuple):
 | 
				
			||||||
 | 
					                old_params = TensorDict({
 | 
				
			||||||
 | 
					                    "loc": old_params[0], 
 | 
				
			||||||
 | 
					                    "scale": old_params[1]
 | 
				
			||||||
 | 
					                }, batch_size=[old_params[0].shape[0]])  # Use the first dimension of the tensor as batch size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Now combine them
 | 
				
			||||||
 | 
					            combined_params = TensorDict({
 | 
				
			||||||
 | 
					                **raw_params,
 | 
				
			||||||
 | 
					                **{f"old_{key}": value for key, value in old_params.items()}
 | 
				
			||||||
 | 
					            }, batch_size=[raw_params["loc"].shape[0]])  # Use the first dimension of loc tensor as batch size
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
            projected_params = self.projection(combined_params)
 | 
					            projected_params = self.projection(combined_params)
 | 
				
			||||||
            return projected_params
 | 
					            return projected_params
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_dist(self, tensordict):
 | 
				
			||||||
 | 
					        # Forward the observation through the network
 | 
				
			||||||
 | 
					        out = self.forward(tensordict)
 | 
				
			||||||
 | 
					        if self.discrete:
 | 
				
			||||||
 | 
					            return Categorical(logits=out["logits"])
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if self.full_covariance:
 | 
				
			||||||
 | 
					                return MultivariateNormal(loc=out["loc"], scale_tril=out["scale_tril"])
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                return Normal(loc=out["loc"], scale=out["scale"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TRPL(OnPolicy):
 | 
					class TRPL(OnPolicy):
 | 
				
			||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
@ -77,14 +110,16 @@ class TRPL(OnPolicy):
 | 
				
			|||||||
        # Initialize environment to get observation and action space sizes
 | 
					        # Initialize environment to get observation and action space sizes
 | 
				
			||||||
        self.env_spec = env_spec
 | 
					        self.env_spec = env_spec
 | 
				
			||||||
        env = self.make_env()
 | 
					        env = self.make_env()
 | 
				
			||||||
        obs_space = env.observation_space
 | 
					 | 
				
			||||||
        act_space = env.action_space
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assert not isinstance(act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
 | 
					        # Get spaces from specs for parallel env
 | 
				
			||||||
 | 
					        self.obs_space = env.observation_spec
 | 
				
			||||||
 | 
					        self.act_space = env.action_spec
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.critic = Critic(obs_space, critic_hidden_sizes, critic_activation_fn, device)
 | 
					        assert not isinstance(self.act_space, DiscreteTensorSpec), "TRPL does not support discrete action spaces"
 | 
				
			||||||
        self.raw_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
 | 
					
 | 
				
			||||||
        self.old_actor = Actor(obs_space, act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
 | 
					        self.critic = Critic(self.obs_space, critic_hidden_sizes, critic_activation_fn, device)
 | 
				
			||||||
 | 
					        self.raw_actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
 | 
				
			||||||
 | 
					        self.old_actor = Actor(self.obs_space, self.act_space, actor_hidden_sizes, actor_activation_fn, device, full_covariance=full_covariance)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Handle projection_class
 | 
					        # Handle projection_class
 | 
				
			||||||
        if isinstance(projection_class, str):
 | 
					        if isinstance(projection_class, str):
 | 
				
			||||||
 | 
				
			|||||||
@ -4,6 +4,7 @@ from torchrl.modules import MLP
 | 
				
			|||||||
from torchrl.data.tensor_specs import DiscreteTensorSpec
 | 
					from torchrl.data.tensor_specs import DiscreteTensorSpec
 | 
				
			||||||
from tensordict.nn.distributions import NormalParamExtractor
 | 
					from tensordict.nn.distributions import NormalParamExtractor
 | 
				
			||||||
from tensordict import TensorDict
 | 
					from tensordict import TensorDict
 | 
				
			||||||
 | 
					from torch.distributions import Categorical, MultivariateNormal, Normal
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Actor(TensorDictModule):
 | 
					class Actor(TensorDictModule):
 | 
				
			||||||
    def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
 | 
					    def __init__(self, obs_space, act_space, hidden_sizes, activation_fn, device, full_covariance=False):
 | 
				
			||||||
@ -50,6 +51,17 @@ class Actor(TensorDictModule):
 | 
				
			|||||||
            out_keys=out_keys
 | 
					            out_keys=out_keys
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_dist(self, tensordict):
 | 
				
			||||||
 | 
					        # Forward the observation through the network
 | 
				
			||||||
 | 
					        out = self.forward(tensordict)
 | 
				
			||||||
 | 
					        if self.discrete:
 | 
				
			||||||
 | 
					            return Categorical(logits=out["logits"])
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if self.full_covariance:
 | 
				
			||||||
 | 
					                return MultivariateNormal(loc=out["loc"], scale_tril=out["scale_tril"])
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                return Normal(loc=out["loc"], scale=out["scale"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class FullCovarianceNormalParamExtractor(nn.Module):
 | 
					class FullCovarianceNormalParamExtractor(nn.Module):
 | 
				
			||||||
    def __init__(self, action_dim):
 | 
					    def __init__(self, action_dim):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
@ -65,8 +77,11 @@ class FullCovarianceNormalParamExtractor(nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class Critic(TensorDictModule):
 | 
					class Critic(TensorDictModule):
 | 
				
			||||||
    def __init__(self, obs_space, hidden_sizes, activation_fn, device):
 | 
					    def __init__(self, obs_space, hidden_sizes, activation_fn, device):
 | 
				
			||||||
 | 
					        obs_space = obs_space["observation"]
 | 
				
			||||||
 | 
					        obs_space_shape = obs_space.shape[1:]
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
        critic_module = MLP(
 | 
					        critic_module = MLP(
 | 
				
			||||||
            in_features=obs_space.shape[-1],
 | 
					            in_features=obs_space_shape[0],
 | 
				
			||||||
            out_features=1,
 | 
					            out_features=1,
 | 
				
			||||||
            num_cells=hidden_sizes,
 | 
					            num_cells=hidden_sizes,
 | 
				
			||||||
            activation_class=getattr(nn, activation_fn),
 | 
					            activation_class=getattr(nn, activation_fn),
 | 
				
			||||||
 | 
				
			|||||||
@ -1,5 +1,9 @@
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					try:
 | 
				
			||||||
    import cpp_projection
 | 
					    import cpp_projection
 | 
				
			||||||
 | 
					    cpp_projection_available = True
 | 
				
			||||||
 | 
					except ImportError:
 | 
				
			||||||
 | 
					    cpp_projection_available = False
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
from .base_projection import BaseProjection
 | 
					from .base_projection import BaseProjection
 | 
				
			||||||
from tensordict.nn import TensorDictModule
 | 
					from tensordict.nn import TensorDictModule
 | 
				
			||||||
 | 
				
			|||||||
@ -8,7 +8,7 @@
 | 
				
			|||||||
   description = "Minimalistic and efficient implementations of PPO and TRPL for torchrl"
 | 
					   description = "Minimalistic and efficient implementations of PPO and TRPL for torchrl"
 | 
				
			||||||
   authors = [{name = "Dominik Roth", email = "mail@dominik-roth.eu"}]
 | 
					   authors = [{name = "Dominik Roth", email = "mail@dominik-roth.eu"}]
 | 
				
			||||||
   readme = "README.md"
 | 
					   readme = "README.md"
 | 
				
			||||||
   requires-python = ">=3.7,<3.12"
 | 
					   requires-python = ">=3.7"
 | 
				
			||||||
   classifiers = [
 | 
					   classifiers = [
 | 
				
			||||||
       "Development Status :: 3 - Alpha",
 | 
					       "Development Status :: 3 - Alpha",
 | 
				
			||||||
       "Intended Audience :: Developers",
 | 
					       "Intended Audience :: Developers",
 | 
				
			||||||
@ -23,9 +23,10 @@
 | 
				
			|||||||
   dependencies = [
 | 
					   dependencies = [
 | 
				
			||||||
       "numpy",
 | 
					       "numpy",
 | 
				
			||||||
       "torch",
 | 
					       "torch",
 | 
				
			||||||
       "gymnasium",
 | 
					       "gymnasium<1.0",
 | 
				
			||||||
       "tensordict",
 | 
					       "tensordict",
 | 
				
			||||||
       "torchrl",
 | 
					       "torchrl",
 | 
				
			||||||
 | 
					       "pytest",
 | 
				
			||||||
   ]
 | 
					   ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   [project.urls]
 | 
					   [project.urls]
 | 
				
			||||||
@ -33,3 +34,4 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
   [project.optional-dependencies]
 | 
					   [project.optional-dependencies]
 | 
				
			||||||
   dev = ["pytest"]
 | 
					   dev = ["pytest"]
 | 
				
			||||||
 | 
					   box2d = ["swig", "gymnasium[box2d]"]
 | 
				
			||||||
 | 
				
			|||||||
@ -3,9 +3,11 @@ import numpy as np
 | 
				
			|||||||
from fancy_rl import PPO
 | 
					from fancy_rl import PPO
 | 
				
			||||||
import gymnasium as gym
 | 
					import gymnasium as gym
 | 
				
			||||||
from torchrl.envs import GymEnv
 | 
					from torchrl.envs import GymEnv
 | 
				
			||||||
 | 
					import torch as th
 | 
				
			||||||
 | 
					from tensordict import TensorDict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def simple_env():
 | 
					def simple_env():
 | 
				
			||||||
    return GymEnv('LunarLander-v2', continuous=True)
 | 
					    return gym.make('LunarLander-v2')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_ppo_instantiation():
 | 
					def test_ppo_instantiation():
 | 
				
			||||||
    ppo = PPO(simple_env)
 | 
					    ppo = PPO(simple_env)
 | 
				
			||||||
@ -15,69 +17,38 @@ def test_ppo_instantiation_from_str():
 | 
				
			|||||||
    ppo = PPO('CartPole-v1')
 | 
					    ppo = PPO('CartPole-v1')
 | 
				
			||||||
    assert isinstance(ppo, PPO)
 | 
					    assert isinstance(ppo, PPO)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_ppo_instantiation_from_make():
 | 
					 | 
				
			||||||
    ppo = PPO(gym.make('CartPole-v1'))
 | 
					 | 
				
			||||||
    assert isinstance(ppo, PPO)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@pytest.mark.parametrize("learning_rate", [1e-4, 3e-4, 1e-3])
 | 
					 | 
				
			||||||
@pytest.mark.parametrize("n_steps", [1024, 2048])
 | 
					 | 
				
			||||||
@pytest.mark.parametrize("batch_size", [32, 64, 128])
 | 
					 | 
				
			||||||
@pytest.mark.parametrize("n_epochs", [5, 10])
 | 
					 | 
				
			||||||
@pytest.mark.parametrize("gamma", [0.95, 0.99])
 | 
					 | 
				
			||||||
@pytest.mark.parametrize("clip_range", [0.1, 0.2, 0.3])
 | 
					 | 
				
			||||||
def test_ppo_initialization_with_different_hps(learning_rate, n_steps, batch_size, n_epochs, gamma, clip_range):
 | 
					 | 
				
			||||||
    ppo = PPO(
 | 
					 | 
				
			||||||
        simple_env,
 | 
					 | 
				
			||||||
        learning_rate=learning_rate,
 | 
					 | 
				
			||||||
        n_steps=n_steps,
 | 
					 | 
				
			||||||
        batch_size=batch_size,
 | 
					 | 
				
			||||||
        n_epochs=n_epochs,
 | 
					 | 
				
			||||||
        gamma=gamma,
 | 
					 | 
				
			||||||
        clip_range=clip_range
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    assert ppo.learning_rate == learning_rate
 | 
					 | 
				
			||||||
    assert ppo.n_steps == n_steps
 | 
					 | 
				
			||||||
    assert ppo.batch_size == batch_size
 | 
					 | 
				
			||||||
    assert ppo.n_epochs == n_epochs
 | 
					 | 
				
			||||||
    assert ppo.gamma == gamma
 | 
					 | 
				
			||||||
    assert ppo.clip_range == clip_range
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def test_ppo_predict():
 | 
					def test_ppo_predict():
 | 
				
			||||||
    ppo = PPO(simple_env)
 | 
					    ppo = PPO(simple_env)
 | 
				
			||||||
    env = ppo.make_env()
 | 
					    env = ppo.make_env()
 | 
				
			||||||
    obs, _ = env.reset()
 | 
					    obs = env.reset()
 | 
				
			||||||
    action, _ = ppo.predict(obs)
 | 
					    action = ppo.predict(obs)
 | 
				
			||||||
    assert isinstance(action, np.ndarray)
 | 
					    assert isinstance(action, TensorDict)
 | 
				
			||||||
    assert action.shape == env.action_space.shape
 | 
					 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
def test_ppo_learn():
 | 
					    # Handle both single and composite action spaces
 | 
				
			||||||
    ppo = PPO(simple_env, n_steps=64, batch_size=32)
 | 
					    if isinstance(env.action_space, list):
 | 
				
			||||||
    env = ppo.make_env()
 | 
					        expected_shape = (len(env.action_space),) + env.action_space[0].shape
 | 
				
			||||||
    obs = env.reset()
 | 
					    else:
 | 
				
			||||||
    for _ in range(64):
 | 
					        expected_shape = env.action_space.shape
 | 
				
			||||||
        action, _next_state = ppo.predict(obs)
 | 
					        
 | 
				
			||||||
        obs, reward, done, truncated, _ = env.step(action)
 | 
					    assert action["action"].shape == expected_shape
 | 
				
			||||||
        if done or truncated:
 | 
					 | 
				
			||||||
            obs = env.reset()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_ppo_training():
 | 
					def test_ppo_training():
 | 
				
			||||||
    ppo = PPO(simple_env, total_timesteps=10000)
 | 
					    ppo = PPO(simple_env, total_timesteps=100)
 | 
				
			||||||
    env = ppo.make_env()
 | 
					    env = ppo.make_env()
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    initial_performance = evaluate_policy(ppo, env)
 | 
					    initial_performance = evaluate_policy(ppo, env)
 | 
				
			||||||
    ppo.train()
 | 
					    ppo.train()
 | 
				
			||||||
    final_performance = evaluate_policy(ppo, env)
 | 
					    final_performance = evaluate_policy(ppo, env)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assert final_performance > initial_performance, "PPO should improve performance after training"
 | 
					def evaluate_policy(policy, env, n_eval_episodes=3):
 | 
				
			||||||
 | 
					 | 
				
			||||||
def evaluate_policy(policy, env, n_eval_episodes=10):
 | 
					 | 
				
			||||||
    total_reward = 0
 | 
					    total_reward = 0
 | 
				
			||||||
    for _ in range(n_eval_episodes):
 | 
					    for _ in range(n_eval_episodes):
 | 
				
			||||||
        obs = env.reset()
 | 
					        tensordict = env.reset()
 | 
				
			||||||
        done = False
 | 
					        done = False
 | 
				
			||||||
        while not done:
 | 
					        while not done:
 | 
				
			||||||
            action, _next_state = policy.predict(obs)
 | 
					            action = policy.predict(tensordict)
 | 
				
			||||||
            obs, reward, terminated, truncated, _ = env.step(action)
 | 
					            next_tensordict = env.step(action).get("next")
 | 
				
			||||||
            total_reward += reward
 | 
					            total_reward += next_tensordict["reward"]
 | 
				
			||||||
            done = terminated or truncated
 | 
					            done = next_tensordict["done"]
 | 
				
			||||||
 | 
					            tensordict = next_tensordict
 | 
				
			||||||
    return total_reward / n_eval_episodes
 | 
					    return total_reward / n_eval_episodes
 | 
				
			||||||
@ -2,9 +2,10 @@ import pytest
 | 
				
			|||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
from fancy_rl import TRPL
 | 
					from fancy_rl import TRPL
 | 
				
			||||||
import gymnasium as gym
 | 
					import gymnasium as gym
 | 
				
			||||||
 | 
					from tensordict import TensorDict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def simple_env():
 | 
					def simple_env():
 | 
				
			||||||
    return gym.make('LunarLander-v2', continuous=True)
 | 
					    return gym.make('Pendulum-v1')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_trpl_instantiation():
 | 
					def test_trpl_instantiation():
 | 
				
			||||||
    trpl = TRPL(simple_env)
 | 
					    trpl = TRPL(simple_env)
 | 
				
			||||||
@ -34,52 +35,41 @@ def test_trpl_initialization_with_different_hps(learning_rate, n_steps, batch_si
 | 
				
			|||||||
    assert trpl.n_steps == n_steps
 | 
					    assert trpl.n_steps == n_steps
 | 
				
			||||||
    assert trpl.batch_size == batch_size
 | 
					    assert trpl.batch_size == batch_size
 | 
				
			||||||
    assert trpl.gamma == gamma
 | 
					    assert trpl.gamma == gamma
 | 
				
			||||||
    assert trpl.projection.trust_region_bound_mean == trust_region_bound_mean
 | 
					    assert trpl.projection.mean_bound == trust_region_bound_mean
 | 
				
			||||||
    assert trpl.projection.trust_region_bound_cov == trust_region_bound_cov
 | 
					    assert trpl.projection.cov_bound == trust_region_bound_cov
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_trpl_predict():
 | 
					def test_trpl_predict():
 | 
				
			||||||
    trpl = TRPL(simple_env)
 | 
					    trpl = TRPL(simple_env)
 | 
				
			||||||
    env = trpl.make_env()
 | 
					    env = trpl.make_env()
 | 
				
			||||||
    obs, _ = env.reset()
 | 
					    obs = env.reset()
 | 
				
			||||||
    action, _ = trpl.predict(obs)
 | 
					    action = trpl.predict(obs)
 | 
				
			||||||
    assert isinstance(action, np.ndarray)
 | 
					    assert isinstance(action, TensorDict)
 | 
				
			||||||
    assert action.shape == env.action_space.shape
 | 
					 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
def test_trpl_learn():
 | 
					    # Handle both single and composite action spaces
 | 
				
			||||||
    trpl = TRPL(simple_env, n_steps=64, batch_size=32)
 | 
					    if isinstance(env.action_space, list):
 | 
				
			||||||
    env = trpl.make_env()
 | 
					        expected_shape = (len(env.action_space),) + env.action_space[0].shape
 | 
				
			||||||
    obs, _ = env.reset()
 | 
					    else:
 | 
				
			||||||
    for _ in range(64):
 | 
					        expected_shape = env.action_space.shape
 | 
				
			||||||
        action, _ = trpl.predict(obs)
 | 
					 | 
				
			||||||
        next_obs, reward, done, truncated, _ = env.step(action)
 | 
					 | 
				
			||||||
        trpl.store_transition(obs, action, reward, done, next_obs)
 | 
					 | 
				
			||||||
        obs = next_obs
 | 
					 | 
				
			||||||
        if done or truncated:
 | 
					 | 
				
			||||||
            obs, _ = env.reset()
 | 
					 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
    loss = trpl.learn()
 | 
					    assert action["action"].shape == expected_shape
 | 
				
			||||||
    assert isinstance(loss, dict)
 | 
					 | 
				
			||||||
    assert "policy_loss" in loss
 | 
					 | 
				
			||||||
    assert "value_loss" in loss
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_trpl_training():
 | 
					def test_trpl_training():
 | 
				
			||||||
    trpl = TRPL(simple_env, total_timesteps=10000)
 | 
					    trpl = TRPL(simple_env, total_timesteps=100)
 | 
				
			||||||
    env = trpl.make_env()
 | 
					    env = trpl.make_env()
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    initial_performance = evaluate_policy(trpl, env)
 | 
					    initial_performance = evaluate_policy(trpl, env)
 | 
				
			||||||
    trpl.train()
 | 
					    trpl.train()
 | 
				
			||||||
    final_performance = evaluate_policy(trpl, env)
 | 
					    final_performance = evaluate_policy(trpl, env)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assert final_performance > initial_performance, "TRPL should improve performance after training"
 | 
					def evaluate_policy(policy, env, n_eval_episodes=3):
 | 
				
			||||||
 | 
					 | 
				
			||||||
def evaluate_policy(policy, env, n_eval_episodes=10):
 | 
					 | 
				
			||||||
    total_reward = 0
 | 
					    total_reward = 0
 | 
				
			||||||
    for _ in range(n_eval_episodes):
 | 
					    for _ in range(n_eval_episodes):
 | 
				
			||||||
        obs, _ = env.reset()
 | 
					        tensordict = env.reset()
 | 
				
			||||||
        done = False
 | 
					        done = False
 | 
				
			||||||
        while not done:
 | 
					        while not done:
 | 
				
			||||||
            action, _ = policy.predict(obs)
 | 
					            action = policy.predict(tensordict)
 | 
				
			||||||
            obs, reward, terminated, truncated, _ = env.step(action)
 | 
					            next_tensordict = env.step(action).get("next")
 | 
				
			||||||
            total_reward += reward
 | 
					            total_reward += next_tensordict["reward"]
 | 
				
			||||||
            done = terminated or truncated
 | 
					            done = next_tensordict["done"]
 | 
				
			||||||
 | 
					            tensordict = next_tensordict
 | 
				
			||||||
    return total_reward / n_eval_episodes
 | 
					    return total_reward / n_eval_episodes
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user