Compare commits

..

No commits in common. "54bab221ef282e9d04fa68e1a7ab3482f67d9f95" and "f7d171399f9e8ac56a372962ccf481b3db30c1e7" have entirely different histories.

5 changed files with 14 additions and 8 deletions

View File

@ -1,3 +1,10 @@
from fancy_rl.algos import PPO, TRPL #, VLEARN
import gymnasium
try:
import fancy_gym
except ImportError:
pass
__all__ = ['PPO', 'TRPL']
from fancy_rl.algos import PPO, TRPL, VLEARN
from fancy_rl.projections import get_projection
__all__ = ["PPO", "TRPL", "VLEARN", "get_projection"]

View File

@ -1,5 +1,3 @@
from fancy_rl.algos.ppo import PPO
from fancy_rl.algos.trpl import TRPL
#from fancy_rl.algos.vlearn import VLEARN
__all__ = ['PPO', 'TRPL']
from fancy_rl.algos.vlearn import VLEARN

View File

@ -5,9 +5,10 @@ from torchrl.modules import ProbabilisticActor, ValueOperator
from torchrl.collectors import SyncDataCollector
from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage
from fancy_rl.objectives.vlearn import VLEARNLoss
from fancy_rl.projections import get_vlearn_projection
from fancy_rl.utils import get_squashed_normal
from fancy_rl.utils import get_env, get_actor, get_critic
from fancy_rl.modules.vlearn_loss import VLEARNLoss
from fancy_rl.modules.projection import get_vlearn_projection
from fancy_rl.modules.squashed_normal import get_squashed_normal
class VLEARN:
def __init__(self, env_id: str, device: str = "cpu", **kwargs: Any):