diff --git a/metastable_baselines2/trpl/trpl.py b/metastable_baselines2/trpl/trpl.py index 7f91a87..ea16a50 100644 --- a/metastable_baselines2/trpl/trpl.py +++ b/metastable_baselines2/trpl/trpl.py @@ -13,6 +13,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedul from stable_baselines3.common.utils import explained_variance, get_schedule_fn from metastable_projections import BaseProjectionLayer, FrobeniusProjectionLayer, WassersteinProjectionLayer, KLProjectionLayer +import metastable_projections SelfTRPL = TypeVar("SelfTRPL", bound="TRPL")