Implemenetd Model Learning
This commit is contained in:
		
							parent
							
								
									132c47ff21
								
							
						
					
					
						commit
						60cd44cc9e
					
				
							
								
								
									
										174
									
								
								nucon/model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										174
									
								
								nucon/model.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,174 @@
 | 
			
		||||
import numpy as np
 | 
			
		||||
import time
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.optim as optim
 | 
			
		||||
import random
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from nucon import Nucon
 | 
			
		||||
import pickle
 | 
			
		||||
import os
 | 
			
		||||
from typing import Union, Tuple, List
 | 
			
		||||
 | 
			
		||||
Actors = {
 | 
			
		||||
    'random': lambda nucon: lambda obs: {param.id: random.uniform(param.min_val, param.max_val) if param.min_val is not None and param.max_val is not None else 0 for param in nucon.get_all_writable().values()},
 | 
			
		||||
    'null': lambda nucon: lambda obs: {},
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class ReactorDynamicsNet(nn.Module):
 | 
			
		||||
    def __init__(self, input_dim, output_dim):
 | 
			
		||||
        super(ReactorDynamicsNet, self).__init__()
 | 
			
		||||
        self.network = nn.Sequential(
 | 
			
		||||
            nn.Linear(input_dim + 1, 128),  # +1 for time_delta
 | 
			
		||||
            nn.ReLU(),
 | 
			
		||||
            nn.Linear(128, 128),
 | 
			
		||||
            nn.ReLU(),
 | 
			
		||||
            nn.Linear(128, output_dim)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def forward(self, state, time_delta):
 | 
			
		||||
        x = torch.cat([state, time_delta], dim=-1)
 | 
			
		||||
        return self.network(x)
 | 
			
		||||
 | 
			
		||||
class ReactorDynamicsModel(nn.Module):
 | 
			
		||||
    def __init__(self, input_params: List[str], output_params: List[str]):
 | 
			
		||||
        super(ReactorDynamicsModel, self).__init__()
 | 
			
		||||
        self.input_params = input_params
 | 
			
		||||
        self.output_params = output_params
 | 
			
		||||
        
 | 
			
		||||
        input_dim = len(input_params)
 | 
			
		||||
        output_dim = len(output_params)
 | 
			
		||||
        self.net = ReactorDynamicsNet(input_dim, output_dim)
 | 
			
		||||
 | 
			
		||||
    def _state_dict_to_tensor(self, state_dict):
 | 
			
		||||
        return torch.tensor([state_dict[p] for p in self.input_params], dtype=torch.float32)
 | 
			
		||||
 | 
			
		||||
    def _tensor_to_state_dict(self, tensor):
 | 
			
		||||
        return {p: tensor[i].item() for i, p in enumerate(self.output_params)}
 | 
			
		||||
 | 
			
		||||
    def forward(self, state_dict, time_delta):
 | 
			
		||||
        state_tensor = self._state_dict_to_tensor(state_dict).unsqueeze(0)
 | 
			
		||||
        time_delta_tensor = torch.tensor([time_delta], dtype=torch.float32).unsqueeze(0)
 | 
			
		||||
        predicted_tensor = self.net(state_tensor, time_delta_tensor)
 | 
			
		||||
        return self._tensor_to_state_dict(predicted_tensor.squeeze(0))
 | 
			
		||||
 | 
			
		||||
class NuconModelLearner:
 | 
			
		||||
    def __init__(self, nucon=None, actor='null', dataset_path='nucon_dataset.pkl', time_delta: Union[float, Tuple[float, float]] = 0.1):
 | 
			
		||||
        self.nucon = Nucon() if nucon is None else nucon
 | 
			
		||||
        self.actor = Actors[actor](self.nucon) if actor in Actors else actor
 | 
			
		||||
        self.dataset = self.load_dataset(dataset_path) or []
 | 
			
		||||
        self.dataset_path = dataset_path
 | 
			
		||||
        
 | 
			
		||||
        self.readable_params = list(self.nucon.get_all_readable().keys())
 | 
			
		||||
        self.non_writable_params = [param.id for param in self.nucon.get_all_readable().values() if not param.is_writable]
 | 
			
		||||
        self.model = ReactorDynamicsModel(self.readable_params, self.non_writable_params)
 | 
			
		||||
        self.optimizer = optim.Adam(self.model.parameters())
 | 
			
		||||
 | 
			
		||||
        if isinstance(time_delta, (int, float)):
 | 
			
		||||
            self.time_delta = lambda: time_delta
 | 
			
		||||
        elif isinstance(time_delta, tuple) and len(time_delta) == 2:
 | 
			
		||||
            self.time_delta = lambda: random.uniform(*time_delta)
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("time_delta must be a float or a tuple of two floats")
 | 
			
		||||
 | 
			
		||||
    def _get_state(self):
 | 
			
		||||
        state = {}
 | 
			
		||||
        for param_id, param in self.nucon.get_all_readable().items():
 | 
			
		||||
            value = self.nucon.get(param)
 | 
			
		||||
            if isinstance(value, Enum):
 | 
			
		||||
                value = value.value
 | 
			
		||||
            state[param_id] = value
 | 
			
		||||
        return state
 | 
			
		||||
 | 
			
		||||
    def collect_data(self, num_steps):
 | 
			
		||||
        state = self._get_state()
 | 
			
		||||
        for _ in range(num_steps):
 | 
			
		||||
            action = self.actor(state)
 | 
			
		||||
            start_time = time.time()
 | 
			
		||||
            for param_id, value in action.items():
 | 
			
		||||
                self.nucon.set(param_id, value)
 | 
			
		||||
            time_delta = self.time_delta()
 | 
			
		||||
            time.sleep(time_delta)
 | 
			
		||||
            next_state = self._get_state()
 | 
			
		||||
            
 | 
			
		||||
            self.dataset.append((state, action, next_state, time_delta))
 | 
			
		||||
            state = next_state
 | 
			
		||||
        
 | 
			
		||||
        self.save_dataset()
 | 
			
		||||
 | 
			
		||||
    def refine_dataset(self, error_threshold):
 | 
			
		||||
        refined_data = []
 | 
			
		||||
        for state, action, next_state, time_delta in self.dataset:
 | 
			
		||||
            predicted_next_state = self.model(state, time_delta)
 | 
			
		||||
            
 | 
			
		||||
            error = sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params)
 | 
			
		||||
            if error > error_threshold:
 | 
			
		||||
                refined_data.append((state, action, next_state, time_delta))
 | 
			
		||||
        
 | 
			
		||||
        self.dataset = refined_data
 | 
			
		||||
        self.save_dataset()
 | 
			
		||||
 | 
			
		||||
    def train_model(self, batch_size=32, num_epochs=10, test_split=0.2):
 | 
			
		||||
        random.shuffle(self.dataset)
 | 
			
		||||
        split_idx = int(len(self.dataset) * (1 - test_split))
 | 
			
		||||
        train_data = self.dataset[:split_idx]
 | 
			
		||||
        test_data = self.dataset[split_idx:]
 | 
			
		||||
 | 
			
		||||
        for epoch in range(num_epochs):
 | 
			
		||||
            train_loss = self._train_epoch(train_data, batch_size)
 | 
			
		||||
            test_loss = self._test_epoch(test_data)
 | 
			
		||||
            print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
 | 
			
		||||
 | 
			
		||||
    def _train_epoch(self, data, batch_size):
 | 
			
		||||
        total_loss = 0
 | 
			
		||||
        for i in range(0, len(data), batch_size):
 | 
			
		||||
            batch = data[i:i+batch_size]
 | 
			
		||||
            states, _, next_states, time_deltas = zip(*batch)
 | 
			
		||||
            
 | 
			
		||||
            loss = 0
 | 
			
		||||
            for state, next_state, time_delta in zip(states, next_states, time_deltas):
 | 
			
		||||
                predicted_next_state = self.model(state, time_delta)
 | 
			
		||||
                loss += sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params)
 | 
			
		||||
            
 | 
			
		||||
            loss /= len(batch)
 | 
			
		||||
            self.optimizer.zero_grad()
 | 
			
		||||
            loss.backward()
 | 
			
		||||
            self.optimizer.step()
 | 
			
		||||
            
 | 
			
		||||
            total_loss += loss.item()
 | 
			
		||||
        
 | 
			
		||||
        return total_loss / (len(data) // batch_size)
 | 
			
		||||
 | 
			
		||||
    def _test_epoch(self, data):
 | 
			
		||||
        total_loss = 0
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            for state, _, next_state, time_delta in data:
 | 
			
		||||
                predicted_next_state = self.model(state, time_delta)
 | 
			
		||||
                loss = sum((predicted_next_state[p] - next_state[p])**2 for p in self.non_writable_params)
 | 
			
		||||
                total_loss += loss
 | 
			
		||||
        
 | 
			
		||||
        return total_loss / len(data)
 | 
			
		||||
 | 
			
		||||
    def save_model(self, path):
 | 
			
		||||
        torch.save(self.model.state_dict(), path)
 | 
			
		||||
 | 
			
		||||
    def load_model(self, path):
 | 
			
		||||
        self.model.load_state_dict(torch.load(path))
 | 
			
		||||
 | 
			
		||||
    def save_dataset(self, path=None):
 | 
			
		||||
        path = path or self.dataset_path
 | 
			
		||||
        with open(path, 'wb') as f:
 | 
			
		||||
            pickle.dump(self.dataset, f)
 | 
			
		||||
 | 
			
		||||
    def load_dataset(self, path=None):
 | 
			
		||||
        path = path or self.dataset_path
 | 
			
		||||
        if os.path.exists(path):
 | 
			
		||||
            with open(path, 'rb') as f:
 | 
			
		||||
                return pickle.load(f)
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def merge_datasets(self, other_dataset_path):
 | 
			
		||||
        other_dataset = self.load_dataset(other_dataset_path)
 | 
			
		||||
        if other_dataset:
 | 
			
		||||
            self.dataset.extend(other_dataset)
 | 
			
		||||
            self.save_dataset()
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user