Upgrading to SB3 1.7 (probably broke some stuff...)
This commit is contained in:
		
							parent
							
								
									c47d5741ca
								
							
						
					
					
						commit
						f3e03916c8
					
				@ -107,7 +107,6 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
 | 
			
		||||
        sde_sample_freq: int = -1,
 | 
			
		||||
        target_kl: Optional[float] = None,
 | 
			
		||||
        tensorboard_log: Optional[str] = None,
 | 
			
		||||
        create_eval_env: bool = False,
 | 
			
		||||
        policy_kwargs: Optional[Dict[str, Any]] = {},
 | 
			
		||||
        verbose: int = 0,
 | 
			
		||||
        seed: Optional[int] = None,
 | 
			
		||||
@ -141,7 +140,6 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
 | 
			
		||||
                projection, WassersteinProjectionLayer)},
 | 
			
		||||
            verbose=verbose,
 | 
			
		||||
            device=device,
 | 
			
		||||
            create_eval_env=create_eval_env,
 | 
			
		||||
            seed=seed,
 | 
			
		||||
            _init_setup_model=False,
 | 
			
		||||
            supported_action_spaces=(
 | 
			
		||||
@ -428,11 +426,7 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
 | 
			
		||||
        total_timesteps: int,
 | 
			
		||||
        callback: MaybeCallback = None,
 | 
			
		||||
        log_interval: int = 1,
 | 
			
		||||
        eval_env: Optional[GymEnv] = None,
 | 
			
		||||
        eval_freq: int = -1,
 | 
			
		||||
        n_eval_episodes: int = 5,
 | 
			
		||||
        tb_log_name: str = "PPO",
 | 
			
		||||
        eval_log_path: Optional[str] = None,
 | 
			
		||||
        reset_num_timesteps: bool = True,
 | 
			
		||||
    ) -> "PPO":
 | 
			
		||||
 | 
			
		||||
@ -440,10 +434,6 @@ class PPO(GaussianRolloutCollectorAuxclass, OnPolicyAlgorithm):
 | 
			
		||||
            total_timesteps=total_timesteps,
 | 
			
		||||
            callback=callback,
 | 
			
		||||
            log_interval=log_interval,
 | 
			
		||||
            eval_env=eval_env,
 | 
			
		||||
            eval_freq=eval_freq,
 | 
			
		||||
            n_eval_episodes=n_eval_episodes,
 | 
			
		||||
            tb_log_name=tb_log_name,
 | 
			
		||||
            eval_log_path=eval_log_path,
 | 
			
		||||
            reset_num_timesteps=reset_num_timesteps,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -121,7 +121,6 @@ class SAC(OffPolicyAlgorithm):
 | 
			
		||||
        sde_sample_freq: int = -1,
 | 
			
		||||
        use_sde_at_warmup: bool = False,
 | 
			
		||||
        tensorboard_log: Optional[str] = None,
 | 
			
		||||
        create_eval_env: bool = False,
 | 
			
		||||
        action_coef: float = 0.0,
 | 
			
		||||
        policy_kwargs: Optional[Dict[str, Any]] = None,
 | 
			
		||||
        verbose: int = 0,
 | 
			
		||||
@ -153,7 +152,6 @@ class SAC(OffPolicyAlgorithm):
 | 
			
		||||
            tensorboard_log=tensorboard_log,
 | 
			
		||||
            verbose=verbose,
 | 
			
		||||
            device=device,
 | 
			
		||||
            create_eval_env=create_eval_env,
 | 
			
		||||
            seed=seed,
 | 
			
		||||
            use_sde=use_sde,
 | 
			
		||||
            sde_sample_freq=sde_sample_freq,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user