From 60cd44cc9e392762e4bab92d7efc277b826204e2 Mon Sep 17 00:00:00 2001 From: Dominik Roth Date: Thu, 3 Oct 2024 21:55:59 +0200 Subject: [PATCH] Implemenetd Model Learning --- nucon/model.py | 174 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 nucon/model.py diff --git a/nucon/model.py b/nucon/model.py new file mode 100644 index 0000000..1c9f2d8 --- /dev/null +++ b/nucon/model.py @@ -0,0 +1,174 @@ +import numpy as np +import time +import torch +import torch.nn as nn +import torch.optim as optim +import random +from enum import Enum +from nucon import Nucon +import pickle +import os +from typing import Union, Tuple, List + +Actors = { + 'random': lambda nucon: lambda obs: {param.id: random.uniform(param.min_val, param.max_val) if param.min_val is not None and param.max_val is not None else 0 for param in nucon.get_all_writable().values()}, + 'null': lambda nucon: lambda obs: {}, +} + +class ReactorDynamicsNet(nn.Module): + def __init__(self, input_dim, output_dim): + super(ReactorDynamicsNet, self).__init__() + self.network = nn.Sequential( + nn.Linear(input_dim + 1, 128), # +1 for time_delta + nn.ReLU(), + nn.Linear(128, 128), + nn.ReLU(), + nn.Linear(128, output_dim) + ) + + def forward(self, state, time_delta): + x = torch.cat([state, time_delta], dim=-1) + return self.network(x) + +class ReactorDynamicsModel(nn.Module): + def __init__(self, input_params: List[str], output_params: List[str]): + super(ReactorDynamicsModel, self).__init__() + self.input_params = input_params + self.output_params = output_params + + input_dim = len(input_params) + output_dim = len(output_params) + self.net = ReactorDynamicsNet(input_dim, output_dim) + + def _state_dict_to_tensor(self, state_dict): + return torch.tensor([state_dict[p] for p in self.input_params], dtype=torch.float32) + + def _tensor_to_state_dict(self, tensor): + return {p: tensor[i].item() for i, p in enumerate(self.output_params)} + + def forward(self, state_dict, time_delta): + state_tensor = self._state_dict_to_tensor(state_dict).unsqueeze(0) + time_delta_tensor = torch.tensor([time_delta], dtype=torch.float32).unsqueeze(0) + predicted_tensor = self.net(state_tensor, time_delta_tensor) + return self._tensor_to_state_dict(predicted_tensor.squeeze(0)) + +class NuconModelLearner: + def __init__(self, nucon=None, actor='null', dataset_path='nucon_dataset.pkl', time_delta: Union[float, Tuple[float, float]] = 0.1): + self.nucon = Nucon() if nucon is None else nucon + self.actor = Actors[actor](self.nucon) if actor in Actors else actor + self.dataset = self.load_dataset(dataset_path) or [] + self.dataset_path = dataset_path + + self.readable_params = list(self.nucon.get_all_readable().keys()) + self.non_writable_params = [param.id for param in self.nucon.get_all_readable().values() if not param.is_writable] + self.model = ReactorDynamicsModel(self.readable_params, self.non_writable_params) + self.optimizer = optim.Adam(self.model.parameters()) + + if isinstance(time_delta, (int, float)): + self.time_delta = lambda: time_delta + elif isinstance(time_delta, tuple) and len(time_delta) == 2: + self.time_delta = lambda: random.uniform(*time_delta) + else: + raise ValueError("time_delta must be a float or a tuple of two floats") + + def _get_state(self): + state = {} + for param_id, param in self.nucon.get_all_readable().items(): + value = self.nucon.get(param) + if isinstance(value, Enum): + value = value.value + state[param_id] = value + return state + + def collect_data(self, num_steps): + state = self._get_state() + for _ in range(num_steps): + action = self.actor(state) + start_time = time.time() + for param_id, value in action.items(): + self.nucon.set(param_id, value) + time_delta = self.time_delta() + time.sleep(time_delta) + next_state = self._get_state() + + self.dataset.append((state, action, next_state, time_delta)) + state = next_state + + self.save_dataset() + + def refine_dataset(self, error_threshold): + refined_data = [] + for state, action, next_state, time_delta in self.dataset: + predicted_next_state = self.model(state, time_delta) + + error = sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params) + if error > error_threshold: + refined_data.append((state, action, next_state, time_delta)) + + self.dataset = refined_data + self.save_dataset() + + def train_model(self, batch_size=32, num_epochs=10, test_split=0.2): + random.shuffle(self.dataset) + split_idx = int(len(self.dataset) * (1 - test_split)) + train_data = self.dataset[:split_idx] + test_data = self.dataset[split_idx:] + + for epoch in range(num_epochs): + train_loss = self._train_epoch(train_data, batch_size) + test_loss = self._test_epoch(test_data) + print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}") + + def _train_epoch(self, data, batch_size): + total_loss = 0 + for i in range(0, len(data), batch_size): + batch = data[i:i+batch_size] + states, _, next_states, time_deltas = zip(*batch) + + loss = 0 + for state, next_state, time_delta in zip(states, next_states, time_deltas): + predicted_next_state = self.model(state, time_delta) + loss += sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params) + + loss /= len(batch) + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + total_loss += loss.item() + + return total_loss / (len(data) // batch_size) + + def _test_epoch(self, data): + total_loss = 0 + with torch.no_grad(): + for state, _, next_state, time_delta in data: + predicted_next_state = self.model(state, time_delta) + loss = sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params) + total_loss += loss + + return total_loss / len(data) + + def save_model(self, path): + torch.save(self.model.state_dict(), path) + + def load_model(self, path): + self.model.load_state_dict(torch.load(path)) + + def save_dataset(self, path=None): + path = path or self.dataset_path + with open(path, 'wb') as f: + pickle.dump(self.dataset, f) + + def load_dataset(self, path=None): + path = path or self.dataset_path + if os.path.exists(path): + with open(path, 'rb') as f: + return pickle.load(f) + return None + + def merge_datasets(self, other_dataset_path): + other_dataset = self.load_dataset(other_dataset_path) + if other_dataset: + self.dataset.extend(other_dataset) + self.save_dataset() \ No newline at end of file