Fix problem when env is provided just as id
This commit is contained in:
parent
1c3d3cf6cf
commit
39c21ab6b9
@ -92,10 +92,6 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
|
|
||||||
assert not rollout_buffer_class and not rollout_buffer_kwargs
|
assert not rollout_buffer_class and not rollout_buffer_kwargs
|
||||||
|
|
||||||
if 'dist_kwargs' not in policy_kwargs:
|
|
||||||
policy_kwargs['dist_kwargs'] = {}
|
|
||||||
policy_kwargs['dist_kwargs']['n_envs'] = len(env.envs)
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
policy=policy,
|
policy=policy,
|
||||||
env=env,
|
env=env,
|
||||||
@ -122,6 +118,10 @@ class BetterOnPolicyAlgorithm(OnPolicyAlgorithm):
|
|||||||
_init_setup_model=_init_setup_model
|
_init_setup_model=_init_setup_model
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if 'dist_kwargs' not in self.policy_kwargs:
|
||||||
|
self.policy_kwargs['dist_kwargs'] = {}
|
||||||
|
self.policy_kwargs['dist_kwargs']['n_envs'] = len(self.env.envs)
|
||||||
|
|
||||||
self.rollout_buffer_class = None
|
self.rollout_buffer_class = None
|
||||||
self.rollout_buffer_kwargs = {}
|
self.rollout_buffer_kwargs = {}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user