import socket
import json
import threading
import time
import torch
import logging
from .llm_manager import LLMManager

logger = logging.getLogger(__name__)

class GPUWorker:
    def __init__(self, coordinator_host: str, coordinator_port: int, model_path: str, gpu_id: int = 0):
        self.coordinator_host = coordinator_host
        self.coordinator_port = coordinator_port
        self.model_path = model_path
        self.gpu_id = gpu_id
        self.worker_id = f"worker_{gpu_id}_{int(time.time())}"
        self.socket = None
        self.llm_manager = None
        self.running = False
        self.connected = False
        
    def connect(self):
        """Connette al coordinator"""
        try:
            self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.socket.connect((self.coordinator_host, self.coordinator_port))
            self.connected = True
            
            # Inizializza LLM
            self.llm_manager = LLMManager(self.model_path, self.gpu_id)
            
            # Registra worker
            self._register_with_coordinator()
            
            # Avvia thread heartbeat
            heartbeat_thread = threading.Thread(target=self._heartbeat_loop, daemon=True)
            heartbeat_thread.start()
            
            # Loop principale per ricevere compiti
            self._receive_loop()
            
        except Exception as e:
            logger.error(f"Failed to connect to coordinator: {e}")
            self.connected = False
    
    def disconnect(self):
        """Disconnette dal coordinator"""
        self.running = False
        if self.socket:
            self.socket.close()
        self.connected = False
    
    def _register_with_coordinator(self):
        """Registra il worker con il coordinator"""
        gpu_info = self._get_gpu_info()
        model_info = self.llm_manager.get_model_info()
        
        registration_message = {
            'type': 'worker_register',
            'worker_id': self.worker_id,
            'gpu_info': gpu_info,
            'model_info': model_info,
            'performance_score': self._calculate_performance_score(gpu_info)
        }
        
        self.socket.send(json.dumps(registration_message).encode('utf-8'))
        logger.info(f"Worker {self.worker_id} registered with coordinator")
    
    def _get_gpu_info(self) -> dict:
        """Ottiene informazioni sulla GPU"""
        if torch.cuda.is_available():
            gpu_props = torch.cuda.get_device_properties(self.gpu_id)
            return {
                'gpu_name': gpu_props.name,
                'total_memory': gpu_props.total_memory,
                'cuda_cores': gpu_props.multi_processor_count,
                'compute_capability': f"{gpu_props.major}.{gpu_props.minor}"
            }
        return {'gpu_name': 'CPU', 'total_memory': 0, 'cuda_cores': 0}
    
    def _calculate_performance_score(self, gpu_info: dict) -> float:
        """Calcola score prestazionale per load balancing"""
        if gpu_info['gpu_name'] == 'CPU':
            return 0.5
        
        # Score basato su memoria e core CUDA
        memory_gb = gpu_info['total_memory'] / (1024**3)
        cuda_cores = gpu_info['cuda_cores']
        
        score = (memory_gb / 16) * 0.6 + (cuda_cores / 100) * 0.4
        return max(0.1, min(1.0, score))
    
    def _heartbeat_loop(self):
        """Invia heartbeat periodico al coordinator"""
        while self.connected:
            try:
                heartbeat_message = {
                    'type': 'heartbeat',
                    'worker_id': self.worker_id,
                    'timestamp': time.time()
                }
                self.socket.send(json.dumps(heartbeat_message).encode('utf-8'))
                time.sleep(10)  # Heartbeat ogni 10 secondi
            except Exception as e:
                logger.error(f"Error sending heartbeat: {e}")
                self.connected = False
                break
    
    def _receive_loop(self):
        """Loop principale per ricevere compiti"""
        self.running = True
        while self.running and self.connected:
            try:
                data = self.socket.recv(4096).decode('utf-8')
                if not data:
                    break
                    
                message = json.loads(data)
                self._handle_message(message)
                    
            except Exception as e:
                logger.error(f"Error receiving data: {e}")
                self.connected = False
                break
    
    def _handle_message(self, message: dict):
        """Gestisce messaggi dal coordinator"""
        message_type = message.get('type')
        
        if message_type == 'inference_task':
            self._handle_inference_task(message)
        elif message_type == 'registration_confirmed':
            logger.info("Registration confirmed by coordinator")
    
    def _handle_inference_task(self, message: dict):
        """Gestisce task di inferenza"""
        request_id = message['request_id']
        prompt_chunk = message['prompt_chunk']
        chunk_id = message['chunk_id']
        
        try:
            # Esegue inferenza
            result = self.llm_manager.generate(prompt_chunk)
            
            # Invia risultato al coordinator
            response_message = {
                'type': 'worker_response',
                'request_id': request_id,
                'worker_id': self.worker_id,
                'chunk_id': chunk_id,
                'result': result
            }
            
            self.socket.send(json.dumps(response_message).encode('utf-8'))
            logger.info(f"Worker {self.worker_id} completed task for request {request_id}")
            
        except Exception as e:
            logger.error(f"Error during inference: {e}")
            # Invia messaggio di errore
            error_message = {
                'type': 'worker_error',
                'request_id': request_id,
                'worker_id': self.worker_id,
                'error': str(e)
            }
            self.socket.send(json.dumps(error_message).encode('utf-8'))