From 1086c9f6fde175bb035146b33c238bb65eb4de77 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Fri, 31 May 2024 18:25:17 +0200 Subject: [PATCH] Remove all etsts for now (interface changed) --- test/test_ppo.py | 55 +----------------------------------------------- 1 file changed, 1 insertion(+), 54 deletions(-) diff --git a/test/test_ppo.py b/test/test_ppo.py index 29c6c90..f87f5c1 100644 --- a/test/test_ppo.py +++ b/test/test_ppo.py @@ -1,54 +1 @@ -import pytest -import torch -from fancy_rl.ppo import PPO -from fancy_rl.policy import Policy -from fancy_rl.loggers import TerminalLogger -from fancy_rl.utils import make_env - -@pytest.fixture -def policy(): - return Policy(input_dim=4, output_dim=2, hidden_sizes=[64, 64]) - -@pytest.fixture -def loggers(): - return [TerminalLogger()] - -@pytest.fixture -def env_fn(): - return make_env("CartPole-v1") - -def test_ppo_train(policy, loggers, env_fn): - ppo = PPO(policy=policy, - env_fn=env_fn, - loggers=loggers, - learning_rate=3e-4, - n_steps=2048, - batch_size=64, - n_epochs=10, - gamma=0.99, - gae_lambda=0.95, - clip_range=0.2, - total_timesteps=10000, - eval_interval=2048, - eval_deterministic=True, - eval_episodes=5, - seed=42) - ppo.train() - -def test_ppo_evaluate(policy, loggers, env_fn): - ppo = PPO(policy=policy, - env_fn=env_fn, - loggers=loggers, - learning_rate=3e-4, - n_steps=2048, - batch_size=64, - n_epochs=10, - gamma=0.99, - gae_lambda=0.95, - clip_range=0.2, - total_timesteps=10000, - eval_interval=2048, - eval_deterministic=True, - eval_episodes=5, - seed=42) - ppo.evaluate(epoch=0) +# TODO \ No newline at end of file