diff --git a/fancy_rl/__init__.py b/fancy_rl/__init__.py index c0e7081..323df6b 100644 --- a/fancy_rl/__init__.py +++ b/fancy_rl/__init__.py @@ -4,6 +4,6 @@ try: except ImportError: pass -from fancy_rl.ppo import PPO +from fancy_rl.algos import PPO __all__ = ["PPO"] \ No newline at end of file diff --git a/fancy_rl/algos/__init__.py b/fancy_rl/algos/__init__.py new file mode 100644 index 0000000..2ae3d06 --- /dev/null +++ b/fancy_rl/algos/__init__.py @@ -0,0 +1 @@ +from fancy_rl.algos.ppo import PPO \ No newline at end of file diff --git a/fancy_rl/on_policy.py b/fancy_rl/algos/on_policy.py similarity index 100% rename from fancy_rl/on_policy.py rename to fancy_rl/algos/on_policy.py diff --git a/fancy_rl/ppo.py b/fancy_rl/algos/ppo.py similarity index 98% rename from fancy_rl/ppo.py rename to fancy_rl/algos/ppo.py index 9aeeed1..6481835 100644 --- a/fancy_rl/ppo.py +++ b/fancy_rl/algos/ppo.py @@ -2,7 +2,7 @@ import torch from torchrl.modules import ActorValueOperator, ProbabilisticActor from torchrl.objectives import ClipPPOLoss from torchrl.objectives.value.advantages import GAE -from fancy_rl.on_policy import OnPolicy +from fancy_rl.algos.on_policy import OnPolicy from fancy_rl.policy import Actor, Critic, SharedModule class PPO(OnPolicy):