Implemenetd Model Learning

This commit is contained in:
Dominik Moritz Roth 2024-10-03 21:55:59 +02:00
parent 132c47ff21
commit 60cd44cc9e

174
nucon/model.py Normal file
View File

@ -0,0 +1,174 @@
import numpy as np
import time
import torch
import torch.nn as nn
import torch.optim as optim
import random
from enum import Enum
from nucon import Nucon
import pickle
import os
from typing import Union, Tuple, List
Actors = {
'random': lambda nucon: lambda obs: {param.id: random.uniform(param.min_val, param.max_val) if param.min_val is not None and param.max_val is not None else 0 for param in nucon.get_all_writable().values()},
'null': lambda nucon: lambda obs: {},
}
class ReactorDynamicsNet(nn.Module):
def __init__(self, input_dim, output_dim):
super(ReactorDynamicsNet, self).__init__()
self.network = nn.Sequential(
nn.Linear(input_dim + 1, 128), # +1 for time_delta
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, output_dim)
)
def forward(self, state, time_delta):
x = torch.cat([state, time_delta], dim=-1)
return self.network(x)
class ReactorDynamicsModel(nn.Module):
def __init__(self, input_params: List[str], output_params: List[str]):
super(ReactorDynamicsModel, self).__init__()
self.input_params = input_params
self.output_params = output_params
input_dim = len(input_params)
output_dim = len(output_params)
self.net = ReactorDynamicsNet(input_dim, output_dim)
def _state_dict_to_tensor(self, state_dict):
return torch.tensor([state_dict[p] for p in self.input_params], dtype=torch.float32)
def _tensor_to_state_dict(self, tensor):
return {p: tensor[i].item() for i, p in enumerate(self.output_params)}
def forward(self, state_dict, time_delta):
state_tensor = self._state_dict_to_tensor(state_dict).unsqueeze(0)
time_delta_tensor = torch.tensor([time_delta], dtype=torch.float32).unsqueeze(0)
predicted_tensor = self.net(state_tensor, time_delta_tensor)
return self._tensor_to_state_dict(predicted_tensor.squeeze(0))
class NuconModelLearner:
def __init__(self, nucon=None, actor='null', dataset_path='nucon_dataset.pkl', time_delta: Union[float, Tuple[float, float]] = 0.1):
self.nucon = Nucon() if nucon is None else nucon
self.actor = Actors[actor](self.nucon) if actor in Actors else actor
self.dataset = self.load_dataset(dataset_path) or []
self.dataset_path = dataset_path
self.readable_params = list(self.nucon.get_all_readable().keys())
self.non_writable_params = [param.id for param in self.nucon.get_all_readable().values() if not param.is_writable]
self.model = ReactorDynamicsModel(self.readable_params, self.non_writable_params)
self.optimizer = optim.Adam(self.model.parameters())
if isinstance(time_delta, (int, float)):
self.time_delta = lambda: time_delta
elif isinstance(time_delta, tuple) and len(time_delta) == 2:
self.time_delta = lambda: random.uniform(*time_delta)
else:
raise ValueError("time_delta must be a float or a tuple of two floats")
def _get_state(self):
state = {}
for param_id, param in self.nucon.get_all_readable().items():
value = self.nucon.get(param)
if isinstance(value, Enum):
value = value.value
state[param_id] = value
return state
def collect_data(self, num_steps):
state = self._get_state()
for _ in range(num_steps):
action = self.actor(state)
start_time = time.time()
for param_id, value in action.items():
self.nucon.set(param_id, value)
time_delta = self.time_delta()
time.sleep(time_delta)
next_state = self._get_state()
self.dataset.append((state, action, next_state, time_delta))
state = next_state
self.save_dataset()
def refine_dataset(self, error_threshold):
refined_data = []
for state, action, next_state, time_delta in self.dataset:
predicted_next_state = self.model(state, time_delta)
error = sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params)
if error > error_threshold:
refined_data.append((state, action, next_state, time_delta))
self.dataset = refined_data
self.save_dataset()
def train_model(self, batch_size=32, num_epochs=10, test_split=0.2):
random.shuffle(self.dataset)
split_idx = int(len(self.dataset) * (1 - test_split))
train_data = self.dataset[:split_idx]
test_data = self.dataset[split_idx:]
for epoch in range(num_epochs):
train_loss = self._train_epoch(train_data, batch_size)
test_loss = self._test_epoch(test_data)
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
def _train_epoch(self, data, batch_size):
total_loss = 0
for i in range(0, len(data), batch_size):
batch = data[i:i+batch_size]
states, _, next_states, time_deltas = zip(*batch)
loss = 0
for state, next_state, time_delta in zip(states, next_states, time_deltas):
predicted_next_state = self.model(state, time_delta)
loss += sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params)
loss /= len(batch)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.item()
return total_loss / (len(data) // batch_size)
def _test_epoch(self, data):
total_loss = 0
with torch.no_grad():
for state, _, next_state, time_delta in data:
predicted_next_state = self.model(state, time_delta)
loss = sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params)
total_loss += loss
return total_loss / len(data)
def save_model(self, path):
torch.save(self.model.state_dict(), path)
def load_model(self, path):
self.model.load_state_dict(torch.load(path))
def save_dataset(self, path=None):
path = path or self.dataset_path
with open(path, 'wb') as f:
pickle.dump(self.dataset, f)
def load_dataset(self, path=None):
path = path or self.dataset_path
if os.path.exists(path):
with open(path, 'rb') as f:
return pickle.load(f)
return None
def merge_datasets(self, other_dataset_path):
other_dataset = self.load_dataset(other_dataset_path)
if other_dataset:
self.dataset.extend(other_dataset)
self.save_dataset()