Compare commits
3 Commits
f7d171399f
...
54bab221ef
Author | SHA1 | Date | |
---|---|---|---|
54bab221ef | |||
1a02568f3c | |||
0464fbabe8 |
@ -1,10 +1,3 @@
|
|||||||
import gymnasium
|
from fancy_rl.algos import PPO, TRPL #, VLEARN
|
||||||
try:
|
|
||||||
import fancy_gym
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
from fancy_rl.algos import PPO, TRPL, VLEARN
|
__all__ = ['PPO', 'TRPL']
|
||||||
from fancy_rl.projections import get_projection
|
|
||||||
|
|
||||||
__all__ = ["PPO", "TRPL", "VLEARN", "get_projection"]
|
|
@ -1,3 +1,5 @@
|
|||||||
from fancy_rl.algos.ppo import PPO
|
from fancy_rl.algos.ppo import PPO
|
||||||
from fancy_rl.algos.trpl import TRPL
|
from fancy_rl.algos.trpl import TRPL
|
||||||
from fancy_rl.algos.vlearn import VLEARN
|
#from fancy_rl.algos.vlearn import VLEARN
|
||||||
|
|
||||||
|
__all__ = ['PPO', 'TRPL']
|
@ -5,10 +5,9 @@ from torchrl.modules import ProbabilisticActor, ValueOperator
|
|||||||
from torchrl.collectors import SyncDataCollector
|
from torchrl.collectors import SyncDataCollector
|
||||||
from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage
|
from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage
|
||||||
|
|
||||||
from fancy_rl.utils import get_env, get_actor, get_critic
|
from fancy_rl.objectives.vlearn import VLEARNLoss
|
||||||
from fancy_rl.modules.vlearn_loss import VLEARNLoss
|
from fancy_rl.projections import get_vlearn_projection
|
||||||
from fancy_rl.modules.projection import get_vlearn_projection
|
from fancy_rl.utils import get_squashed_normal
|
||||||
from fancy_rl.modules.squashed_normal import get_squashed_normal
|
|
||||||
|
|
||||||
class VLEARN:
|
class VLEARN:
|
||||||
def __init__(self, env_id: str, device: str = "cpu", **kwargs: Any):
|
def __init__(self, env_id: str, device: str = "cpu", **kwargs: Any):
|
||||||
|
Loading…
Reference in New Issue
Block a user