From 54bab221ef282e9d04fa68e1a7ab3482f67d9f95 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Wed, 28 Aug 2024 11:55:43 +0200 Subject: [PATCH] Disable vlearn for now... --- fancy_rl/__init__.py | 11 ++--------- fancy_rl/algos/__init__.py | 4 +++- fancy_rl/algos/vlearn.py | 7 +++---- 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/fancy_rl/__init__.py b/fancy_rl/__init__.py index 011f43e..fde13b4 100644 --- a/fancy_rl/__init__.py +++ b/fancy_rl/__init__.py @@ -1,10 +1,3 @@ -import gymnasium -try: - import fancy_gym -except ImportError: - pass +from fancy_rl.algos import PPO, TRPL #, VLEARN -from fancy_rl.algos import PPO, TRPL, VLEARN -from fancy_rl.projections import get_projection - -__all__ = ["PPO", "TRPL", "VLEARN", "get_projection"] \ No newline at end of file +__all__ = ['PPO', 'TRPL'] \ No newline at end of file diff --git a/fancy_rl/algos/__init__.py b/fancy_rl/algos/__init__.py index 040bf9c..7ca1c62 100644 --- a/fancy_rl/algos/__init__.py +++ b/fancy_rl/algos/__init__.py @@ -1,3 +1,5 @@ from fancy_rl.algos.ppo import PPO from fancy_rl.algos.trpl import TRPL -from fancy_rl.algos.vlearn import VLEARN \ No newline at end of file +#from fancy_rl.algos.vlearn import VLEARN + +__all__ = ['PPO', 'TRPL'] \ No newline at end of file diff --git a/fancy_rl/algos/vlearn.py b/fancy_rl/algos/vlearn.py index 384ec0e..acc8d43 100644 --- a/fancy_rl/algos/vlearn.py +++ b/fancy_rl/algos/vlearn.py @@ -5,10 +5,9 @@ from torchrl.modules import ProbabilisticActor, ValueOperator from torchrl.collectors import SyncDataCollector from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage -from fancy_rl.utils import get_env, get_actor, get_critic -from fancy_rl.modules.vlearn_loss import VLEARNLoss -from fancy_rl.modules.projection import get_vlearn_projection -from fancy_rl.modules.squashed_normal import get_squashed_normal +from fancy_rl.objectives.vlearn import VLEARNLoss +from fancy_rl.projections import get_vlearn_projection +from fancy_rl.utils import get_squashed_normal class VLEARN: def __init__(self, env_id: str, device: str = "cpu", **kwargs: Any):