Implemenetd Model Learning
This commit is contained in:
parent
132c47ff21
commit
60cd44cc9e
174
nucon/model.py
Normal file
174
nucon/model.py
Normal file
@ -0,0 +1,174 @@
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import random
|
||||
from enum import Enum
|
||||
from nucon import Nucon
|
||||
import pickle
|
||||
import os
|
||||
from typing import Union, Tuple, List
|
||||
|
||||
Actors = {
|
||||
'random': lambda nucon: lambda obs: {param.id: random.uniform(param.min_val, param.max_val) if param.min_val is not None and param.max_val is not None else 0 for param in nucon.get_all_writable().values()},
|
||||
'null': lambda nucon: lambda obs: {},
|
||||
}
|
||||
|
||||
class ReactorDynamicsNet(nn.Module):
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(ReactorDynamicsNet, self).__init__()
|
||||
self.network = nn.Sequential(
|
||||
nn.Linear(input_dim + 1, 128), # +1 for time_delta
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, output_dim)
|
||||
)
|
||||
|
||||
def forward(self, state, time_delta):
|
||||
x = torch.cat([state, time_delta], dim=-1)
|
||||
return self.network(x)
|
||||
|
||||
class ReactorDynamicsModel(nn.Module):
|
||||
def __init__(self, input_params: List[str], output_params: List[str]):
|
||||
super(ReactorDynamicsModel, self).__init__()
|
||||
self.input_params = input_params
|
||||
self.output_params = output_params
|
||||
|
||||
input_dim = len(input_params)
|
||||
output_dim = len(output_params)
|
||||
self.net = ReactorDynamicsNet(input_dim, output_dim)
|
||||
|
||||
def _state_dict_to_tensor(self, state_dict):
|
||||
return torch.tensor([state_dict[p] for p in self.input_params], dtype=torch.float32)
|
||||
|
||||
def _tensor_to_state_dict(self, tensor):
|
||||
return {p: tensor[i].item() for i, p in enumerate(self.output_params)}
|
||||
|
||||
def forward(self, state_dict, time_delta):
|
||||
state_tensor = self._state_dict_to_tensor(state_dict).unsqueeze(0)
|
||||
time_delta_tensor = torch.tensor([time_delta], dtype=torch.float32).unsqueeze(0)
|
||||
predicted_tensor = self.net(state_tensor, time_delta_tensor)
|
||||
return self._tensor_to_state_dict(predicted_tensor.squeeze(0))
|
||||
|
||||
class NuconModelLearner:
|
||||
def __init__(self, nucon=None, actor='null', dataset_path='nucon_dataset.pkl', time_delta: Union[float, Tuple[float, float]] = 0.1):
|
||||
self.nucon = Nucon() if nucon is None else nucon
|
||||
self.actor = Actors[actor](self.nucon) if actor in Actors else actor
|
||||
self.dataset = self.load_dataset(dataset_path) or []
|
||||
self.dataset_path = dataset_path
|
||||
|
||||
self.readable_params = list(self.nucon.get_all_readable().keys())
|
||||
self.non_writable_params = [param.id for param in self.nucon.get_all_readable().values() if not param.is_writable]
|
||||
self.model = ReactorDynamicsModel(self.readable_params, self.non_writable_params)
|
||||
self.optimizer = optim.Adam(self.model.parameters())
|
||||
|
||||
if isinstance(time_delta, (int, float)):
|
||||
self.time_delta = lambda: time_delta
|
||||
elif isinstance(time_delta, tuple) and len(time_delta) == 2:
|
||||
self.time_delta = lambda: random.uniform(*time_delta)
|
||||
else:
|
||||
raise ValueError("time_delta must be a float or a tuple of two floats")
|
||||
|
||||
def _get_state(self):
|
||||
state = {}
|
||||
for param_id, param in self.nucon.get_all_readable().items():
|
||||
value = self.nucon.get(param)
|
||||
if isinstance(value, Enum):
|
||||
value = value.value
|
||||
state[param_id] = value
|
||||
return state
|
||||
|
||||
def collect_data(self, num_steps):
|
||||
state = self._get_state()
|
||||
for _ in range(num_steps):
|
||||
action = self.actor(state)
|
||||
start_time = time.time()
|
||||
for param_id, value in action.items():
|
||||
self.nucon.set(param_id, value)
|
||||
time_delta = self.time_delta()
|
||||
time.sleep(time_delta)
|
||||
next_state = self._get_state()
|
||||
|
||||
self.dataset.append((state, action, next_state, time_delta))
|
||||
state = next_state
|
||||
|
||||
self.save_dataset()
|
||||
|
||||
def refine_dataset(self, error_threshold):
|
||||
refined_data = []
|
||||
for state, action, next_state, time_delta in self.dataset:
|
||||
predicted_next_state = self.model(state, time_delta)
|
||||
|
||||
error = sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params)
|
||||
if error > error_threshold:
|
||||
refined_data.append((state, action, next_state, time_delta))
|
||||
|
||||
self.dataset = refined_data
|
||||
self.save_dataset()
|
||||
|
||||
def train_model(self, batch_size=32, num_epochs=10, test_split=0.2):
|
||||
random.shuffle(self.dataset)
|
||||
split_idx = int(len(self.dataset) * (1 - test_split))
|
||||
train_data = self.dataset[:split_idx]
|
||||
test_data = self.dataset[split_idx:]
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
train_loss = self._train_epoch(train_data, batch_size)
|
||||
test_loss = self._test_epoch(test_data)
|
||||
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
|
||||
|
||||
def _train_epoch(self, data, batch_size):
|
||||
total_loss = 0
|
||||
for i in range(0, len(data), batch_size):
|
||||
batch = data[i:i+batch_size]
|
||||
states, _, next_states, time_deltas = zip(*batch)
|
||||
|
||||
loss = 0
|
||||
for state, next_state, time_delta in zip(states, next_states, time_deltas):
|
||||
predicted_next_state = self.model(state, time_delta)
|
||||
loss += sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params)
|
||||
|
||||
loss /= len(batch)
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
return total_loss / (len(data) // batch_size)
|
||||
|
||||
def _test_epoch(self, data):
|
||||
total_loss = 0
|
||||
with torch.no_grad():
|
||||
for state, _, next_state, time_delta in data:
|
||||
predicted_next_state = self.model(state, time_delta)
|
||||
loss = sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params)
|
||||
total_loss += loss
|
||||
|
||||
return total_loss / len(data)
|
||||
|
||||
def save_model(self, path):
|
||||
torch.save(self.model.state_dict(), path)
|
||||
|
||||
def load_model(self, path):
|
||||
self.model.load_state_dict(torch.load(path))
|
||||
|
||||
def save_dataset(self, path=None):
|
||||
path = path or self.dataset_path
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(self.dataset, f)
|
||||
|
||||
def load_dataset(self, path=None):
|
||||
path = path or self.dataset_path
|
||||
if os.path.exists(path):
|
||||
with open(path, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
return None
|
||||
|
||||
def merge_datasets(self, other_dataset_path):
|
||||
other_dataset = self.load_dataset(other_dataset_path)
|
||||
if other_dataset:
|
||||
self.dataset.extend(other_dataset)
|
||||
self.save_dataset()
|
Loading…
Reference in New Issue
Block a user