Morer objectives and fixes
This commit is contained in:
		
							parent
							
								
									33b5db2f57
								
							
						
					
					
						commit
						4c3ad983fc
					
				
							
								
								
									
										70
									
								
								nucon/rl.py
									
									
									
									
									
								
							
							
						
						
									
										70
									
								
								nucon/rl.py
									
									
									
									
									
								
							@ -3,7 +3,7 @@ from gymnasium import spaces
 | 
			
		||||
import numpy as np
 | 
			
		||||
import time
 | 
			
		||||
from typing import Dict, Any
 | 
			
		||||
from .core import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus
 | 
			
		||||
from nucon import Nucon, BreakerStatus, PumpStatus, PumpDryStatus, PumpOverloadStatus
 | 
			
		||||
 | 
			
		||||
Objectives = {
 | 
			
		||||
    "null": lambda obs: 0,
 | 
			
		||||
@ -13,12 +13,16 @@ Objectives = {
 | 
			
		||||
 | 
			
		||||
Parameterized_Objectives = {
 | 
			
		||||
    "target_temperature": lambda goal_temp: lambda obs: -((obs["CORE_TEMP"] - goal_temp) ** 2),
 | 
			
		||||
    "target_gap": lambda goal_gap: lambda obs: -((obs["CORE_TEMP"] - obs["CORE_TEMP_MIN"] - goal_gap) ** 2),
 | 
			
		||||
    "temp_below": lambda max_temp: lambda obs: -(np.clip(obs["CORE_TEMP"] - max_temp, 0, np.inf) ** 2),
 | 
			
		||||
    "temp_above": lambda min_temp: lambda obs: -(np.clip(min_temp - obs["CORE_TEMP"], 0, np.inf) ** 2),
 | 
			
		||||
    "constant": lambda constant: lambda obs: constant,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class NuconEnv(gym.Env):
 | 
			
		||||
    metadata = {'render_modes': ['human']}
 | 
			
		||||
 | 
			
		||||
    def __init__(self, render_mode=None, seconds_per_step=5, objectives=['null'], terminators=['null'], objective_weights=None, terminate_above=0):
 | 
			
		||||
    def __init__(self, nucon=None, simulator=None, render_mode=None, seconds_per_step=5, objectives=['null'], terminators=['null'], objective_weights=None, terminate_above=0):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.render_mode = render_mode
 | 
			
		||||
@ -26,22 +30,30 @@ class NuconEnv(gym.Env):
 | 
			
		||||
        if objective_weights is None:
 | 
			
		||||
            objective_weights = [1.0 for objective in objectives]
 | 
			
		||||
        self.objective_weights = objective_weights
 | 
			
		||||
        self.terminate_at = terminate_at
 | 
			
		||||
        self.terminate_above = terminate_above
 | 
			
		||||
        self.simulator = simulator
 | 
			
		||||
 | 
			
		||||
        if nucon is None:
 | 
			
		||||
            if simulator:
 | 
			
		||||
                nucon = Nucon(port=simulator.port)
 | 
			
		||||
            else:
 | 
			
		||||
                nucon = Nucon()
 | 
			
		||||
        self.nucon = nucon
 | 
			
		||||
 | 
			
		||||
        # Define observation space
 | 
			
		||||
        obs_spaces = {'EPISODE_TIME': spaces.Box(low=0, high=np.inf, shape=(1,), dtype=np.float32)}
 | 
			
		||||
        for param in Nucon.get_all_readable():
 | 
			
		||||
        for param_id, param in self.nucon.get_all_readable().items():
 | 
			
		||||
            if param.param_type == float:
 | 
			
		||||
                obs_spaces[param.id] = spaces.Box(low=param.min_val or -np.inf, high=param.max_val or np.inf, shape=(1,), dtype=np.float32)
 | 
			
		||||
                obs_spaces[param_id] = spaces.Box(low=param.min_val or -np.inf, high=param.max_val or np.inf, shape=(1,), dtype=np.float32)
 | 
			
		||||
            elif param.param_type == int:
 | 
			
		||||
                if param.min_val is not None and param.max_val is not None:
 | 
			
		||||
                    obs_spaces[param.id] = spaces.Box(low=param.min_val, high=param.max_val, shape=(1,), dtype=np.float32)
 | 
			
		||||
                    obs_spaces[param_id] = spaces.Box(low=param.min_val, high=param.max_val, shape=(1,), dtype=np.float32)
 | 
			
		||||
                else:
 | 
			
		||||
                    obs_spaces[param.id] = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32)
 | 
			
		||||
                    obs_spaces[param_id] = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32)
 | 
			
		||||
            elif param.param_type == bool:
 | 
			
		||||
                obs_spaces[param.id] = spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)
 | 
			
		||||
                obs_spaces[param_id] = spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)
 | 
			
		||||
            elif issubclass(param.param_type, Enum):
 | 
			
		||||
                obs_spaces[param.id] = spaces.Box(low=0, high=1, shape=(len(param.param_type),), dtype=np.float32)
 | 
			
		||||
                obs_spaces[param_id] = spaces.Box(low=0, high=1, shape=(len(param.param_type),), dtype=np.float32)
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError(f"Unsupported observation parameter type: {param.param_type}")
 | 
			
		||||
 | 
			
		||||
@ -49,23 +61,26 @@ class NuconEnv(gym.Env):
 | 
			
		||||
 | 
			
		||||
        # Define action space
 | 
			
		||||
        action_spaces = {}
 | 
			
		||||
        for param in Nucon.get_all_writable():
 | 
			
		||||
        for param_id, param in self.nucon.get_all_writable().items():
 | 
			
		||||
            if param.param_type == float:
 | 
			
		||||
                action_spaces[param.id] = spaces.Box(low=param.min_val or -np.inf, high=param.max_val or np.inf, shape=(1,), dtype=np.float32)
 | 
			
		||||
                action_spaces[param_id] = spaces.Box(low=param.min_val or -np.inf, high=param.max_val or np.inf, shape=(1,), dtype=np.float32)
 | 
			
		||||
            elif param.param_type == int:
 | 
			
		||||
                if param.min_val is not None and param.max_val is not None:
 | 
			
		||||
                    action_spaces[param.id] = spaces.Box(low=param.min_val, high=param.max_val, shape=(1,), dtype=np.float32)
 | 
			
		||||
                    action_spaces[param_id] = spaces.Box(low=param.min_val, high=param.max_val, shape=(1,), dtype=np.float32)
 | 
			
		||||
                else:
 | 
			
		||||
                    action_spaces[param.id] = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32)
 | 
			
		||||
                    action_spaces[param_id] = spaces.Box(low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32)
 | 
			
		||||
            elif param.param_type == bool:
 | 
			
		||||
                action_spaces[param.id] = spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)
 | 
			
		||||
                action_spaces[param_id] = spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32)
 | 
			
		||||
            elif issubclass(param.param_type, Enum):
 | 
			
		||||
                action_spaces[param.id] = spaces.Box(low=0, high=1, shape=(len(param.param_type),), dtype=np.float32)
 | 
			
		||||
                action_spaces[param_id] = spaces.Box(low=0, high=1, shape=(len(param.param_type),), dtype=np.float32)
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError(f"Unsupported action parameter type: {param.param_type}")
 | 
			
		||||
 | 
			
		||||
        self.action_space = spaces.Dict(action_spaces)
 | 
			
		||||
 | 
			
		||||
        self.objectives = []
 | 
			
		||||
        self.terminators = []
 | 
			
		||||
 | 
			
		||||
        for objective in objectives:
 | 
			
		||||
            if objective in Objectives:
 | 
			
		||||
                self.objectives.append(Objectives[objective])
 | 
			
		||||
@ -84,11 +99,11 @@ class NuconEnv(gym.Env):
 | 
			
		||||
 | 
			
		||||
    def _get_obs(self):
 | 
			
		||||
        obs = {}
 | 
			
		||||
        for param in Nucon.get_all_readable():
 | 
			
		||||
            value = Nucon.get(param)
 | 
			
		||||
        for param_id, param in self.nucon.get_all_readable().items():
 | 
			
		||||
            value = self.nucon.get(param_id)
 | 
			
		||||
            if isinstance(value, Enum):
 | 
			
		||||
                value = value.value
 | 
			
		||||
            obs[param.id] = value
 | 
			
		||||
            obs[param_id] = value
 | 
			
		||||
        obs["EPISODE_TIME"] = self._total_steps * self.seconds_per_step
 | 
			
		||||
        return obs
 | 
			
		||||
 | 
			
		||||
@ -112,12 +127,12 @@ class NuconEnv(gym.Env):
 | 
			
		||||
    def step(self, action):
 | 
			
		||||
        # Apply the action to the Nucon system
 | 
			
		||||
        for param_id, value in action.items():
 | 
			
		||||
            param = next(p for p in Nucon if p.id == param_id)
 | 
			
		||||
            param = next(p for p in self.nucon if p.id == param_id)
 | 
			
		||||
            if issubclass(param.param_type, Enum):
 | 
			
		||||
                value = param.param_type(value)
 | 
			
		||||
            if param.min_val is not None and param.max_val is not None:
 | 
			
		||||
                value = np.clip(value, param.min_val, param.max_val)
 | 
			
		||||
            Nucon.set(param, value)
 | 
			
		||||
            self.nucon.set(param, value)
 | 
			
		||||
 | 
			
		||||
        observation = self._get_obs()
 | 
			
		||||
        terminated = np.sum([terminator(observation) for terminator in self.terminators]) > self.terminate_above
 | 
			
		||||
@ -126,7 +141,10 @@ class NuconEnv(gym.Env):
 | 
			
		||||
        reward = sum(obj for obj in info['objectives_weighted'].values())
 | 
			
		||||
 | 
			
		||||
        self._total_steps += 1
 | 
			
		||||
        time.sleep(self.seconds_per_step)
 | 
			
		||||
        if self.simulator:
 | 
			
		||||
            self.simulator.update(self.seconds_per_step)
 | 
			
		||||
        else:
 | 
			
		||||
            time.sleep(self.seconds_per_step)
 | 
			
		||||
        return observation, reward, terminated, truncated, info
 | 
			
		||||
 | 
			
		||||
    def render(self):
 | 
			
		||||
@ -148,6 +166,7 @@ class NuconEnv(gym.Env):
 | 
			
		||||
    def _unflatten_observation(self, flat_observation):
 | 
			
		||||
        return {k: v.reshape(1, -1) for k, v in self.observation_space.items()}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def register_nucon_envs():
 | 
			
		||||
    gym.register(
 | 
			
		||||
        id='Nucon-max_power-v0',
 | 
			
		||||
@ -155,9 +174,14 @@ def register_nucon_envs():
 | 
			
		||||
        kwargs={'seconds_per_step': 5, 'objectives': ['max_power']}
 | 
			
		||||
    )
 | 
			
		||||
    gym.register(
 | 
			
		||||
        id='Nucon-target_temperature_600-v0',
 | 
			
		||||
        id='Nucon-target_temperature_350-v0',
 | 
			
		||||
        entry_point='nucon.rl:NuconEnv',
 | 
			
		||||
        kwargs={'seconds_per_step': 5, 'objectives': [Parameterized_Objectives['target_temperature'](goal_temp=600)]}
 | 
			
		||||
        kwargs={'seconds_per_step': 5, 'objectives': [Parameterized_Objectives['target_temperature'](goal_temp=350)]}
 | 
			
		||||
    )
 | 
			
		||||
    gym.register(
 | 
			
		||||
        id='Nucon-safe_max_power-v0',
 | 
			
		||||
        entry_point='nucon.rl:NuconEnv',
 | 
			
		||||
        kwargs={'seconds_per_step': 5, 'objectives': [Parameterized_Objectives['temp_above'](min_temp=310), Parameterized_Objectives['temp_below'](max_temp=365), 'max_power'], 'objective_weights': [1, 10, 1/100_000]}
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
register_nucon_envs()
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user