Compare commits
No commits in common. "7fcc8098521916b05cc867ddd282eadf75690044" and "2ec68ff2e552497abece7c02fbe6f4dd302d228f" have entirely different histories.
7fcc809852
...
2ec68ff2e5
91
README.md
91
README.md
@ -72,10 +72,7 @@ Parameter properties:
|
|||||||
- `nucon.<PARAMETER>.value`: Get or set the current value of the parameter. Assigning a new value will write it to the game.
|
- `nucon.<PARAMETER>.value`: Get or set the current value of the parameter. Assigning a new value will write it to the game.
|
||||||
- `nucon.<PARAMETER>.param_type`: Get the type of the parameter
|
- `nucon.<PARAMETER>.param_type`: Get the type of the parameter
|
||||||
- `nucon.<PARAMETER>.is_writable`: Check if the parameter is writable
|
- `nucon.<PARAMETER>.is_writable`: Check if the parameter is writable
|
||||||
- `nucon.<PARAMETER>.is_readable`: `False` for write-only parameters (e.g. VALVE_OPEN, CORE_SCRAM_BUTTON). Reading raises `AttributeError`.
|
|
||||||
- `nucon.<PARAMETER>.is_cheat`: `True` for game-event triggers (all `FUN_*`). Writing raises `ValueError` unless `cheat_mode=True`.
|
|
||||||
- `nucon.<PARAMETER>.enum_type`: Get the enum type of the parameter if it's an enum, otherwise None
|
- `nucon.<PARAMETER>.enum_type`: Get the enum type of the parameter if it's an enum, otherwise None
|
||||||
- `nucon.<PARAMETER>.unit`: Unit string if defined (e.g. `'°C'`, `'bar'`, `'%'`)
|
|
||||||
|
|
||||||
Parameter methods:
|
Parameter methods:
|
||||||
- `nucon.<PARAMETER>.read()`: Get the current value of the parameter (alias for `value`)
|
- `nucon.<PARAMETER>.read()`: Get the current value of the parameter (alias for `value`)
|
||||||
@ -83,23 +80,14 @@ Parameter methods:
|
|||||||
|
|
||||||
Class methods:
|
Class methods:
|
||||||
- `nucon.get(parameter)`: Get the value of a specific parameter. Also accepts string parameter names.
|
- `nucon.get(parameter)`: Get the value of a specific parameter. Also accepts string parameter names.
|
||||||
- `nucon.set(parameter, value, force=False)`: Set the value of a specific parameter. Also accepts string parameter names. `force` bypasses writable/range/cheat checks.
|
- `nucon.set(parameter, value, force=False)`: Set the value of a specific parameter. Also accepts string parameter names. `force` will try to write even if the parameter is known as non-writable or out of known allowed range.
|
||||||
- `nucon.get_all_readable()`: Get a dict of all readable parameters.
|
- `nucon.get_all_readable()`: Get a list of all readable parameters (which is all parameters)
|
||||||
- `nucon.get_all_writable()`: Get a dict of all writable parameters (includes write-only params).
|
- `nucon.get_all_writable()`: Get a list of all writable parameters
|
||||||
- `nucon.get_all()`: Get all readable parameter values as a dictionary.
|
- `nucon.get_all()`: Get all parameter values as a dictionary
|
||||||
- `nucon.get_all_iter()`: Get all readable parameter values as a generator.
|
- `nucon.get_all_iter()`: Get all parameter values as a generator
|
||||||
- `nucon.get_multiple(params)`: Get values for multiple specified parameters.
|
- `nucon.get_multiple(params)`: Get values for multiple specified parameters
|
||||||
- `nucon.get_multiple_iter(params)`: Get values for multiple specified parameters as a generator.
|
- `nucon.get_multiple_iter(params)`: Get values for multiple specified parameters as a generator
|
||||||
- `nucon.get_game_variable_names()`: Query the game for all exposed variable names (GET and POST), excluding special endpoints.
|
- `nucon.set_dummy_mode(dummy_mode)`: Enable or disable dummy mode for testing. In dummy mode we won't connect to the game and just return sensible values for all params and allow but ignore all writes to writable parameters.
|
||||||
- `nucon.set_dummy_mode(dummy_mode)`: In dummy mode, returns sensible values without connecting to the game and silently ignores writes.
|
|
||||||
- `nucon.set_cheat_mode(cheat_mode)`: Enable writing to cheat parameters (`FUN_*` event triggers). Default `False`.
|
|
||||||
|
|
||||||
Valve API (motorized actuators: OPEN/CLOSE powers the motor, OFF holds current position):
|
|
||||||
- `nucon.get_valve(name)`: Get state dict for a single valve (`Value`, `IsOpened`, `IsClosed`, `Stuck`, …).
|
|
||||||
- `nucon.get_valves()`: Get state dict for all 53 valves.
|
|
||||||
- `nucon.open_valve(name)` / `nucon.open_valves(names)`: Power actuator toward open.
|
|
||||||
- `nucon.close_valve(name)` / `nucon.close_valves(names)`: Power actuator toward closed.
|
|
||||||
- `nucon.off_valve(name)` / `nucon.off_valves(names)`: Cut actuator power, hold current position (normal resting state).
|
|
||||||
|
|
||||||
Custom Enum Types:
|
Custom Enum Types:
|
||||||
- `PumpStatus`: Enum for pump status (INACTIVE, ACTIVE_NO_SPEED_REACHED\*, ACTIVE_SPEED_REACHED\*, REQUIRES_MAINTENANCE, NOT_INSTALLED, INSUFFICIENT_ENERGY)
|
- `PumpStatus`: Enum for pump status (INACTIVE, ACTIVE_NO_SPEED_REACHED\*, ACTIVE_SPEED_REACHED\*, REQUIRES_MAINTENANCE, NOT_INSTALLED, INSUFFICIENT_ENERGY)
|
||||||
@ -246,62 +234,41 @@ But theres yet another problem: We do not know the exact simulation dynamics of
|
|||||||
|
|
||||||
## Model Learning (Work in Progress)
|
## Model Learning (Work in Progress)
|
||||||
|
|
||||||
To address the challenge of unknown game dynamics, NuCon provides tools for collecting data, creating datasets, and training models to learn the reactor dynamics. Key features include:
|
To address the challenge of unknown game dynamics, NuCon provides tools for collecting data, creating datasets, and training models to learn the reactor dynamics. This approach allows for more accurate simulations and enables model-based control strategies. Key features include:
|
||||||
|
|
||||||
- **Data Collection**: Gathers state transitions from human play or automated agents. `time_delta` is specified in game-time seconds; wall-clock sleep is automatically adjusted for `GAME_SIM_SPEED` so collected deltas are uniform regardless of simulation speed.
|
- Data Collection: Supports gathering state transitions from both human play and automated agents.
|
||||||
- **Automatic param filtering**: Junk params (GAME_VERSION, TIME, ALARMS_ACTIVE, …) and params from uninstalled subsystems (returns `None`) are automatically excluded from model inputs/outputs.
|
- Dataset Management: Tools for saving, loading, and merging datasets.
|
||||||
- **Two model backends**: Neural network (NN) or k-Nearest Neighbours with GP interpolation (kNN).
|
- Model Training: Train neural network models to predict next states based on current states and time deltas.
|
||||||
- **Uncertainty estimation**: The kNN backend returns a GP posterior standard deviation alongside each prediction — 0 means the query lies on known data, ~1 means it is out of distribution.
|
- Dataset Refinement: Ability to refine datasets by focusing on more challenging or interesting data points.
|
||||||
- **Dataset management**: Tools for saving, loading, merging, and pruning datasets.
|
|
||||||
|
|
||||||
### Additional Dependencies
|
### Additional Dependencies
|
||||||
|
|
||||||
|
To use you'll need to install `torch` and `numpy`. You can do so via
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install -e '.[model]'
|
pip install -e '.[model]'
|
||||||
```
|
```
|
||||||
|
|
||||||
### Usage
|
### Usage:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from nucon.model import NuconModelLearner
|
from nucon.model import NuconModelLearner
|
||||||
|
|
||||||
# --- Data collection ---
|
# Initialize the model learner
|
||||||
learner = NuconModelLearner(
|
learner = NuconModelLearner()
|
||||||
time_delta=10.0, # 10 game-seconds per step (wall sleep auto-scales with sim speed)
|
|
||||||
include_valve_states=False, # set True to include all 53 valve positions as model inputs
|
# Collect data by querying the game
|
||||||
)
|
|
||||||
learner.collect_data(num_steps=1000)
|
learner.collect_data(num_steps=1000)
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
learner.train_model(batch_size=32, num_epochs=10)
|
||||||
|
|
||||||
|
# Refine the dataset
|
||||||
|
learner.refine_dataset(error_threshold=0.1)
|
||||||
|
|
||||||
|
# Save the model and dataset
|
||||||
|
learner.save_model('reactor_model.pth')
|
||||||
learner.save_dataset('reactor_dataset.pkl')
|
learner.save_dataset('reactor_dataset.pkl')
|
||||||
|
|
||||||
# Merge datasets collected across multiple sessions
|
|
||||||
learner.merge_datasets('other_session.pkl')
|
|
||||||
|
|
||||||
# --- Neural network backend ---
|
|
||||||
nn_learner = NuconModelLearner(model_type='nn', dataset_path='reactor_dataset.pkl')
|
|
||||||
nn_learner.train_model(batch_size=32, num_epochs=50)
|
|
||||||
# Drop samples the NN already predicts well (keep hard cases for further training)
|
|
||||||
nn_learner.drop_well_fitted(error_threshold=1.0)
|
|
||||||
nn_learner.save_model('reactor_nn.pth')
|
|
||||||
|
|
||||||
# --- kNN + GP backend ---
|
|
||||||
knn_learner = NuconModelLearner(model_type='knn', knn_k=10, dataset_path='reactor_dataset.pkl')
|
|
||||||
# Drop near-duplicate samples before fitting (keeps diverse coverage).
|
|
||||||
# A sample is dropped only if BOTH its input state AND output transition
|
|
||||||
# are within the given distances of an already-kept sample.
|
|
||||||
knn_learner.drop_redundant(min_state_distance=0.1, min_output_distance=0.05)
|
|
||||||
knn_learner.fit_knn()
|
|
||||||
|
|
||||||
# Point prediction
|
|
||||||
state = knn_learner._get_state()
|
|
||||||
pred = knn_learner.model.forward(state, time_delta=10.0)
|
|
||||||
|
|
||||||
# Prediction with uncertainty
|
|
||||||
pred, uncertainty = knn_learner.predict_with_uncertainty(state, time_delta=10.0)
|
|
||||||
print(f"CORE_TEMP: {pred['CORE_TEMP']:.1f} ± {uncertainty:.3f} (std, GP posterior)")
|
|
||||||
# uncertainty ≈ 0: confident (query near known data)
|
|
||||||
# uncertainty ≈ 1: out of distribution
|
|
||||||
|
|
||||||
knn_learner.save_model('reactor_knn.pkl')
|
|
||||||
```
|
```
|
||||||
|
|
||||||
The trained models can be integrated into the NuconSimulator to provide accurate dynamics based on real game data.
|
The trained models can be integrated into the NuconSimulator to provide accurate dynamics based on real game data.
|
||||||
|
|||||||
123
nucon/core.py
123
nucon/core.py
@ -68,13 +68,13 @@ SPECIAL_VARIABLES = frozenset({
|
|||||||
})
|
})
|
||||||
|
|
||||||
class NuconParameter:
|
class NuconParameter:
|
||||||
def __init__(self, nucon: 'Nucon', id: str, param_type: Type, is_writable: bool, min_val: Optional[Union[int, float]] = None, max_val: Optional[Union[int, float]] = None, unit: Optional[str] = None, is_readable: bool = True, is_cheat: bool = False):
|
def __init__(self, nucon: 'Nucon', id: str, param_type: Type, is_writable: bool, min_val: Optional[Union[int, float]] = None, max_val: Optional[Union[int, float]] = None, unit: Optional[str] = None, is_readable: bool = True, is_admin: bool = False):
|
||||||
self.nucon = nucon
|
self.nucon = nucon
|
||||||
self.id = id
|
self.id = id
|
||||||
self.param_type = param_type
|
self.param_type = param_type
|
||||||
self.is_writable = is_writable
|
self.is_writable = is_writable
|
||||||
self.is_readable = is_readable
|
self.is_readable = is_readable
|
||||||
self.is_cheat = is_cheat
|
self.is_admin = is_admin
|
||||||
self.min_val = min_val
|
self.min_val = min_val
|
||||||
self.max_val = max_val
|
self.max_val = max_val
|
||||||
self.unit = unit
|
self.unit = unit
|
||||||
@ -121,17 +121,17 @@ class NuconParameter:
|
|||||||
unit_str = f", unit='{self.unit}'" if self.unit else ""
|
unit_str = f", unit='{self.unit}'" if self.unit else ""
|
||||||
value_str = f", value={self.value}" if self.is_readable else ""
|
value_str = f", value={self.value}" if self.is_readable else ""
|
||||||
rw_str = "write-only" if not self.is_readable else f"is_writable={self.is_writable}"
|
rw_str = "write-only" if not self.is_readable else f"is_writable={self.is_writable}"
|
||||||
admin_str = ", is_cheat=True" if self.is_cheat else ""
|
admin_str = ", is_admin=True" if self.is_admin else ""
|
||||||
return f"NuconParameter(id='{self.id}'{value_str}, param_type={self.param_type.__name__}, {rw_str}{admin_str}{unit_str})"
|
return f"NuconParameter(id='{self.id}'{value_str}, param_type={self.param_type.__name__}, {rw_str}{admin_str}{unit_str})"
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.id
|
return self.id
|
||||||
|
|
||||||
class Nucon:
|
class Nucon:
|
||||||
def __init__(self, host: str = 'localhost', port: int = 8785, cheat_mode: bool = False):
|
def __init__(self, host: str = 'localhost', port: int = 8785, admin_mode: bool = False):
|
||||||
self.base_url = f'http://{host}:{port}/'
|
self.base_url = f'http://{host}:{port}/'
|
||||||
self.dummy_mode = False
|
self.dummy_mode = False
|
||||||
self.cheat_mode = cheat_mode
|
self.admin_mode = admin_mode
|
||||||
self._parameters = self._create_parameters()
|
self._parameters = self._create_parameters()
|
||||||
|
|
||||||
def _create_parameters(self) -> Dict[str, NuconParameter]:
|
def _create_parameters(self) -> Dict[str, NuconParameter]:
|
||||||
@ -145,7 +145,6 @@ class Nucon:
|
|||||||
'ALARMS_ACTIVE': (str, False),
|
'ALARMS_ACTIVE': (str, False),
|
||||||
'GAME_SIM_SPEED': (float, False),
|
'GAME_SIM_SPEED': (float, False),
|
||||||
'AMBIENT_TEMPERATURE': (float, False, None, None, '°C'),
|
'AMBIENT_TEMPERATURE': (float, False, None, None, '°C'),
|
||||||
'FUN_IS_ENABLED': (bool, False),
|
|
||||||
|
|
||||||
# --- Core thermal/pressure ---
|
# --- Core thermal/pressure ---
|
||||||
'CORE_TEMP': (float, False, 0, 1000, '°C'),
|
'CORE_TEMP': (float, False, 0, 1000, '°C'),
|
||||||
@ -358,48 +357,54 @@ class Nucon:
|
|||||||
'CHEMICAL_CLEANING_PUMP_OVERLOAD_STATUS': (PumpOverloadStatus, False),
|
'CHEMICAL_CLEANING_PUMP_OVERLOAD_STATUS': (PumpOverloadStatus, False),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Write-only params: normal operational commands
|
# Write-only params: normal control setpoints (no admin restriction)
|
||||||
write_only_values = {
|
write_only_values = {
|
||||||
# --- MSCVs (Main Steam Control Valves) setpoints ---
|
# --- MSCVs (Main Steam Control Valves) ---
|
||||||
**{f'MSCV_{i}_OPENING_ORDERED': (float, True, 0, 100, '%') for i in range(3)},
|
**{f'MSCV_{i}_OPENING_ORDERED': (float, True, 0, 100, '%') for i in range(3)},
|
||||||
|
|
||||||
# --- Steam turbine bypass setpoints ---
|
# --- Steam turbine bypass setpoints ---
|
||||||
**{f'STEAM_TURBINE_{i}_BYPASS_ORDERED': (float, True, 0, 100, '%') for i in range(3)},
|
**{f'STEAM_TURBINE_{i}_BYPASS_ORDERED': (float, True, 0, 100, '%') for i in range(3)},
|
||||||
|
|
||||||
# --- Steam ejector valve setpoints (0-100 position, not bool) ---
|
|
||||||
'STEAM_EJECTOR_STARTUP_MOTIVE_VALVE': (int, True, 0, 100, '%'),
|
|
||||||
'STEAM_EJECTOR_OPERATIONAL_MOTIVE_VALVE': (int, True, 0, 100, '%'),
|
|
||||||
'STEAM_EJECTOR_CONDENSER_RETURN_VALVE': (int, True, 0, 100, '%'),
|
|
||||||
|
|
||||||
# --- Generic valve commands (value = valve name e.g. "M01", "M02", "M03") ---
|
|
||||||
'VALVE_OPEN': (str, True),
|
|
||||||
'VALVE_CLOSE': (str, True),
|
|
||||||
'VALVE_OFF': (str, True),
|
|
||||||
|
|
||||||
# --- Pump / generator start/stop ---
|
|
||||||
'CONDENSER_VACUUM_PUMP_START_STOP': (bool, True),
|
|
||||||
'EMERGENCY_GENERATOR_1_START_STOP': (bool, True),
|
|
||||||
'EMERGENCY_GENERATOR_2_START_STOP': (bool, True),
|
|
||||||
|
|
||||||
# --- Chemistry setpoints ---
|
# --- Chemistry setpoints ---
|
||||||
'CHEM_BORON_DOSAGE_ORDERED_RATE': (float, True, 0, 100, '%'),
|
'CHEM_BORON_DOSAGE_ORDERED_RATE': (float, True, 0, 100, '%'),
|
||||||
'CHEM_BORON_FILTER_ORDERED_SPEED': (float, True, 0, 100, '%'),
|
'CHEM_BORON_FILTER_ORDERED_SPEED': (float, True, 0, 100, '%'),
|
||||||
|
}
|
||||||
|
|
||||||
# --- Core safety / operational actions ---
|
# Write-only admin params: destructive/irreversible operations, blocked unless admin_mode=True
|
||||||
|
write_only_admin_values = {
|
||||||
|
# --- Core safety actions ---
|
||||||
'CORE_SCRAM_BUTTON': (bool, True),
|
'CORE_SCRAM_BUTTON': (bool, True),
|
||||||
'CORE_EMERGENCY_STOP': (bool, True),
|
'CORE_EMERGENCY_STOP': (bool, True),
|
||||||
'CORE_END_EMERGENCY_STOP': (bool, True),
|
'CORE_END_EMERGENCY_STOP': (bool, True),
|
||||||
'RESET_AO': (bool, True),
|
'RESET_AO': (bool, True),
|
||||||
'STEAM_TURBINE_TRIP': (bool, True),
|
|
||||||
'RODS_ALL_POS_ORDERED': (float, True, 0, 100, '%'),
|
|
||||||
|
|
||||||
# --- Core bay physical operations ---
|
# --- Core bay physical operations ---
|
||||||
**{f'CORE_BAY_{i}_HATCH': (bool, True) for i in range(1, 10)},
|
**{f'CORE_BAY_{i}_HATCH': (bool, True) for i in range(1, 10)},
|
||||||
**{f'CORE_BAY_{i}_FUEL_LOADING': (int, True) for i in range(1, 10)},
|
**{f'CORE_BAY_{i}_FUEL_LOADING': (int, True) for i in range(1, 10)},
|
||||||
}
|
|
||||||
|
|
||||||
# Write-only cheat params: game event triggers, blocked unless cheat_mode=True
|
# --- Bulk rod override ---
|
||||||
write_only_cheat_values = {
|
'RODS_ALL_POS_ORDERED': (float, True, 0, 100, '%'),
|
||||||
|
|
||||||
|
# --- Steam turbine trip ---
|
||||||
|
'STEAM_TURBINE_TRIP': (bool, True),
|
||||||
|
|
||||||
|
# --- Steam ejector valves ---
|
||||||
|
'STEAM_EJECTOR_CONDENSER_RETURN_VALVE': (bool, True),
|
||||||
|
'STEAM_EJECTOR_OPERATIONAL_MOTIVE_VALVE': (bool, True),
|
||||||
|
'STEAM_EJECTOR_STARTUP_MOTIVE_VALVE': (bool, True),
|
||||||
|
|
||||||
|
# --- Generic valve commands (take valve name as value) ---
|
||||||
|
'VALVE_OPEN': (str, True),
|
||||||
|
'VALVE_CLOSE': (str, True),
|
||||||
|
'VALVE_OFF': (str, True),
|
||||||
|
|
||||||
|
# --- Infrastructure start/stop ---
|
||||||
|
'CONDENSER_VACUUM_PUMP_START_STOP': (bool, True),
|
||||||
|
'EMERGENCY_GENERATOR_1_START_STOP': (bool, True),
|
||||||
|
'EMERGENCY_GENERATOR_2_START_STOP': (bool, True),
|
||||||
|
|
||||||
|
# --- Fun / event triggers (game cheats) ---
|
||||||
|
'FUN_IS_ENABLED': (bool, True),
|
||||||
'FUN_REQUEST_ENABLE': (bool, True),
|
'FUN_REQUEST_ENABLE': (bool, True),
|
||||||
'FUN_AO_SABOTAGE_ONCE': (bool, True),
|
'FUN_AO_SABOTAGE_ONCE': (bool, True),
|
||||||
'FUN_AO_SABOTAGE_TIME': (float, True),
|
'FUN_AO_SABOTAGE_TIME': (float, True),
|
||||||
@ -423,8 +428,8 @@ class Nucon:
|
|||||||
}
|
}
|
||||||
for name, values in write_only_values.items():
|
for name, values in write_only_values.items():
|
||||||
params[name] = NuconParameter(self, name, *values, is_readable=False)
|
params[name] = NuconParameter(self, name, *values, is_readable=False)
|
||||||
for name, values in write_only_cheat_values.items():
|
for name, values in write_only_admin_values.items():
|
||||||
params[name] = NuconParameter(self, name, *values, is_readable=False, is_cheat=True)
|
params[name] = NuconParameter(self, name, *values, is_readable=False, is_admin=True)
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def _parse_value(self, parameter: NuconParameter, value: str) -> Union[float, int, bool, str, Enum, None]:
|
def _parse_value(self, parameter: NuconParameter, value: str) -> Union[float, int, bool, str, Enum, None]:
|
||||||
@ -464,8 +469,8 @@ class Nucon:
|
|||||||
|
|
||||||
if not force and not parameter.is_writable:
|
if not force and not parameter.is_writable:
|
||||||
raise ValueError(f"Parameter {parameter} is not writable")
|
raise ValueError(f"Parameter {parameter} is not writable")
|
||||||
if not force and parameter.is_cheat and not self.cheat_mode:
|
if not force and parameter.is_admin and not self.admin_mode:
|
||||||
raise ValueError(f"Parameter {parameter} is a cheat parameter. Enable cheat_mode on the Nucon instance or use force=True")
|
raise ValueError(f"Parameter {parameter} is an admin parameter. Enable admin_mode on the Nucon instance or use force=True")
|
||||||
|
|
||||||
if not force:
|
if not force:
|
||||||
parameter.check_in_range(value, raise_on_oob=True)
|
parameter.check_in_range(value, raise_on_oob=True)
|
||||||
@ -599,59 +604,11 @@ class Nucon:
|
|||||||
def get_all_writable(self) -> List[NuconParameter]:
|
def get_all_writable(self) -> List[NuconParameter]:
|
||||||
return {name: param for name, param in self._parameters.items() if param.is_writable}
|
return {name: param for name, param in self._parameters.items() if param.is_writable}
|
||||||
|
|
||||||
# --- Valve API ---
|
|
||||||
# Valves have a motorized actuator. OPEN/CLOSE power the motor toward that end-state;
|
|
||||||
# OFF cuts power and holds the current position. Normal resting state is OFF.
|
|
||||||
# The Value field (0-100) is the actual live position during travel.
|
|
||||||
|
|
||||||
def _post_valve_command(self, command: str, valve_name: str) -> None:
|
|
||||||
response = requests.post(self.base_url, params={"variable": command, "value": valve_name})
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise Exception(f"Valve command {command} on '{valve_name}' failed. Status: {response.status_code}")
|
|
||||||
|
|
||||||
def get_valve(self, valve_name: str) -> Dict[str, Any]:
|
|
||||||
"""Return current state dict for a single valve (from VALVE_PANEL_JSON)."""
|
|
||||||
valves = self.get_valves()
|
|
||||||
if valve_name not in valves:
|
|
||||||
raise KeyError(f"Valve '{valve_name}' not found")
|
|
||||||
return valves[valve_name]
|
|
||||||
|
|
||||||
def get_valves(self) -> Dict[str, Any]:
|
|
||||||
"""Return state dict for all valves, keyed by valve name."""
|
|
||||||
response = requests.get(self.base_url, params={"variable": "VALVE_PANEL_JSON"})
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise Exception(f"Failed to get valve panel. Status: {response.status_code}")
|
|
||||||
return response.json().get("valves", {})
|
|
||||||
|
|
||||||
def open_valve(self, valve_name: str) -> None:
|
|
||||||
"""Power actuator toward open state. Send off_valve() once target is reached."""
|
|
||||||
self._post_valve_command("VALVE_OPEN", valve_name)
|
|
||||||
|
|
||||||
def close_valve(self, valve_name: str) -> None:
|
|
||||||
"""Power actuator toward closed state. Send off_valve() once target is reached."""
|
|
||||||
self._post_valve_command("VALVE_CLOSE", valve_name)
|
|
||||||
|
|
||||||
def off_valve(self, valve_name: str) -> None:
|
|
||||||
"""Cut actuator power, hold current position. Normal resting state."""
|
|
||||||
self._post_valve_command("VALVE_OFF", valve_name)
|
|
||||||
|
|
||||||
def open_valves(self, valve_names: List[str]) -> None:
|
|
||||||
for name in valve_names:
|
|
||||||
self.open_valve(name)
|
|
||||||
|
|
||||||
def close_valves(self, valve_names: List[str]) -> None:
|
|
||||||
for name in valve_names:
|
|
||||||
self.close_valve(name)
|
|
||||||
|
|
||||||
def off_valves(self, valve_names: List[str]) -> None:
|
|
||||||
for name in valve_names:
|
|
||||||
self.off_valve(name)
|
|
||||||
|
|
||||||
def set_dummy_mode(self, dummy_mode: bool) -> None:
|
def set_dummy_mode(self, dummy_mode: bool) -> None:
|
||||||
self.dummy_mode = dummy_mode
|
self.dummy_mode = dummy_mode
|
||||||
|
|
||||||
def set_cheat_mode(self, cheat_mode: bool) -> None:
|
def set_admin_mode(self, admin_mode: bool) -> None:
|
||||||
self.cheat_mode = cheat_mode
|
self.admin_mode = admin_mode
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
if isinstance(name, int):
|
if isinstance(name, int):
|
||||||
|
|||||||
303
nucon/model.py
303
nucon/model.py
@ -8,15 +8,13 @@ from enum import Enum
|
|||||||
from nucon import Nucon
|
from nucon import Nucon
|
||||||
import pickle
|
import pickle
|
||||||
import os
|
import os
|
||||||
from typing import Union, Tuple, List, Dict
|
from typing import Union, Tuple, List
|
||||||
|
|
||||||
Actors = {
|
Actors = {
|
||||||
'random': lambda nucon: lambda obs: {param.id: random.uniform(param.min_val, param.max_val) if param.min_val is not None and param.max_val is not None else 0 for param in nucon.get_all_writable().values()},
|
'random': lambda nucon: lambda obs: {param.id: random.uniform(param.min_val, param.max_val) if param.min_val is not None and param.max_val is not None else 0 for param in nucon.get_all_writable().values()},
|
||||||
'null': lambda nucon: lambda obs: {},
|
'null': lambda nucon: lambda obs: {},
|
||||||
}
|
}
|
||||||
|
|
||||||
# --- NN-based dynamics model ---
|
|
||||||
|
|
||||||
class ReactorDynamicsNet(nn.Module):
|
class ReactorDynamicsNet(nn.Module):
|
||||||
def __init__(self, input_dim, output_dim):
|
def __init__(self, input_dim, output_dim):
|
||||||
super(ReactorDynamicsNet, self).__init__()
|
super(ReactorDynamicsNet, self).__init__()
|
||||||
@ -37,7 +35,10 @@ class ReactorDynamicsModel(nn.Module):
|
|||||||
super(ReactorDynamicsModel, self).__init__()
|
super(ReactorDynamicsModel, self).__init__()
|
||||||
self.input_params = input_params
|
self.input_params = input_params
|
||||||
self.output_params = output_params
|
self.output_params = output_params
|
||||||
self.net = ReactorDynamicsNet(len(input_params), len(output_params))
|
|
||||||
|
input_dim = len(input_params)
|
||||||
|
output_dim = len(output_params)
|
||||||
|
self.net = ReactorDynamicsNet(input_dim, output_dim)
|
||||||
|
|
||||||
def _state_dict_to_tensor(self, state_dict):
|
def _state_dict_to_tensor(self, state_dict):
|
||||||
return torch.tensor([state_dict[p] for p in self.input_params], dtype=torch.float32)
|
return torch.tensor([state_dict[p] for p in self.input_params], dtype=torch.float32)
|
||||||
@ -51,142 +52,17 @@ class ReactorDynamicsModel(nn.Module):
|
|||||||
predicted_tensor = self.net(state_tensor, time_delta_tensor)
|
predicted_tensor = self.net(state_tensor, time_delta_tensor)
|
||||||
return self._tensor_to_state_dict(predicted_tensor.squeeze(0))
|
return self._tensor_to_state_dict(predicted_tensor.squeeze(0))
|
||||||
|
|
||||||
# --- kNN-based dynamics model ---
|
|
||||||
|
|
||||||
class ReactorKNNModel:
|
|
||||||
"""
|
|
||||||
Non-parametric dynamics model using k-nearest neighbours.
|
|
||||||
|
|
||||||
For a query (state, game_delta):
|
|
||||||
1. Find the k dataset entries whose *state* is closest (L2 in normalised space).
|
|
||||||
2. For each neighbour compute the per-second rate-of-change:
|
|
||||||
rate_i = (next_state_i - state_i) / game_delta_i
|
|
||||||
3. Linearly scale to the requested game_delta:
|
|
||||||
predicted_delta_i = rate_i * game_delta
|
|
||||||
4. Return the inverse-distance-weighted average of those predicted deltas
|
|
||||||
added to the current output state.
|
|
||||||
|
|
||||||
The linear-in-time assumption means two datapoints at 0.5 s and 2 s contribute
|
|
||||||
equally once normalised by their own game_delta.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, input_params: List[str], output_params: List[str], k: int = 5):
|
|
||||||
self.input_params = input_params
|
|
||||||
self.output_params = output_params
|
|
||||||
self.k = k
|
|
||||||
self._states = None # (n, d_in) normalised state matrix
|
|
||||||
self._rates = None # (n, d_out) (next_out - cur_out) / game_delta
|
|
||||||
self._raw_states = None # unnormalised, for mean/std computation
|
|
||||||
self._mean = None
|
|
||||||
self._std = None
|
|
||||||
|
|
||||||
def fit(self, dataset):
|
|
||||||
"""Build lookup tables from a collected dataset."""
|
|
||||||
raw, rates = [], []
|
|
||||||
for state, _action, next_state, game_delta in dataset:
|
|
||||||
if game_delta <= 0:
|
|
||||||
continue
|
|
||||||
s = np.array([state[p] for p in self.input_params], dtype=np.float32)
|
|
||||||
cur = np.array([state[p] for p in self.output_params], dtype=np.float32)
|
|
||||||
nxt = np.array([next_state[p] for p in self.output_params], dtype=np.float32)
|
|
||||||
raw.append(s)
|
|
||||||
rates.append((nxt - cur) / game_delta)
|
|
||||||
|
|
||||||
self._raw_states = np.array(raw)
|
|
||||||
self._rates = np.array(rates)
|
|
||||||
self._mean = self._raw_states.mean(axis=0)
|
|
||||||
self._std = self._raw_states.std(axis=0) + 1e-8
|
|
||||||
self._states = (self._raw_states - self._mean) / self._std
|
|
||||||
|
|
||||||
def _lookup(self, state_dict: Dict):
|
|
||||||
"""Return (s_norm, idx, k) for the k nearest neighbours."""
|
|
||||||
s = np.array([state_dict[p] for p in self.input_params], dtype=np.float32)
|
|
||||||
s_norm = (s - self._mean) / self._std
|
|
||||||
dists = np.linalg.norm(self._states - s_norm, axis=1)
|
|
||||||
k = min(self.k, len(dists))
|
|
||||||
idx = np.argpartition(dists, k - 1)[:k]
|
|
||||||
return s_norm, idx, k
|
|
||||||
|
|
||||||
def forward(self, state_dict: Dict, time_delta: float) -> Dict:
|
|
||||||
if self._states is None:
|
|
||||||
raise ValueError("Model not fitted. Call fit(dataset) first.")
|
|
||||||
return self.forward_with_uncertainty(state_dict, time_delta)[0]
|
|
||||||
|
|
||||||
def forward_with_uncertainty(self, state_dict: Dict, time_delta: float):
|
|
||||||
"""Return (prediction_dict, uncertainty_scalar).
|
|
||||||
|
|
||||||
Uncertainty is the GP posterior std in normalised input space:
|
|
||||||
0 = query lies exactly on a training point (fully confident)
|
|
||||||
~1 = query is far from all neighbours (maximally uncertain)
|
|
||||||
"""
|
|
||||||
if self._states is None:
|
|
||||||
raise ValueError("Model not fitted. Call fit(dataset) first.")
|
|
||||||
|
|
||||||
s_norm, idx, k = self._lookup(state_dict)
|
|
||||||
X = self._states[idx] # (k, d_in)
|
|
||||||
Y = self._rates[idx] # (k, d_out)
|
|
||||||
|
|
||||||
# RBF kernel (vectorised): k(a,b) = exp(-0.5 ||a-b||^2)
|
|
||||||
def rbf_matrix(A, B):
|
|
||||||
diff = A[:, None, :] - B[None, :, :] # (|A|, |B|, d)
|
|
||||||
return np.exp(-0.5 * (diff ** 2).sum(axis=-1)) # (|A|, |B|)
|
|
||||||
|
|
||||||
K = rbf_matrix(X, X) + 1e-4 * np.eye(k) # (k, k)
|
|
||||||
k_star = rbf_matrix(s_norm[None, :], X)[0] # (k,)
|
|
||||||
|
|
||||||
K_inv = np.linalg.inv(K)
|
|
||||||
mean_rates = k_star @ K_inv @ Y # (d_out,)
|
|
||||||
|
|
||||||
# Posterior variance (scalar, shared across all output dims)
|
|
||||||
var = max(0.0, 1.0 - float(k_star @ K_inv @ k_star))
|
|
||||||
std = float(np.sqrt(var))
|
|
||||||
|
|
||||||
cur_out = np.array([state_dict[p] for p in self.output_params], dtype=np.float32)
|
|
||||||
predicted = cur_out + mean_rates * time_delta
|
|
||||||
|
|
||||||
pred_dict = {p: float(predicted[i]) for i, p in enumerate(self.output_params)}
|
|
||||||
return pred_dict, std
|
|
||||||
|
|
||||||
# --- Learner ---
|
|
||||||
|
|
||||||
class NuconModelLearner:
|
class NuconModelLearner:
|
||||||
def __init__(self, nucon=None, actor='null', dataset_path='nucon_dataset.pkl',
|
def __init__(self, nucon=None, actor='null', dataset_path='nucon_dataset.pkl', time_delta: Union[float, Tuple[float, float]] = 0.1):
|
||||||
time_delta: Union[float, Tuple[float, float]] = 1.0,
|
|
||||||
model_type: str = 'nn', knn_k: int = 5,
|
|
||||||
include_valve_states: bool = False):
|
|
||||||
self.nucon = Nucon() if nucon is None else nucon
|
self.nucon = Nucon() if nucon is None else nucon
|
||||||
self.actor = Actors[actor](self.nucon) if actor in Actors else actor
|
self.actor = Actors[actor](self.nucon) if actor in Actors else actor
|
||||||
self.dataset = self.load_dataset(dataset_path) or []
|
self.dataset = self.load_dataset(dataset_path) or []
|
||||||
self.dataset_path = dataset_path
|
self.dataset_path = dataset_path
|
||||||
self.include_valve_states = include_valve_states
|
|
||||||
|
|
||||||
# Exclude params with no physics signal
|
self.readable_params = list(self.nucon.get_all_readable().keys())
|
||||||
_JUNK_PARAMS = frozenset({'GAME_VERSION', 'TIME', 'TIME_STAMP', 'TIME_DAY',
|
self.non_writable_params = [param.id for param in self.nucon.get_all_readable().values() if not param.is_writable]
|
||||||
'ALARMS_ACTIVE', 'FUN_IS_ENABLED', 'GAME_SIM_SPEED'})
|
|
||||||
candidate_params = {k: p for k, p in self.nucon.get_all_readable().items()
|
|
||||||
if k not in _JUNK_PARAMS and p.param_type != str}
|
|
||||||
# Filter out params that return None (subsystem not installed)
|
|
||||||
test_state = {k: self.nucon.get(k) for k in candidate_params}
|
|
||||||
self.readable_params = [k for k in candidate_params if test_state[k] is not None]
|
|
||||||
self.non_writable_params = [k for k in self.readable_params
|
|
||||||
if not self.nucon.get_all_readable()[k].is_writable]
|
|
||||||
|
|
||||||
# Optionally include valve positions (input only — valves are externally driven)
|
|
||||||
self.valve_keys = []
|
|
||||||
if include_valve_states:
|
|
||||||
valves = self.nucon.get_valves()
|
|
||||||
self.valve_keys = [f'VALVE__{name}' for name in sorted(valves.keys())]
|
|
||||||
self.readable_params = self.readable_params + self.valve_keys
|
|
||||||
# valve positions are input-only (not predicted as outputs)
|
|
||||||
|
|
||||||
if model_type == 'nn':
|
|
||||||
self.model = ReactorDynamicsModel(self.readable_params, self.non_writable_params)
|
self.model = ReactorDynamicsModel(self.readable_params, self.non_writable_params)
|
||||||
self.optimizer = optim.Adam(self.model.parameters())
|
self.optimizer = optim.Adam(self.model.parameters())
|
||||||
elif model_type == 'knn':
|
|
||||||
self.model = ReactorKNNModel(self.readable_params, self.non_writable_params, k=knn_k)
|
|
||||||
self.optimizer = None
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown model_type '{model_type}'. Use 'nn' or 'knn'.")
|
|
||||||
|
|
||||||
if isinstance(time_delta, (int, float)):
|
if isinstance(time_delta, (int, float)):
|
||||||
self.time_delta = lambda: time_delta
|
self.time_delta = lambda: time_delta
|
||||||
@ -197,180 +73,87 @@ class NuconModelLearner:
|
|||||||
|
|
||||||
def _get_state(self):
|
def _get_state(self):
|
||||||
state = {}
|
state = {}
|
||||||
for param_id in self.readable_params:
|
for param_id, param in self.nucon.get_all_readable().items():
|
||||||
if param_id in self.valve_keys:
|
value = self.nucon.get(param)
|
||||||
continue # filled below
|
|
||||||
value = self.nucon.get(param_id)
|
|
||||||
if isinstance(value, Enum):
|
if isinstance(value, Enum):
|
||||||
value = value.value
|
value = value.value
|
||||||
state[param_id] = value
|
state[param_id] = value
|
||||||
if self.valve_keys:
|
|
||||||
valves = self.nucon.get_valves()
|
|
||||||
for key in self.valve_keys:
|
|
||||||
name = key[len('VALVE__'):]
|
|
||||||
state[key] = valves.get(name, {}).get('Value', 0.0)
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def collect_data(self, num_steps):
|
def collect_data(self, num_steps):
|
||||||
"""
|
|
||||||
Collect state-transition tuples from the live game.
|
|
||||||
|
|
||||||
Sleeps wall_time = target_game_delta / sim_speed so that each stored
|
|
||||||
game_delta is uniform regardless of the game's simulation speed setting.
|
|
||||||
"""
|
|
||||||
state = self._get_state()
|
state = self._get_state()
|
||||||
for _ in range(num_steps):
|
for _ in range(num_steps):
|
||||||
action = self.actor(state)
|
action = self.actor(state)
|
||||||
|
start_time = time.time()
|
||||||
for param_id, value in action.items():
|
for param_id, value in action.items():
|
||||||
self.nucon.set(param_id, value)
|
self.nucon.set(param_id, value)
|
||||||
|
time_delta = self.time_delta()
|
||||||
target_game_delta = self.time_delta()
|
time.sleep(time_delta)
|
||||||
sim_speed = self.nucon.GAME_SIM_SPEED.value or 1.0
|
|
||||||
time.sleep(target_game_delta / sim_speed)
|
|
||||||
next_state = self._get_state()
|
next_state = self._get_state()
|
||||||
|
|
||||||
self.dataset.append((state, action, next_state, target_game_delta))
|
self.dataset.append((state, action, next_state, time_delta))
|
||||||
state = next_state
|
state = next_state
|
||||||
|
|
||||||
self.save_dataset()
|
self.save_dataset()
|
||||||
|
|
||||||
|
def refine_dataset(self, error_threshold):
|
||||||
|
refined_data = []
|
||||||
|
for state, action, next_state, time_delta in self.dataset:
|
||||||
|
predicted_next_state = self.model(state, time_delta)
|
||||||
|
|
||||||
|
error = sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params)
|
||||||
|
if error > error_threshold:
|
||||||
|
refined_data.append((state, action, next_state, time_delta))
|
||||||
|
|
||||||
|
self.dataset = refined_data
|
||||||
|
self.save_dataset()
|
||||||
|
|
||||||
def train_model(self, batch_size=32, num_epochs=10, test_split=0.2):
|
def train_model(self, batch_size=32, num_epochs=10, test_split=0.2):
|
||||||
"""Train the NN model. For kNN, call fit_knn() instead."""
|
|
||||||
if not isinstance(self.model, ReactorDynamicsModel):
|
|
||||||
raise ValueError("train_model() is for the NN model. Use fit_knn() for kNN.")
|
|
||||||
random.shuffle(self.dataset)
|
random.shuffle(self.dataset)
|
||||||
split_idx = int(len(self.dataset) * (1 - test_split))
|
split_idx = int(len(self.dataset) * (1 - test_split))
|
||||||
train_data = self.dataset[:split_idx]
|
train_data = self.dataset[:split_idx]
|
||||||
test_data = self.dataset[split_idx:]
|
test_data = self.dataset[split_idx:]
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
train_loss = self._train_epoch(train_data, batch_size)
|
train_loss = self._train_epoch(train_data, batch_size)
|
||||||
test_loss = self._test_epoch(test_data)
|
test_loss = self._test_epoch(test_data)
|
||||||
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
|
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
|
||||||
|
|
||||||
def fit_knn(self):
|
|
||||||
"""Fit the kNN/GP model from the current dataset (instantaneous, no gradient steps)."""
|
|
||||||
if not isinstance(self.model, ReactorKNNModel):
|
|
||||||
raise ValueError("fit_knn() is for the kNN model. Use train_model() for NN.")
|
|
||||||
self.model.fit(self.dataset)
|
|
||||||
print(f"kNN model fitted on {len(self.dataset)} samples.")
|
|
||||||
|
|
||||||
def predict_with_uncertainty(self, state_dict: Dict, time_delta: float):
|
|
||||||
"""Return (prediction_dict, uncertainty_std). Only available for kNN model."""
|
|
||||||
if not isinstance(self.model, ReactorKNNModel):
|
|
||||||
raise ValueError("predict_with_uncertainty() requires model_type='knn'.")
|
|
||||||
return self.model.forward_with_uncertainty(state_dict, time_delta)
|
|
||||||
|
|
||||||
def drop_well_fitted(self, error_threshold: float):
|
|
||||||
"""Drop samples the current model already predicts well (MSE < threshold).
|
|
||||||
|
|
||||||
Keeps only hard/surprising transitions. Useful for NN training to focus
|
|
||||||
capacity on difficult regions of state space.
|
|
||||||
"""
|
|
||||||
kept = []
|
|
||||||
for state, action, next_state, time_delta in self.dataset:
|
|
||||||
pred = self.model.forward(state, time_delta)
|
|
||||||
error = sum((pred[p] - next_state[p]) ** 2 for p in self.non_writable_params)
|
|
||||||
if error > error_threshold:
|
|
||||||
kept.append((state, action, next_state, time_delta))
|
|
||||||
dropped = len(self.dataset) - len(kept)
|
|
||||||
self.dataset = kept
|
|
||||||
self.save_dataset()
|
|
||||||
print(f"drop_well_fitted: kept {len(kept)}, dropped {dropped} samples.")
|
|
||||||
|
|
||||||
def drop_redundant(self, min_state_distance: float, min_output_distance: float = 0.0):
|
|
||||||
"""Drop near-duplicate samples, keeping only those that add coverage.
|
|
||||||
|
|
||||||
A sample is dropped only if *both* its input state and its output
|
|
||||||
transition are within the given distances of an already-kept sample
|
|
||||||
(L2 in z-scored space). If two samples share the same input state but
|
|
||||||
have different transitions they represent genuinely different dynamics
|
|
||||||
and are both kept regardless of `min_output_distance`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
min_state_distance: minimum L2 distance in z-scored input space.
|
|
||||||
min_output_distance: minimum L2 distance in z-scored output-delta
|
|
||||||
space. Defaults to 0 (only input distance matters).
|
|
||||||
"""
|
|
||||||
if not self.dataset:
|
|
||||||
return
|
|
||||||
|
|
||||||
in_params = [p for p in self.readable_params if p not in self.valve_keys]
|
|
||||||
out_params = self.non_writable_params
|
|
||||||
|
|
||||||
all_states = np.array([[s[p] for p in in_params] for s, *_ in self.dataset], dtype=np.float32)
|
|
||||||
all_deltas = np.array([[ns[p] - s[p] for p in out_params]
|
|
||||||
for s, _, ns, gd in self.dataset], dtype=np.float32)
|
|
||||||
|
|
||||||
s_mean, s_std = all_states.mean(0), all_states.std(0) + 1e-8
|
|
||||||
d_mean, d_std = all_deltas.mean(0), all_deltas.std(0) + 1e-8
|
|
||||||
|
|
||||||
s_norm = (all_states - s_mean) / s_std
|
|
||||||
d_norm = (all_deltas - d_mean) / d_std
|
|
||||||
|
|
||||||
kept_idx = [0]
|
|
||||||
kept_s = [s_norm[0]]
|
|
||||||
kept_d = [d_norm[0]]
|
|
||||||
|
|
||||||
for i in range(1, len(self.dataset)):
|
|
||||||
s_dists = np.linalg.norm(np.array(kept_s) - s_norm[i], axis=1)
|
|
||||||
d_dists = np.linalg.norm(np.array(kept_d) - d_norm[i], axis=1)
|
|
||||||
# Drop only if close in BOTH spaces
|
|
||||||
if not np.any((s_dists < min_state_distance) & (d_dists < min_output_distance)):
|
|
||||||
kept_idx.append(i)
|
|
||||||
kept_s.append(s_norm[i])
|
|
||||||
kept_d.append(d_norm[i])
|
|
||||||
|
|
||||||
dropped = len(self.dataset) - len(kept_idx)
|
|
||||||
self.dataset = [self.dataset[i] for i in kept_idx]
|
|
||||||
self.save_dataset()
|
|
||||||
print(f"drop_redundant: kept {len(self.dataset)}, dropped {dropped} samples.")
|
|
||||||
|
|
||||||
def _train_epoch(self, data, batch_size):
|
def _train_epoch(self, data, batch_size):
|
||||||
out_indices = [self.readable_params.index(p) if p in self.readable_params else None
|
|
||||||
for p in self.non_writable_params]
|
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
for i in range(0, len(data), batch_size):
|
for i in range(0, len(data), batch_size):
|
||||||
batch = data[i:i+batch_size]
|
batch = data[i:i+batch_size]
|
||||||
|
states, _, next_states, time_deltas = zip(*batch)
|
||||||
|
|
||||||
|
loss = 0
|
||||||
|
for state, next_state, time_delta in zip(states, next_states, time_deltas):
|
||||||
|
predicted_next_state = self.model(state, time_delta)
|
||||||
|
loss += sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params)
|
||||||
|
|
||||||
|
loss /= len(batch)
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
loss = torch.tensor(0.0)
|
|
||||||
for state, _, next_state, time_delta in batch:
|
|
||||||
state_t = self.model._state_dict_to_tensor(state).unsqueeze(0)
|
|
||||||
td_t = torch.tensor([[time_delta]], dtype=torch.float32)
|
|
||||||
pred = self.model.net(state_t, td_t).squeeze(0)
|
|
||||||
target = torch.tensor([next_state[p] for p in self.non_writable_params],
|
|
||||||
dtype=torch.float32)
|
|
||||||
loss = loss + torch.nn.functional.mse_loss(pred, target)
|
|
||||||
loss = loss / len(batch)
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
return total_loss / max(1, len(data) // batch_size)
|
|
||||||
|
return total_loss / (len(data) // batch_size)
|
||||||
|
|
||||||
def _test_epoch(self, data):
|
def _test_epoch(self, data):
|
||||||
total_loss = 0.0
|
total_loss = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for state, _, next_state, time_delta in data:
|
for state, _, next_state, time_delta in data:
|
||||||
state_t = self.model._state_dict_to_tensor(state).unsqueeze(0)
|
predicted_next_state = self.model(state, time_delta)
|
||||||
td_t = torch.tensor([[time_delta]], dtype=torch.float32)
|
loss = sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params)
|
||||||
pred = self.model.net(state_t, td_t).squeeze(0)
|
total_loss += loss
|
||||||
target = torch.tensor([next_state[p] for p in self.non_writable_params],
|
|
||||||
dtype=torch.float32)
|
|
||||||
total_loss += torch.nn.functional.mse_loss(pred, target).item()
|
|
||||||
return total_loss / len(data)
|
return total_loss / len(data)
|
||||||
|
|
||||||
def save_model(self, path):
|
def save_model(self, path):
|
||||||
if isinstance(self.model, ReactorDynamicsModel):
|
|
||||||
torch.save(self.model.state_dict(), path)
|
torch.save(self.model.state_dict(), path)
|
||||||
else:
|
|
||||||
with open(path, 'wb') as f:
|
|
||||||
pickle.dump(self.model, f)
|
|
||||||
|
|
||||||
def load_model(self, path):
|
def load_model(self, path):
|
||||||
if isinstance(self.model, ReactorDynamicsModel):
|
|
||||||
self.model.load_state_dict(torch.load(path))
|
self.model.load_state_dict(torch.load(path))
|
||||||
else:
|
|
||||||
with open(path, 'rb') as f:
|
|
||||||
self.model = pickle.load(f)
|
|
||||||
|
|
||||||
def save_dataset(self, path=None):
|
def save_dataset(self, path=None):
|
||||||
path = path or self.dataset_path
|
path = path or self.dataset_path
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user