diff --git a/metastable_baselines2/common/on_policy_algorithm.py b/metastable_baselines2/common/on_policy_algorithm.py index df3d239..7194093 100644 --- a/metastable_baselines2/common/on_policy_algorithm.py +++ b/metastable_baselines2/common/on_policy_algorithm.py @@ -92,6 +92,8 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm): assert not rollout_buffer_class and not rollout_buffer_kwargs + if 'dist_kwargs' not in policy_kwargs: + policy_kwargs['dist_kwargs'] = {} policy_kwargs['dist_kwargs']['n_envs'] = len(env.envs) super().__init__(