Compare commits

..

3 Commits

Author SHA1 Message Date
54bab221ef Disable vlearn for now... 2024-08-28 11:55:43 +02:00
1a02568f3c Rename frob_projection file 2024-08-28 11:55:30 +02:00
0464fbabe8 Disable vlearn test for now 2024-08-28 11:55:04 +02:00
5 changed files with 8 additions and 14 deletions

View File

@ -1,10 +1,3 @@
import gymnasium
try:
import fancy_gym
except ImportError:
pass
from fancy_rl.algos import PPO, TRPL #, VLEARN
from fancy_rl.algos import PPO, TRPL, VLEARN
from fancy_rl.projections import get_projection
__all__ = ["PPO", "TRPL", "VLEARN", "get_projection"]
__all__ = ['PPO', 'TRPL']

View File

@ -1,3 +1,5 @@
from fancy_rl.algos.ppo import PPO
from fancy_rl.algos.trpl import TRPL
from fancy_rl.algos.vlearn import VLEARN
#from fancy_rl.algos.vlearn import VLEARN
__all__ = ['PPO', 'TRPL']

View File

@ -5,10 +5,9 @@ from torchrl.modules import ProbabilisticActor, ValueOperator
from torchrl.collectors import SyncDataCollector
from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage
from fancy_rl.utils import get_env, get_actor, get_critic
from fancy_rl.modules.vlearn_loss import VLEARNLoss
from fancy_rl.modules.projection import get_vlearn_projection
from fancy_rl.modules.squashed_normal import get_squashed_normal
from fancy_rl.objectives.vlearn import VLEARNLoss
from fancy_rl.projections import get_vlearn_projection
from fancy_rl.utils import get_squashed_normal
class VLEARN:
def __init__(self, env_id: str, device: str = "cpu", **kwargs: Any):