import numpy as np import time import torch import torch.nn as nn import torch.optim as optim import random from enum import Enum from nucon import Nucon import pickle import os from typing import Union, Tuple, List Actors = { 'random': lambda nucon: lambda obs: {param.id: random.uniform(param.min_val, param.max_val) if param.min_val is not None and param.max_val is not None else 0 for param in nucon.get_all_writable().values()}, 'null': lambda nucon: lambda obs: {}, } class ReactorDynamicsNet(nn.Module): def __init__(self, input_dim, output_dim): super(ReactorDynamicsNet, self).__init__() self.network = nn.Sequential( nn.Linear(input_dim + 1, 128), # +1 for time_delta nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, output_dim) ) def forward(self, state, time_delta): x = torch.cat([state, time_delta], dim=-1) return self.network(x) class ReactorDynamicsModel(nn.Module): def __init__(self, input_params: List[str], output_params: List[str]): super(ReactorDynamicsModel, self).__init__() self.input_params = input_params self.output_params = output_params input_dim = len(input_params) output_dim = len(output_params) self.net = ReactorDynamicsNet(input_dim, output_dim) def _state_dict_to_tensor(self, state_dict): return torch.tensor([state_dict[p] for p in self.input_params], dtype=torch.float32) def _tensor_to_state_dict(self, tensor): return {p: tensor[i].item() for i, p in enumerate(self.output_params)} def forward(self, state_dict, time_delta): state_tensor = self._state_dict_to_tensor(state_dict).unsqueeze(0) time_delta_tensor = torch.tensor([time_delta], dtype=torch.float32).unsqueeze(0) predicted_tensor = self.net(state_tensor, time_delta_tensor) return self._tensor_to_state_dict(predicted_tensor.squeeze(0)) class NuconModelLearner: def __init__(self, nucon=None, actor='null', dataset_path='nucon_dataset.pkl', time_delta: Union[float, Tuple[float, float]] = 0.1): self.nucon = Nucon() if nucon is None else nucon self.actor = Actors[actor](self.nucon) if actor in Actors else actor self.dataset = self.load_dataset(dataset_path) or [] self.dataset_path = dataset_path self.readable_params = list(self.nucon.get_all_readable().keys()) self.non_writable_params = [param.id for param in self.nucon.get_all_readable().values() if not param.is_writable] self.model = ReactorDynamicsModel(self.readable_params, self.non_writable_params) self.optimizer = optim.Adam(self.model.parameters()) if isinstance(time_delta, (int, float)): self.time_delta = lambda: time_delta elif isinstance(time_delta, tuple) and len(time_delta) == 2: self.time_delta = lambda: random.uniform(*time_delta) else: raise ValueError("time_delta must be a float or a tuple of two floats") def _get_state(self): state = {} for param_id, param in self.nucon.get_all_readable().items(): value = self.nucon.get(param) if isinstance(value, Enum): value = value.value state[param_id] = value return state def collect_data(self, num_steps): state = self._get_state() for _ in range(num_steps): action = self.actor(state) start_time = time.time() for param_id, value in action.items(): self.nucon.set(param_id, value) time_delta = self.time_delta() time.sleep(time_delta) next_state = self._get_state() self.dataset.append((state, action, next_state, time_delta)) state = next_state self.save_dataset() def refine_dataset(self, error_threshold): refined_data = [] for state, action, next_state, time_delta in self.dataset: predicted_next_state = self.model(state, time_delta) error = sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params) if error > error_threshold: refined_data.append((state, action, next_state, time_delta)) self.dataset = refined_data self.save_dataset() def train_model(self, batch_size=32, num_epochs=10, test_split=0.2): random.shuffle(self.dataset) split_idx = int(len(self.dataset) * (1 - test_split)) train_data = self.dataset[:split_idx] test_data = self.dataset[split_idx:] for epoch in range(num_epochs): train_loss = self._train_epoch(train_data, batch_size) test_loss = self._test_epoch(test_data) print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}") def _train_epoch(self, data, batch_size): total_loss = 0 for i in range(0, len(data), batch_size): batch = data[i:i+batch_size] states, _, next_states, time_deltas = zip(*batch) loss = 0 for state, next_state, time_delta in zip(states, next_states, time_deltas): predicted_next_state = self.model(state, time_delta) loss += sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params) loss /= len(batch) self.optimizer.zero_grad() loss.backward() self.optimizer.step() total_loss += loss.item() return total_loss / (len(data) // batch_size) def _test_epoch(self, data): total_loss = 0 with torch.no_grad(): for state, _, next_state, time_delta in data: predicted_next_state = self.model(state, time_delta) loss = sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params) total_loss += loss return total_loss / len(data) def save_model(self, path): torch.save(self.model.state_dict(), path) def load_model(self, path): self.model.load_state_dict(torch.load(path)) def save_dataset(self, path=None): path = path or self.dataset_path with open(path, 'wb') as f: pickle.dump(self.dataset, f) def load_dataset(self, path=None): path = path or self.dataset_path if os.path.exists(path): with open(path, 'rb') as f: return pickle.load(f) return None def merge_datasets(self, other_dataset_path): other_dataset = self.load_dataset(other_dataset_path) if other_dataset: self.dataset.extend(other_dataset) self.save_dataset()