diff --git a/metastable_baselines2/common/on_policy_algorithm.py b/metastable_baselines2/common/on_policy_algorithm.py index 7194093..2af3175 100644 --- a/metastable_baselines2/common/on_policy_algorithm.py +++ b/metastable_baselines2/common/on_policy_algorithm.py @@ -92,10 +92,6 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): assert not rollout_buffer_class and not rollout_buffer_kwargs - if 'dist_kwargs' not in policy_kwargs: - policy_kwargs['dist_kwargs'] = {} - policy_kwargs['dist_kwargs']['n_envs'] = len(env.envs) - super().__init__( policy=policy, env=env, @@ -122,6 +118,10 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): _init_setup_model=_init_setup_model ) + if 'dist_kwargs' not in self.policy_kwargs: + self.policy_kwargs['dist_kwargs'] = {} + self.policy_kwargs['dist_kwargs']['n_envs'] = len(self.env.envs) + self.rollout_buffer_class = None self.rollout_buffer_kwargs = {}