import socket
import threading
import json
import time
import logging
import uuid
from typing import Dict, List, Any, Optional
from dataclasses import dataclass
from enum import Enum
from concurrent.futures import ThreadPoolExecutor

logger = logging.getLogger(__name__)

class NodeType(Enum):
    HIGH_PERFORMANCE = "high_perf"
    MEDIUM_PERFORMANCE = "medium"
    LOW_PERFORMANCE = "low_perf"
    EDGE_NODE = "edge"

class PipelineStrategy(Enum):
    SINGLE_GPU = "single"
    HYBRID_PARALLEL = "hybrid"
    PIPELINE_PARALLEL = "pipeline"
    TENSOR_PARALLEL = "tensor"
    MASSIVE_SCALE = "massive"

@dataclass
class GPUResource:
    worker_id: str
    gpu_name: str
    memory_gb: float
    performance_score: float
    node_type: NodeType
    current_load: float
    assigned_layers: List[int]
    last_heartbeat: float

@dataclass
class ModelShard:
    shard_id: str
    gpu_id: str
    layers: List[int]
    attention_heads: List[int]
    parameters_count: int
    memory_required: float

@dataclass
class PipelineStage:
    stage_id: str
    stage_type: str
    assigned_gpus: List[str]
    input_data: Any
    is_final_stage: bool
    dependencies: List[str]

class HierarchicalLoadBalancer:
    def __init__(self):
        self.gpu_resources: Dict[str, GPUResource] = {}
        self.performance_metrics = {}
        self.load_distribution = {}
        
    def add_gpu_resource(self, gpu_resource: GPUResource):
        self.gpu_resources[gpu_resource.worker_id] = gpu_resource
        
    def get_optimal_gpu_allocation(self, total_gpus_needed: int, complexity: str) -> List[GPUResource]:
        """Seleziona le GPU ottimali basandosi su performance e carico"""
        available_gpus = [gpu for gpu in self.gpu_resources.values() 
                         if gpu.current_load < 0.8]
        
        if not available_gpus:
            return []
            
        sorted_gpus = sorted(available_gpus, 
                           key=lambda x: (-x.performance_score, x.current_load))
        
        if complexity == "high":
            return sorted_gpus[:total_gpus_needed]
        elif complexity == "medium":
            return self._get_balanced_mix(sorted_gpus, total_gpus_needed)
        else:
            return sorted_gpus[:total_gpus_needed]
    
    def _get_balanced_mix(self, gpus: List[GPUResource], count: int) -> List[GPUResource]:
        """Restituisce un mix bilanciato di GPU"""
        if len(gpus) <= count:
            return gpus
            
        high_perf = [g for g in gpus if g.node_type == NodeType.HIGH_PERFORMANCE][:count//2]
        medium_perf = [g for g in gpus if g.node_type == NodeType.MEDIUM_PERFORMANCE][:count - len(high_perf)]
        
        return high_perf + medium_perf

class DynamicModelSharding:
    def __init__(self):
        self.sharding_cache = {}
        # ✅ SOLO sharding_cache, niente central_integration qui
        
    def calculate_optimal_sharding(self, total_gpus: int, model_layers: int = 32) -> List[ModelShard]:
        """Calcola sharding ottimale per qualsiasi numero di GPU"""
        cache_key = f"{total_gpus}_{model_layers}"
        if cache_key in self.sharding_cache:
            return self.sharding_cache[cache_key].copy()
        
        shards = []
        layers_per_gpu = max(model_layers // total_gpus, 1)
        
        for gpu_idx in range(total_gpus):
            start_layer = gpu_idx * layers_per_gpu
            end_layer = start_layer + layers_per_gpu if gpu_idx < total_gpus - 1 else model_layers
            
            shard = ModelShard(
                shard_id=f"shard_{gpu_idx}",
                gpu_id=f"gpu_{gpu_idx}",
                layers=list(range(start_layer, end_layer)),
                attention_heads=self._assign_attention_heads(gpu_idx, total_gpus),
                parameters_count=7000000000 // total_gpus,
                memory_required=4.0 / total_gpus
            )
            shards.append(shard)
        
        self.sharding_cache[cache_key] = shards
        return shards.copy()
    
    def _assign_attention_heads(self, gpu_idx: int, total_gpus: int) -> List[int]:
        """Assegna attention heads in modo bilanciato"""
        total_heads = 32
        heads_per_gpu = max(total_heads // total_gpus, 1)
        start_head = gpu_idx * heads_per_gpu
        end_head = start_head + heads_per_gpu if gpu_idx < total_gpus - 1 else total_heads
        return list(range(start_head, end_head))

class PipelineManager:
    def __init__(self):
        self.pipeline_templates = {}
        self.performance_history = {}
        
    def create_optimal_pipeline(self, available_gpus: int, prompt_complexity: str) -> List[PipelineStage]:
        """Crea pipeline ottimale basata su risorse disponibili"""
        
        if available_gpus == 1:
            return self._create_single_gpu_pipeline()
        elif available_gpus == 2:
            return self._create_hybrid_pipeline()
        elif 3 <= available_gpus <= 8:
            return self._create_small_pipeline(available_gpus)
        elif 9 <= available_gpus <= 32:
            return self._create_medium_pipeline(available_gpus)
        elif 33 <= available_gpus <= 100:
            return self._create_large_pipeline(available_gpus)
        else:
            return self._create_massive_pipeline(available_gpus)
    
    def _create_single_gpu_pipeline(self) -> List[PipelineStage]:
        return [PipelineStage(
            stage_id="full_inference",
            stage_type="complete_generation",
            assigned_gpus=["gpu_0"],
            input_data=None,
            is_final_stage=True,
            dependencies=[]
        )]
    
    def _create_hybrid_pipeline(self) -> List[PipelineStage]:
        return [
            PipelineStage(
                stage_id="input_processing",
                stage_type="input_processor",
                assigned_gpus=["gpu_0"],
                input_data=None,
                is_final_stage=False,
                dependencies=[]
            ),
            PipelineStage(
                stage_id="output_generation",
                stage_type="output_generator", 
                assigned_gpus=["gpu_1"],
                input_data=None,
                is_final_stage=True,
                dependencies=["input_processing"]
            )
        ]
    
    def _create_small_pipeline(self, gpu_count: int) -> List[PipelineStage]:
        stages = []
        layers_per_stage = max(32 // gpu_count, 1)
        
        for i in range(gpu_count):
            stages.append(PipelineStage(
                stage_id=f"stage_{i}",
                stage_type=f"layers_{i*layers_per_stage}_{min((i+1)*layers_per_stage, 32)}",
                assigned_gpus=[f"gpu_{i}"],
                input_data=None,
                is_final_stage=(i == gpu_count - 1),
                dependencies=[f"stage_{i-1}"] if i > 0 else []
            ))
        
        return stages
    
    def _create_massive_pipeline(self, gpu_count: int) -> List[PipelineStage]:
        """Pipeline per 100+ GPU con tensor parallelism"""
        stages = []
        
        pipeline_depth = min(gpu_count // 8, 16)
        tensor_parallelism = gpu_count // pipeline_depth
        
        logger.info(f"Creating massive pipeline: depth={pipeline_depth}, tensor_parallelism={tensor_parallelism}")
        
        for stage_idx in range(pipeline_depth):
            gpu_group = [f"gpu_{stage_idx * tensor_parallelism + i}" 
                        for i in range(tensor_parallelism)]
            
            stages.append(PipelineStage(
                stage_id=f"pipeline_stage_{stage_idx}",
                stage_type=f"tensor_parallel_block_{stage_idx}",
                assigned_gpus=gpu_group,
                input_data=None,
                is_final_stage=(stage_idx == pipeline_depth - 1),
                dependencies=[f"pipeline_stage_{stage_idx-1}"] if stage_idx > 0 else []
            ))
        
        return stages

class SessionManager:
    def __init__(self):
        self.sessions: Dict[str, Dict] = {}
        self.session_timeout = 1800
        
    def get_session_context(self, session_id: str, prompt: str) -> str:
        """Gestisce contesto sessione"""
        if session_id not in self.sessions:
            self.sessions[session_id] = {
                'created': time.time(),
                'history': [],
                'last_activity': time.time()
            }
            
        session = self.sessions[session_id]
        session['last_activity'] = time.time()
        
        if len(session['history']) > 10:
            session['history'] = session['history'][-10:]
            
        return self._build_context(session, prompt)
    
    def _build_context(self, session: Dict, prompt: str) -> str:
        """Costruisce contesto dalla storia"""
        context = "Contesto conversazione:\n"
        for i, (q, a) in enumerate(session['history'][-3:]):
            context += f"Q{i+1}: {q}\nA{i+1}: {a}\n"
        context += f"\nNuova domanda: {prompt}\nRisposta:"
        return context

class PerformanceTracker:
    def __init__(self):
        self.metrics_history = []
        self.performance_thresholds = {
            'high_load': 0.8,
            'low_performance': 0.5
        }
        
    def track_metrics(self, metrics: Dict):
        """Traccia metriche di performance"""
        self.metrics_history.append({
            'timestamp': time.time(),
            'metrics': metrics
        })
        
        if len(self.metrics_history) > 1000:
            self.metrics_history = self.metrics_history[-1000:]

class AutoScaler:
    def __init__(self, coordinator):
        self.coordinator = coordinator
        self.scaling_thresholds = {
            'scale_up_cpu': 0.8,
            'scale_down_cpu': 0.3,
            'scale_up_memory': 0.9
        }
        
    def check_and_scale(self):
        """Verifica e applica scaling se necessario"""
        stats = self.coordinator._get_detailed_system_stats()
        
        if stats['average_load'] > self.scaling_thresholds['scale_up_cpu']:
            self._scale_up()
        elif stats['average_load'] < self.scaling_thresholds['scale_down_cpu']:
            self._scale_down()
            
    def _scale_up(self):
        """Scala verso l'alto il cluster"""
        logger.info("Auto-scaling: Scaling up cluster")
        
    def _scale_down(self):
        """Scala verso il basso il cluster"""
        logger.info("Auto-scaling: Scaling down cluster")

class MessageProcessor:
    """Classe helper per processare i messaggi in modo compatibile"""
    
    def __init__(self, scalable_coordinator):
        self.scalable_coordinator = scalable_coordinator
    
    def process_message(self, client_socket: socket.socket, address: tuple, message: dict):
        """Processa messaggio con dispatch compatibile"""
        message_type = message.get('type')

        if message_type == 'worker_register':
            self.scalable_coordinator._handle_worker_registration(client_socket, address, message)
        elif message_type == 'inference_request':
            self.scalable_coordinator._handle_inference_request(client_socket, address, message)
        elif message_type == 'pipeline_response':
            self.scalable_coordinator._handle_pipeline_response(client_socket, address, message)
        elif message_type == 'model_shard_response':
            self._handle_model_shard_response_compat(message)
        elif message_type == 'worker_response':
            self._handle_worker_response_compat(message)
        elif message_type == 'worker_error':
            self.scalable_coordinator._handle_worker_error(client_socket, address, message)
        elif message_type == 'heartbeat':
            self.scalable_coordinator._handle_heartbeat(client_socket, address, message)
        elif message_type == 'get_system_stats':
            self.scalable_coordinator._handle_system_stats_request(client_socket, address, message)
        else:
            logger.warning(f"Unknown message type: {message_type}")

    def _handle_model_shard_response_compat(self, message: dict):
        """Gestisce model_shard_response in modo compatibile"""
        try:
            request_id = message.get('request_id', 'unknown')
            worker_id = message.get('worker_id', 'unknown')
            shard_id = message.get('shard_id', 'unknown')
            is_final_shard = message.get('is_final_shard', False)

            logger.info(f"Model shard response: request={request_id}, worker={worker_id}, shard={shard_id}, final={is_final_shard}")

            if request_id in self.scalable_coordinator.active_pipelines:
                pipeline = self.scalable_coordinator.active_pipelines[request_id]

                logger.info(f"Pipeline found: strategy={pipeline.get('strategy')}, received={pipeline.get('received_responses', 0)}/{pipeline.get('total_expected', 0)}")

                if pipeline.get('strategy') == 'hybrid_parallel':
                    pipeline['received_responses'] += 1
                    pipeline['results'][worker_id] = message.get('result', '')

                    logger.info(f"Updated pipeline: {pipeline['received_responses']}/{pipeline['total_expected']} responses")

                    if pipeline['received_responses'] >= pipeline['total_expected']:
                        logger.info(f"All responses received for {request_id}")
                        final_result = self._combine_hybrid_results(pipeline)

                        response = {
                            'type': 'inference_response',
                            'request_id': request_id,
                            'result': final_result,
                            'worker_id': 'hybrid_system',
                            'strategy': 'hybrid_parallel',
                            'workers_used': pipeline['workers_used'],
                            'session_id': pipeline.get('session_id', 'default'),
                            'timestamp': time.time()
                        }

                        self.scalable_coordinator._send_to_client(pipeline['client_socket'], response)
                        logger.info(f"Hybrid inference completed for {request_id}")

                        del self.scalable_coordinator.active_pipelines[request_id]
                    return

            pipeline_response = {
                'type': 'pipeline_response',
                'pipeline_id': request_id,
                'stage_id': f"shard_{shard_id}",
                'worker_id': worker_id,
                'result': message.get('result', ''),
                'is_final_stage': is_final_shard,
                'timestamp': time.time()
            }

            self.scalable_coordinator._handle_pipeline_response(None, None, pipeline_response)

        except Exception as e:
            logger.error(f"Model shard response compatibility error: {e}")

    def _combine_hybrid_results(self, pipeline):
        """Combina risultati hybrid - Versione MIGLIORATA per risposte complete"""
        try:
            results = list(pipeline['results'].values())

            if not results:
                return "Elaborazione completata ma nessun risultato ricevuto"

            logger.info(f"Combining {len(results)} results: {list(pipeline['results'].keys())}")

            best_result = ""
            max_length = 0

            for worker_id, result in pipeline['results'].items():
                if result and isinstance(result, str):
                    if any(msg in result for msg in ['INPUT_PROCESSED', 'PROCESSED', 'SHARD_COMPLETE']):
                        logger.info(f"Skipping system message from {worker_id}")
                        continue

                    result_length = len(result.strip())
                    if result_length > max_length and result_length > 20:
                        best_result = result
                        max_length = result_length
                        logger.info(f"Better result found from {worker_id}: {result_length} chars")

            if best_result:
                logger.info(f"Using result with {max_length} characters")
                return best_result

            for result in reversed(results):
                if result and len(result.strip()) > 20:
                    logger.info(f"Using fallback result: {len(result.strip())} chars")
                    return result

            final_fallback = " ".join([str(r) for r in results if r and len(str(r).strip()) > 5])
            if final_fallback:
                logger.info(f"Using concatenated result: {len(final_fallback)} chars")
                return final_fallback

            return "Elaborazione completata con successo - Richiedi più dettagli se necessario"

        except Exception as e:
            logger.error(f"Results combination error: {e}")
            return "Risposta generata dal sistema - Si prega di riformulare la domanda per maggiori dettagli"

    def _handle_worker_response_compat(self, message: dict):
        """Gestisce worker_response in modo compatibile"""
        try:
            request_id = message.get('request_id', 'unknown')

            if request_id in self.scalable_coordinator.active_pipelines:
                pipeline = self.scalable_coordinator.active_pipelines[request_id]

                if pipeline.get('strategy') == 'single':
                    response = {
                        'type': 'inference_response',
                        'request_id': request_id,
                        'result': message.get('result', ''),
                        'worker_id': message.get('worker_id', 'unknown'),
                        'strategy': 'single',
                        'timestamp': time.time()
                    }

                    self.scalable_coordinator._send_to_client(pipeline['client_socket'], response)
                    logger.info(f"Single GPU response sent for {request_id}")

                    del self.scalable_coordinator.active_pipelines[request_id]
                    return

            pipeline_response = {
                'type': 'pipeline_response',
                'pipeline_id': f"single_{request_id}",
                'stage_id': 'single_stage',
                'worker_id': message.get('worker_id', 'unknown'),
                'result': message.get('result', ''),
                'is_final_stage': True,
                'timestamp': time.time()
            }

            self.scalable_coordinator._handle_pipeline_response(None, None, pipeline_response)

        except Exception as e:
            logger.error(f"Worker response compatibility error: {e}")
        print("🔧 DEBUG: Coordinator module loaded - checking CentralServerIntegration")
class ScalableCoordinator:
    """
    Coordinator completamente riscritto per scalabilità massiva (2-1000+ GPU)
    Architettura gerarchica senza colli di bottiglia
    """
    
    def __init__(self, host: str = "0.0.0.0", port: int = 8765, blockchain=None):
        self.host = host
        self.port = port
        self.blockchain = blockchain

        self.load_balancer = HierarchicalLoadBalancer()
        self.model_sharding = DynamicModelSharding()
        self.pipeline_manager = PipelineManager()

        self.worker_registry: Dict[str, Dict] = {}
        self.gpu_resources: Dict[str, GPUResource] = {}
        self.active_pipelines: Dict[str, Any] = {}
        self.session_manager = SessionManager()

        self.socket = None
        self.running = False
        self.thread_pool = ThreadPoolExecutor(max_workers=100)

        self.performance_tracker = PerformanceTracker()
        self.auto_scaler = AutoScaler(self)

        self.coordinator_id = f"coord_{uuid.uuid4().hex[:8]}"
        
        # === INTEGRAZIONE CENTRAL SERVER - CON DEBUG ===
        print(f"🔧 DEBUG: Initializing CentralServerIntegration for {self.coordinator_id}")
        self.central_integration = CentralServerIntegration(self)
        print("🔧 DEBUG: Starting registration loop...")
        self.central_integration.start_registration_loop()
        print("✅ DEBUG: CentralServerIntegration started successfully")
        
        # === REGISTRAZIONE IMMEDIATA FORZATA ===
        print("🚀 DEBUG: Forcing immediate registration...")
        import threading
        threading.Timer(2.0, self.force_immediate_registration).start()
        # ==============================================

        logger.info(f"Scalable Coordinator {self.coordinator_id} initialized")

    def force_immediate_registration(self):
        """Forza una registrazione immediata"""
        print("🚀 DEBUG: Executing forced immediate registration...")
        try:
            self.central_integration.register_with_central_server()
            print("✅ DEBUG: Forced registration completed successfully")
        except Exception as e:
            print(f"❌ DEBUG: Forced registration failed: {e}")
            import traceback
            print(f"🔍 DEBUG: Traceback: {traceback.format_exc()}")

    def start(self):
        """Avvia il coordinator con gestione massiva"""
        try:
            self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            self.socket.bind((self.host, self.port))
            self.socket.listen(1000)
            self.socket.settimeout(1.0)
            self.running = True

            self._start_background_services()
            
            logger.info(f"Scalable Coordinator listening on {self.host}:{self.port}")
            logger.info("Ready for 1000+ GPU cluster")

            self._accept_connections_loop()

        except Exception as e:
            logger.error(f"Coordinator startup failed: {e}")
            raise

    def _start_background_services(self):
            """Avvia tutti i servizi in background"""
            threading.Thread(target=self._cleanup_loop, daemon=True).start()
            threading.Thread(target=self._monitoring_loop, daemon=True).start()
            threading.Thread(target=self._auto_scaling_loop, daemon=True).start()
            
            # ⚠️ RIMUOVI QUESTA RIGA:
            # threading.Thread(target=self._register_with_central_server, daemon=True).start()
            
            # ✅ La registrazione è ora gestita da CentralServerIntegration


    def calculate_worker_rewards(self, worker_id, inference_time, model_complexity):
        """Calcola rewards per worker basati su utilizzo"""
        base_reward = 0.1  # AIT per inference di base
        time_bonus = inference_time / 60 * 0.01  # Bonus per tempo
        complexity_bonus = model_complexity * 0.05  # Bonus per complessità
        
        total_reward = base_reward + time_bonus + complexity_bonus
        return total_reward

    def _accept_connections_loop(self):
        """Loop principale di accettazione connessioni"""
        while self.running:
            try:
                client_socket, address = self.socket.accept()
                logger.info(f"New connection from {address}")
                
                self.thread_pool.submit(
                    self._handle_client_connection, 
                    client_socket, 
                    address
                )
                
            except socket.timeout:
                continue
            except Exception as e:
                if self.running:
                    logger.error(f"Connection acceptance error: {e}")

    def _handle_client_connection(self, client_socket: socket.socket, address: tuple):
        """Gestisce connessione client in thread separato"""
        buffer = ""
        try:
            client_socket.settimeout(1.0)
            
            while self.running:
                try:
                    data = client_socket.recv(65536).decode('utf-8')
                    if not data:
                        break
                        
                    buffer += data
                    
                    while '\n' in buffer:
                        line, buffer = buffer.split('\n', 1)
                        if line.strip():
                            self._process_message_async(client_socket, address, line.strip())
                            
                except socket.timeout:
                    continue
                except Exception as e:
                    logger.error(f"Client {address} error: {e}")
                    break
                    
        except Exception as e:
            logger.error(f"Error handling client {address}: {e}")
        finally:
            self._cleanup_client_connection(client_socket, address)

    def _handle_error_message(self, client_socket: socket.socket, address: tuple, message: dict):
        """Gestisce messaggi di errore dai worker"""
        try:
            error_msg = message.get('error', 'Unknown error')
            worker_id = message.get('worker_id', 'unknown')
            request_id = message.get('request_id', 'unknown')
            
            logger.error(f"Error from worker {worker_id} for request {request_id}: {error_msg}")
            
            # Gestisci errori nelle pipeline attive
            if request_id in self.active_pipelines:
                pipeline = self.active_pipelines[request_id]
                pipeline['received_responses'] = pipeline.get('received_responses', 0) + 1
                pipeline['results'][worker_id] = f"[ERROR] {error_msg}"
                
                logger.warning(f"Error in pipeline {request_id}, responses: {pipeline['received_responses']}/{pipeline['total_expected']}")
                
                # Se abbiamo abbastanza risposte (anche con errori), combina
                if pipeline['received_responses'] >= pipeline['total_expected']:
                    logger.info(f"Combining results despite errors for {request_id}")
                    self._finalize_pipeline_early(request_id)
                    
        except Exception as e:
            logger.error(f"Error message handling failed: {e}")

    def _finalize_pipeline_early(self, pipeline_id: str):
        """Finalizza pipeline con i risultati disponibili (anche parziali)"""
        try:
            if pipeline_id not in self.active_pipelines:
                return
                
            pipeline = self.active_pipelines[pipeline_id]
            
            logger.info(f"Finalizing pipeline {pipeline_id} with {pipeline['received_responses']}/{pipeline['total_expected']} responses")
            
            # Combina i risultati disponibili
            available_results = {}
            for worker_id, result in pipeline['results'].items():
                if result and not result.startswith('[ERROR]'):
                    available_results[worker_id] = result
            
            if available_results:
                final_result = self._combine_available_results(available_results)
                logger.info(f"Using {len(available_results)} successful responses")
            else:
                # TUTTI I WORKER HANNO FALLITO
                final_result = "Mi dispiace, tutti i worker hanno riportato errori. Riprova con una domanda più semplice."
            
            # Invia risposta al client
            response = {
                'type': 'inference_response',
                'request_id': pipeline['request_id'],
                'result': final_result,
                'pipeline_id': pipeline_id,
                'workers_used': list(pipeline.get('workers_used', [])),
                'successful_workers': len(available_results),
                'total_workers': pipeline['total_expected'],
                'processing_time': time.time() - pipeline['start_time'],
                'timestamp': time.time()
            }
            
            self._send_to_client(pipeline['client_socket'], response)
            logger.info(f"Pipeline {pipeline_id} completed with {len(available_results)}/{pipeline['total_expected']} successful workers")
            
            # Cleanup
            del self.active_pipelines[pipeline_id]
            
        except Exception as e:
            logger.error(f"Early pipeline finalization error: {e}")

    def _combine_available_results(self, results: Dict[str, str]) -> str:
        """Combina i risultati disponibili"""
        try:
            if not results:
                return "Nessun risultato disponibile dai worker"
                
            # Prendi il risultato più lungo e significativo
            best_result = ""
            for worker_id, result in results.items():
                if (result and 
                    len(result.strip()) > len(best_result) and 
                    len(result.strip()) > 20):
                    best_result = result
            
            return best_result if best_result else list(results.values())[0]
            
        except Exception as e:
            logger.error(f"Results combination error: {e}")
            return "Risposta generata dal sistema"
        
    def _handle_system_stats_api(self, client_socket: socket.socket, address: tuple, message: dict):
        """API per fornire statistiche al central server"""
        stats = self._get_detailed_system_stats()
        response = {
            'type': 'api_response',
            'endpoint': 'system_stats', 
            'stats': stats,
            'timestamp': time.time()
        }
        self._send_to_client(client_socket, response)

    def _process_message_async(self, client_socket: socket.socket, address: tuple, message_str: str):
        """Processa messaggio in modo asincrono"""
        try:
            message = json.loads(message_str)
            message_type = message.get('type')

    

            handlers = {
                'get_system_stats_api': self._handle_system_stats_api, 
                'worker_register': self._handle_worker_registration,
                'inference_request': self._handle_inference_request,
                'pipeline_response': self._handle_pipeline_response,
                'model_shard_response': self._handle_model_shard_response,
                'worker_response': self._handle_worker_response,
                'heartbeat': self._handle_heartbeat,
                'get_system_stats': self._handle_system_stats_request,
                'error': self._handle_error_message,  # AGGIUNTO
                'worker_error': self._handle_error_message  # USA LO STESSO HANDLER
            }

            handler = handlers.get(message_type)
            if handler:
                handler(client_socket, address, message)
            else:
                logger.warning(f"Unknown message type: {message_type}")
                
        except json.JSONDecodeError as e:
            logger.warning(f"Invalid JSON from {address}: {e}")
        except Exception as e:
            logger.error(f"Message processing error: {e}")

    def _handle_worker_registration(self, client_socket: socket.socket, address: tuple, message: dict):
        """Registra worker con analisi avanzata delle capacità"""
        try:
            worker_id = message.get('worker_id', f'worker_{uuid.uuid4().hex[:8]}')
            gpu_info = message.get('gpu_info', {})
            
            gpu_resource = self._analyze_gpu_capabilities(worker_id, gpu_info)
            
            self.worker_registry[worker_id] = {
                'socket': client_socket,
                'address': address,
                'gpu_info': gpu_info,
                'last_heartbeat': time.time(),
                'resource': gpu_resource
            }
            
            self.load_balancer.add_gpu_resource(gpu_resource)
            self.gpu_resources[worker_id] = gpu_resource
            
            logger.info(f"Worker registered: {worker_id}")
            logger.info(f"GPU: {gpu_resource.gpu_name} | Score: {gpu_resource.performance_score:.2f}")
            logger.info(f"Memory: {gpu_resource.memory_gb}GB | Type: {gpu_resource.node_type.value}")

            response = {
                'type': 'registration_confirmed',
                'worker_id': worker_id,
                'coordinator_id': self.coordinator_id,
                'timestamp': time.time()
            }
            self._send_to_client(client_socket, response)
            
        except Exception as e:
            logger.error(f"Worker registration error: {e}")

    def _analyze_gpu_capabilities(self, worker_id: str, gpu_info: dict) -> GPUResource:
        """Analizza approfonditamente le capacità della GPU"""
        gpu_name = gpu_info.get('gpu_name', 'Unknown').lower()
        memory_gb = self._extract_gpu_memory(gpu_info)
        
        performance_score = self._calculate_performance_score(gpu_name, memory_gb)
        node_type = self._determine_node_type(performance_score, memory_gb)
        
        return GPUResource(
            worker_id=worker_id,
            gpu_name=gpu_info.get('gpu_name', 'Unknown'),
            memory_gb=memory_gb,
            performance_score=performance_score,
            node_type=node_type,
            current_load=0.0,
            assigned_layers=[],
            last_heartbeat=time.time()
        )

    def _calculate_performance_score(self, gpu_name: str, memory_gb: float) -> float:
        """Calcola score prestazioni avanzato"""
        score = 1.0
        
        if memory_gb >= 16:
            score *= 2.0
        elif memory_gb >= 12:
            score *= 1.7
        elif memory_gb >= 8:
            score *= 1.3
        elif memory_gb <= 4:
            score *= 0.6
            
        performance_factors = {
            'rtx 4090': 2.8, 'rtx 4080': 2.4, 'rtx 4070': 2.0,
            'rtx 3090': 2.5, 'rtx 3080': 2.2, 'rtx 3070': 1.8,
            'rtx 3060': 1.5, 'rtx 3050': 1.0,
            'a100': 3.5, 'h100': 4.0, 'v100': 3.0
        }
        
        for model, factor in performance_factors.items():
            if model in gpu_name:
                score *= factor
                break
                
        if 'rtx 40' in gpu_name:
            score *= 1.4
        elif 'rtx 30' in gpu_name:
            score *= 1.2
            
        return max(0.5, min(4.0, score))

    def _determine_node_type(self, performance_score: float, memory_gb: float) -> NodeType:
        """Determina il tipo di nodo basato su capacità"""
        if performance_score >= 2.5 and memory_gb >= 12:
            return NodeType.HIGH_PERFORMANCE
        elif performance_score >= 1.5 and memory_gb >= 8:
            return NodeType.MEDIUM_PERFORMANCE
        elif performance_score < 1.0 or memory_gb < 4:
            return NodeType.EDGE_NODE
        else:
            return NodeType.LOW_PERFORMANCE

    def _extract_gpu_memory(self, gpu_info: dict) -> float:
        """Estrae memoria GPU in GB"""
        try:
            memory_str = str(gpu_info.get('total_memory', '0'))
            
            if 'GB' in memory_str:
                return float(memory_str.replace('GB', '').strip())
            elif 'MB' in memory_str:
                return float(memory_str.replace('MB', '').strip()) / 1024
            elif 'MiB' in memory_str:
                return float(memory_str.replace('MiB', '').strip()) / 1024
            elif 'GiB' in memory_str:
                return float(memory_str.replace('GiB', '').strip())
            else:
                return float(memory_str) / (1024**3)
                
        except:
            return 4.0

    def _handle_inference_request(self, client_socket: socket.socket, address: tuple, message: dict):
        """Gestisce richiesta di inferenza con strategia scalabile"""
        try:
            request_id = message.get('request_id', f'req_{uuid.uuid4().hex[:8]}')
            prompt = message.get('prompt', '')
            session_id = message.get('session_id', 'default')

            logger.info(f"Inference request {request_id}: {prompt[:50]}...")

            strategy = self._select_optimal_strategy(prompt)
            
            if strategy == PipelineStrategy.SINGLE_GPU:
                self._execute_single_gpu_inference(client_socket, request_id, prompt, session_id)
            elif strategy == PipelineStrategy.HYBRID_PARALLEL:
                self._execute_hybrid_inference(client_socket, request_id, prompt, session_id)
            elif strategy == PipelineStrategy.PIPELINE_PARALLEL:
                self._execute_pipeline_inference(client_socket, request_id, prompt, session_id)
            elif strategy == PipelineStrategy.MASSIVE_SCALE:
                self._execute_massive_scale_inference(client_socket, request_id, prompt, session_id)
            else:
                self._execute_adaptive_inference(client_socket, request_id, prompt, session_id)

        except Exception as e:
            logger.error(f"Inference request error: {e}")
            self._send_error(client_socket, message.get('request_id', 'unknown'), str(e))

    def _select_optimal_strategy(self, prompt: str) -> PipelineStrategy:
        """Seleziona la strategia ottimale basata su risorse e complessità"""
        available_gpus = len(self.gpu_resources)
        prompt_complexity = self._analyze_prompt_complexity(prompt)
        
        logger.info(f"Strategy selection: {available_gpus} GPUs, complexity: {prompt_complexity}")
        
        if available_gpus == 0:
            raise Exception("No available workers")
        elif available_gpus == 1:
            return PipelineStrategy.SINGLE_GPU
        elif available_gpus == 2:
            return PipelineStrategy.HYBRID_PARALLEL
        elif 3 <= available_gpus <= 8:
            return PipelineStrategy.PIPELINE_PARALLEL
        elif available_gpus >= 9:
            return PipelineStrategy.MASSIVE_SCALE
        else:
            return PipelineStrategy.TENSOR_PARALLEL

    def _analyze_prompt_complexity(self, prompt: str) -> str:
        """Analizza la complessità del prompt"""
        word_count = len(prompt.split())
        has_complex_questions = any(term in prompt.lower() for term in 
                                  ['explain', 'analyze', 'compare', 'describe in detail'])
        requires_reasoning = any(term in prompt.lower() for term in
                               ['why', 'how', 'what if', 'consequence'])
        
        if word_count > 200 or (has_complex_questions and requires_reasoning):
            return "very_high"
        elif word_count > 100 or has_complex_questions:
            return "high"
        elif word_count > 50:
            return "medium"
        else:
            return "low"

    def _execute_hybrid_inference(self, client_socket, request_id, prompt, session_id):
        """Esegui inferenza ibrida OTTIMIZZATA per risposte complete"""
        try:
            available_workers = list(self.worker_registry.keys())
            
            if len(available_workers) < 2:
                logger.warning("Worker insufficienti, uso single GPU")
                self._execute_single_gpu_inference(client_socket, request_id, prompt, session_id)
                return
            
            worker1, worker2 = available_workers[:2]
            
            input_task = {
                'type': 'model_shard_task',
                'request_id': request_id,
                'shard_id': 'shard_0',
                'shard_config': {
                    'type': 'input_processor', 
                    'role': 'process_input',
                    'generation_params': {
                        'max_length': 1024,
                        'min_length': 50,
                        'temperature': 0.7
                    }
                },
                'prompt': prompt,
                'total_shards': 2,
                'current_shard': 0,
                'is_final_shard': False,
                'timestamp': time.time()
            }
            
            output_task = {
                'type': 'model_shard_task', 
                'request_id': request_id,
                'shard_id': 'shard_1',
                'shard_config': {
                    'type': 'output_generator',
                    'role': 'generate_output', 
                    'generation_params': {
                        'max_length': 2048,
                        'min_length': 200,
                        'temperature': 0.8,
                        'do_sample': True
                    }
                },
                'prompt': prompt,
                'total_shards': 2,
                'current_shard': 1,
                'is_final_shard': True,
                'timestamp': time.time()
            }
            
            self._send_to_worker(worker1, input_task)
            self._send_to_worker(worker2, output_task)
            
            self.active_pipelines[request_id] = {
                'client_socket': client_socket,
                'request_id': request_id,
                'session_id': session_id,
                'strategy': 'hybrid_parallel',
                'start_time': time.time(),
                'workers_used': [worker1, worker2],
                'received_responses': 0,
                'total_expected': 2,
                'results': {},
                'completed_stages': set(),
                'stages': [
                    {'stage_id': 'shard_0', 'is_final_stage': False},
                    {'stage_id': 'shard_1', 'is_final_stage': True}
                ]
            }
            
            logger.info(f"Inferenza ibrida avviata per {request_id}")
            
        except Exception as e:
            logger.error(f"Inferenza ibrida fallita: {e}")
            self._send_error(client_socket, request_id, f"Inferenza ibrida fallita: {e}")

    def _execute_single_gpu_inference(self, client_socket, request_id, prompt, session_id):
        """Esegui inferenza single GPU"""
        try:
            available_gpus = list(self.gpu_resources.keys())
            if not available_gpus:
                self._send_error(client_socket, request_id, "No workers available")
                return

            gpu_id = available_gpus[0]
            
            task = {
                'type': 'inference_task',
                'request_id': request_id,
                'prompt': prompt,
                'session_id': session_id,
                'timestamp': time.time()
            }
            self._send_to_worker(gpu_id, task)

            self.active_pipelines[request_id] = {
                'client_socket': client_socket,
                'request_id': request_id, 
                'session_id': session_id,
                'strategy': 'single',
                'start_time': time.time(),
                'workers_used': [gpu_id],
                'received_responses': 0,
                'total_expected': 1,
                'results': {},
                'completed_stages': set(),
                'stages': [
                    {'stage_id': 'single_stage', 'is_final_stage': True}
                ]
            }

            logger.info(f"Single GPU inference started for {request_id}")

        except Exception as e:
            logger.error(f"Single GPU inference failed: {e}")
            self._send_error(client_socket, request_id, f"Single GPU inference failed: {e}")

    def _execute_pipeline_inference(self, client_socket, request_id, prompt, session_id):
        """Esegui pipeline inference"""
        self._execute_hybrid_inference(client_socket, request_id, prompt, session_id)

    def _execute_massive_scale_inference(self, client_socket, request_id, prompt, session_id):
        """Esegui massive scale inference"""
        self._execute_hybrid_inference(client_socket, request_id, prompt, session_id)

    def _execute_adaptive_inference(self, client_socket, request_id, prompt, session_id):
        """Esegui adaptive inference"""
        self._execute_hybrid_inference(client_socket, request_id, prompt, session_id)

    def _handle_pipeline_response(self, client_socket: socket.socket, address: tuple, message: dict):
        """Gestisce risposte dalla pipeline"""
        try:
            pipeline_id = message.get('pipeline_id')
            stage_id = message.get('stage_id')
            result = message.get('result')
            worker_id = message.get('worker_id')
            
            if pipeline_id not in self.active_pipelines:
                logger.warning(f"Unknown pipeline: {pipeline_id}")
                return
                
            pipeline = self.active_pipelines[pipeline_id]
            pipeline['completed_stages'].add(stage_id)
            pipeline['results'][stage_id] = result
            
            if self._is_pipeline_completed(pipeline):
                self._finalize_pipeline(pipeline_id)
                
        except Exception as e:
            logger.error(f"Pipeline response error: {e}")

    def _is_pipeline_completed(self, pipeline: dict) -> bool:
        """Verifica se tutti gli stage della pipeline sono completati"""
        completed = pipeline['completed_stages']
        total_stages = len(pipeline['stages'])
        return len(completed) == total_stages

    def _finalize_pipeline(self, pipeline_id: str):
        """Finalizza pipeline e rilascia risorse"""
        try:
            pipeline = self.active_pipelines[pipeline_id]
            
            final_result = self._combine_pipeline_results(pipeline)
            
            if 'initial_loads' in pipeline:
                for worker_id, initial_load in pipeline['initial_loads'].items():
                    if worker_id in self.gpu_resources:
                        current_load = self.gpu_resources[worker_id].current_load
                        new_load = max(0.0, current_load - 0.3)
                        self.update_gpu_load(worker_id, new_load)
                        logger.info(f"Released load for {worker_id}: {current_load:.2f} -> {new_load:.2f}")
            
            response = {
                'type': 'inference_response',
                'request_id': pipeline['request_id'],
                'result': final_result,
                'pipeline_id': pipeline_id,
                'workers_used': list(pipeline.get('gpu_allocation', {}).keys()),
                'processing_time': time.time() - pipeline['start_time'],
                'final_loads': {
                    worker_id: self.gpu_resources[worker_id].current_load 
                    for worker_id in pipeline.get('workers_used', [])
                    if worker_id in self.gpu_resources
                },
                'timestamp': time.time()
            }
            
            self._send_to_client(pipeline['client_socket'], response)
            logger.info(f"Pipeline completed with load balancing")
            
            del self.active_pipelines[pipeline_id]
            
        except Exception as e:
            logger.error(f"Pipeline finalization error: {e}")

    def _combine_pipeline_results(self, pipeline):
        """Combina risultati della pipeline"""
        try:
            results = list(pipeline['results'].values())
            if results:
                return results[-1]
            return "Nessun risultato disponibile"
        except Exception as e:
            logger.error(f"Pipeline results combination error: {e}")
            return "Errore nella combinazione dei risultati"

    def _handle_model_shard_response(self, client_socket: socket.socket, address: tuple, message: dict):
        """Gestisce model_shard_response"""
        try:
            request_id = message.get('request_id', 'unknown')
            
            if request_id in self.active_pipelines:
                pipeline = self.active_pipelines[request_id]
                pipeline['received_responses'] += 1
                pipeline['results'][message.get('worker_id', 'unknown')] = message.get('result', '')
                
                if pipeline['received_responses'] >= pipeline['total_expected']:
                    final_result = self._combine_hybrid_results(pipeline)
                    
                    response = {
                        'type': 'inference_response',
                        'request_id': request_id,
                        'result': final_result,
                        'worker_id': 'hybrid_system',
                        'strategy': 'hybrid_parallel',
                        'workers_used': pipeline['workers_used'],
                        'session_id': pipeline.get('session_id', 'default'),
                        'timestamp': time.time()
                    }
                    
                    self._send_to_client(pipeline['client_socket'], response)
                    logger.info(f"Hybrid inference completed for {request_id}")
                    
                    del self.active_pipelines[request_id]
                    
        except Exception as e:
            logger.error(f"Model shard response error: {e}")

    def _combine_hybrid_results(self, pipeline):
        """Combina risultati hybrid"""
        try:
            results = list(pipeline['results'].values())
            
            if not results:
                return "Elaborazione completata ma nessun risultato ricevuto"
            
            logger.info(f"Combining {len(results)} results: {list(pipeline['results'].keys())}")
            
            best_result = ""
            max_length = 0
            
            for worker_id, result in pipeline['results'].items():
                if result and isinstance(result, str):
                    if any(msg in result for msg in ['INPUT_PROCESSED', 'PROCESSED', 'SHARD_COMPLETE']):
                        logger.info(f"Skipping system message from {worker_id}")
                        continue
                    
                    result_length = len(result.strip())
                    if result_length > max_length and result_length > 20:
                        best_result = result
                        max_length = result_length
                        logger.info(f"Better result found from {worker_id}: {result_length} chars")
            
            if best_result:
                logger.info(f"Using result with {max_length} characters")
                return best_result
            
            for result in reversed(results):
                if result and len(result.strip()) > 20:
                    logger.info(f"Using fallback result: {len(result.strip())} chars")
                    return result
            
            final_fallback = " ".join([str(r) for r in results if r and len(str(r).strip()) > 5])
            if final_fallback:
                logger.info(f"Using concatenated result: {len(final_fallback)} chars")
                return final_fallback
                
            return "Elaborazione completata con successo - Richiedi più dettagli se necessario"
            
        except Exception as e:
            logger.error(f"Results combination error: {e}")
            return "Risposta generata dal sistema - Si prega di riformulare la domanda per maggiori dettagli"

    def _handle_worker_response(self, client_socket: socket.socket, address: tuple, message: dict):
        """Gestisce worker_response"""
        try:
            request_id = message.get('request_id', 'unknown')
            
            if request_id in self.active_pipelines:
                pipeline = self.active_pipelines[request_id]
                
                if pipeline.get('strategy') == 'single':
                    response = {
                        'type': 'inference_response',
                        'request_id': request_id,
                        'result': message.get('result', ''),
                        'worker_id': message.get('worker_id', 'unknown'),
                        'strategy': 'single',
                        'timestamp': time.time()
                    }
                    
                    self._send_to_client(pipeline['client_socket'], response)
                    logger.info(f"Single GPU response sent for {request_id}")
                    
                    del self.active_pipelines[request_id]
                    return
            
            pipeline_response = {
                'type': 'pipeline_response',
                'pipeline_id': f"single_{request_id}",
                'stage_id': 'single_stage',
                'worker_id': message.get('worker_id', 'unknown'),
                'result': message.get('result', ''),
                'is_final_stage': True,
                'timestamp': time.time()
            }
            
            self._handle_pipeline_response(None, None, pipeline_response)
            
        except Exception as e:
            logger.error(f"Worker response error: {e}")

    def _handle_worker_error(self, client_socket: socket.socket, address: tuple, message: dict):
        """Gestisce errori dai worker"""
        try:
            worker_id = message.get('worker_id', 'unknown')
            error_msg = message.get('error', 'Unknown error')
            logger.warning(f"Worker {worker_id} error: {error_msg}")
            
            if client_socket and 'request_id' in message:
                self._send_error(client_socket, message['request_id'], f"Worker error: {error_msg}")
                
        except Exception as e:
            logger.error(f"Worker error handling error: {e}")

    def _handle_heartbeat(self, client_socket: socket.socket, address: tuple, message: dict):
        """Gestisce heartbeat dei worker"""
        worker_id = message.get('worker_id')
        if worker_id in self.worker_registry:
            self.worker_registry[worker_id]['last_heartbeat'] = time.time()
            if worker_id in self.gpu_resources:
                self.gpu_resources[worker_id].last_heartbeat = time.time()

    def _handle_system_stats_request(self, client_socket: socket.socket, address: tuple, message: dict):
        """Fornisce statistiche di sistema dettagliate"""
        stats = self._get_detailed_system_stats()
        response = {
            'type': 'system_stats',
            'stats': stats,
            'timestamp': time.time()
        }
        self._send_to_client(client_socket, response)

    def _get_detailed_system_stats(self) -> Dict[str, Any]:
        """Restituisce statistiche di sistema dettagliate"""
        total_gpus = len(self.gpu_resources)
        active_pipelines = len(self.active_pipelines)
        
        avg_load = sum(gpu.current_load for gpu in self.gpu_resources.values()) / total_gpus if total_gpus > 0 else 0
        
        gpu_types = {}
        for gpu in self.gpu_resources.values():
            gpu_type = gpu.node_type.value
            gpu_types[gpu_type] = gpu_types.get(gpu_type, 0) + 1
        
        return {
            'total_gpus': total_gpus,
            'active_pipelines': active_pipelines,
            'average_load': avg_load,
            'gpu_distribution': gpu_types,
            'performance_score': self._calculate_cluster_performance(),
            'coordinator_id': self.coordinator_id,
            'uptime': time.time() - getattr(self, 'start_time', time.time()),
            'memory_usage_mb': self._get_memory_usage(),
            'requests_processed': getattr(self, 'requests_processed', 0)
        }

    def _calculate_cluster_performance(self) -> float:
        """Calcola performance complessiva del cluster"""
        if not self.gpu_resources:
            return 0.0
        total_score = sum(gpu.performance_score for gpu in self.gpu_resources.values())
        return total_score / len(self.gpu_resources)

    def _get_memory_usage(self) -> int:
        """Stima uso memoria"""
        import psutil
        return psutil.Process().memory_info().rss // 1024 // 1024

    def _cleanup_loop(self):
        """Pulizia periodica delle risorse"""
        while self.running:
            time.sleep(60)
            try:
                self._cleanup_inactive_workers()
                self._cleanup_expired_sessions()
                self._cleanup_stalled_pipelines()
            except Exception as e:
                logger.error(f"Cleanup error: {e}")

    def _cleanup_inactive_workers(self):
        """Rimuove worker inattivi"""
        current_time = time.time()
        inactive_workers = []
        
        for worker_id, worker_info in self.worker_registry.items():
            if current_time - worker_info['last_heartbeat'] > 120:
                inactive_workers.append(worker_id)
        
        for worker_id in inactive_workers:
            self._remove_worker(worker_id)
            logger.info(f"Removed inactive worker: {worker_id}")

    def _remove_worker(self, worker_id: str):
        """Rimuove worker dal sistema"""
        if worker_id in self.worker_registry:
            try:
                socket = self.worker_registry[worker_id]['socket']
                socket.close()
            except:
                pass
            finally:
                del self.worker_registry[worker_id]
                if worker_id in self.gpu_resources:
                    del self.gpu_resources[worker_id]

    def _cleanup_expired_sessions(self):
        """Pulizia sessioni scadute"""
        try:
            if hasattr(self, 'session_manager'):
                current_time = time.time()
                sessions_to_remove = []
                
                for session_id, session_data in self.session_manager.sessions.items():
                    if current_time - session_data['last_activity'] > self.session_manager.session_timeout:
                        sessions_to_remove.append(session_id)
                
                for session_id in sessions_to_remove:
                    del self.session_manager.sessions[session_id]
                    logger.info(f"Expired session cleaned: {session_id}")
                    
        except Exception as e:
            logger.error(f"Session cleanup error: {e}")

    def _cleanup_stalled_pipelines(self):
        """Pulizia pipeline bloccate"""
        try:
            current_time = time.time()
            stalled_pipelines = []
            
            for pipeline_id, pipeline_data in self.active_pipelines.items():
                elapsed = current_time - pipeline_data['start_time']
                if elapsed > 300:
                    stalled_pipelines.append(pipeline_id)
            
            for pipeline_id in stalled_pipelines:
                logger.warning(f"Removing stalled pipeline: {pipeline_id}")
                try:
                    self._send_error(
                        self.active_pipelines[pipeline_id]['client_socket'],
                        pipeline_id,
                        "Pipeline timeout - processing took too long"
                    )
                except:
                    pass
                finally:
                    del self.active_pipelines[pipeline_id]
                    
        except Exception as e:
            logger.error(f"Stalled pipeline cleanup error: {e}")

    def _cleanup_client_connection(self, client_socket: socket.socket, address: tuple):
        """Pulizia connessione client"""
        workers_to_remove = []
        for worker_id, worker_info in self.worker_registry.items():
            if worker_info['socket'] == client_socket:
                workers_to_remove.append(worker_id)
        
        for worker_id in workers_to_remove:
            self._remove_worker(worker_id)
        
        try:
            client_socket.close()
        except:
            pass
            
        logger.info(f"Client disconnected: {address}")

    def _send_to_client(self, client_socket: socket.socket, message: dict):
        """Invia messaggio a client"""
        try:
            message_str = json.dumps(message) + '\n'
            client_socket.send(message_str.encode('utf-8'))
        except Exception as e:
            logger.error(f"Send error: {e}")

    def _send_to_worker(self, worker_id: str, message: dict):
        """Invia messaggio a worker specifico"""
        if worker_id in self.worker_registry:
            self._send_to_client(self.worker_registry[worker_id]['socket'], message)
        else:
            logger.error(f"Worker not found: {worker_id}")

    def _send_error(self, client_socket: socket.socket, request_id: str, error_msg: str):
        """Invia messaggio di errore"""
        error_response = {
            'type': 'error',
            'request_id': request_id,
            'error': error_msg,
            'timestamp': time.time()
        }
        self._send_to_client(client_socket, error_response)

    def _register_with_central_server(self):
        """Registrazione MIGLIORATA con tutti i worker"""
        try:
            import requests
            
            # Calcola statistiche REALI
            total_gpu_memory = sum(gpu.memory_gb for gpu in self.gpu_resources.values()) * 1024
            performance_score = self._calculate_cluster_performance()
            
            coord_data = {
                'coordinator_id': self.coordinator_id,
                'host': self.host,
                'port': self.port,
                'workers_capacity': 1000,
                'current_workers': len(self.worker_registry),
                'total_models': len(self.active_pipelines),
                'performance_score': performance_score,
                'total_gpu_memory': total_gpu_memory,
                'capabilities': {
                    'max_gpus': 1000,
                    'pipeline_parallelism': True,
                    'tensor_parallelism': True
                }
            }
            
            logger.info(f"📤 Registering coordinator with {len(self.worker_registry)} workers...")
            
            response = requests.post(
                "http://localhost:5000/api/register_coordinator",
                json=coord_data,
                timeout=10
            )
            
            if response.status_code == 200:
                logger.info("✅ Coordinator registered successfully!")
                # REGISTRA TUTTI I WORKER
                self._register_all_workers_with_central_server()  # ⚠️ QUESTO METODO NON ESISTE!
            else:
                logger.error(f"❌ Coordinator registration failed: {response.text}")
                
        except Exception as e:
            logger.error(f"❌ Registration error: {e}")

    def _monitoring_loop(self):
        """Loop di monitoring con metriche di carico"""
        while self.running:
            time.sleep(30)
            try:
                stats = self._get_detailed_system_stats()
                
                load_distribution = {}
                for gpu in self.gpu_resources.values():
                    load_level = "high" if gpu.current_load > 0.7 else "medium" if gpu.current_load > 0.4 else "low"
                    load_distribution[load_level] = load_distribution.get(load_level, 0) + 1
                
                logger.info(f"Cluster Stats: {stats['total_gpus']} GPUs, "
                        f"Avg Load: {stats['average_load']:.2f}, "
                        f"Load Dist: {load_distribution}")
                
                high_load_gpus = [gpu.worker_id for gpu in self.gpu_resources.values() 
                                if gpu.current_load > 0.85]
                if high_load_gpus:
                    logger.warning(f"High load GPUs: {high_load_gpus}")
                    
            except Exception as e:
                logger.error(f"Monitoring error: {e}")

    def _auto_scaling_loop(self):
        """Loop di auto-scaling"""
        while self.running:
            time.sleep(60)
            try:
                self.auto_scaler.check_and_scale()
            except Exception as e: 
                logger.error(f"Auto-scaling error: {e}")

    def update_gpu_load(self, worker_id: str, new_load: float):
        """Aggiorna il carico di una GPU in tempo reale"""
        if worker_id in self.gpu_resources:
            self.gpu_resources[worker_id].current_load = max(0.0, min(1.0, new_load))
            logger.info(f"Updated {worker_id} load: {new_load:.2f}")

    def update_gpu_load_from_task(self, worker_id: str, task_type: str, duration: float):
        """Aggiorna carico basato sul tipo di task e durata"""
        if worker_id in self.gpu_resources:
            load_factors = {
                'inference_task': 0.3,
                'model_shard_task': 0.6,
                'pipeline_task': 0.8
            }
            base_load = load_factors.get(task_type, 0.5)
            
            time_factor = min(duration / 60.0, 1.0)
            new_load = min(0.95, self.gpu_resources[worker_id].current_load + base_load * time_factor)
            
            self.gpu_resources[worker_id].current_load = new_load
            logger.info(f"{worker_id} load updated to {new_load:.2f} after {task_type}")

    def stop(self):
        """Ferma il coordinator gracefulmente"""
        self.running = False
        
        for worker_info in self.worker_registry.values():
            try:
                worker_info['socket'].close()
            except:
                pass
                
        if self.socket:
            self.socket.close()
            
        self.thread_pool.shutdown(wait=True)
        
        logger.info("Scalable Coordinator stopped gracefully")
class CentralServerIntegration:
    def __init__(self, coordinator):
        self.coordinator = coordinator
        self.central_server_url = "http://localhost:5000"
        self.registration_interval = 30
        print(f"🔧 DEBUG: CentralServerIntegration initialized with URL: {self.central_server_url}")

    def start_registration_loop(self):
        """Avvia loop di registrazione periodica"""
        def registration_loop():
            print(f"🔄 DEBUG: Registration loop STARTED for {self.coordinator.coordinator_id}")
            cycle_count = 0
            while getattr(self.coordinator, 'running', True):
                try:
                    cycle_count += 1
                    print(f"🔄 DEBUG: Registration cycle #{cycle_count} starting...")
                    self.register_with_central_server()
                    print(f"✅ DEBUG: Registration cycle #{cycle_count} completed")
                    print(f"⏰ DEBUG: Waiting {self.registration_interval}s for next registration")
                    time.sleep(self.registration_interval)
                except Exception as e:
                    print(f"❌ DEBUG: Registration error in cycle #{cycle_count}: {e}")
                    import traceback
                    print(f"🔍 DEBUG: Full traceback: {traceback.format_exc()}")
                    time.sleep(10)  # Aspetta 10 secondi prima di riprovare

        thread = threading.Thread(target=registration_loop, daemon=True)
        thread.start()
        print(f"✅ DEBUG: Registration thread started with ID: {thread.ident}")

    def register_with_central_server(self):
        """Registra coordinator con Central Server"""
        try:
            import requests
            
            # Calcola statistiche REALI
            total_gpu_memory = sum(gpu.memory_gb for gpu in self.coordinator.gpu_resources.values()) * 1024
            performance_score = self.coordinator._calculate_cluster_performance()

            coord_data = {
                'coordinator_id': self.coordinator.coordinator_id,
                'host': self.coordinator.host,
                'port': self.coordinator.port,
                'workers_capacity': 1000,
                'current_workers': len(self.coordinator.worker_registry),
                'total_models': len(self.coordinator.active_pipelines),
                'performance_score': performance_score,
                'total_gpu_memory': total_gpu_memory,
                'capabilities': {
                    'max_gpus': 1000,
                    'pipeline_parallelism': True,
                    'tensor_parallelism': True
                }
            }

            print(f"📤 DEBUG: Sending coordinator registration to {self.central_server_url}")
            print(f"📊 DEBUG: Workers: {len(self.coordinator.worker_registry)}, Memory: {total_gpu_memory}MB")

            response = requests.post(
                f"{self.central_server_url}/api/register_coordinator",
                json=coord_data,
                timeout=10
            )

            if response.status_code == 200:
                print(f"✅ DEBUG: Coordinator {self.coordinator.coordinator_id} registered successfully!")
                # Registra anche tutti i worker
                self._register_all_workers()
            else:
                print(f"❌ DEBUG: Coordinator registration failed: {response.status_code} - {response.text}")

        except Exception as e:
            print(f"❌ DEBUG: Registration exception: {e}")

    def _register_all_workers(self):
        """Registra TUTTI i worker con Central Server"""
        try:
            import requests

            worker_count = len(self.coordinator.worker_registry)
            print(f"🔧 DEBUG: Registering {worker_count} workers...")

            for worker_id, worker_info in self.coordinator.worker_registry.items():
                gpu_resource = worker_info.get('resource')
                
                if gpu_resource:
                    node_data = {
                        'node_id': worker_id,
                        'coordinator_id': self.coordinator.coordinator_id,
                        'gpu_name': gpu_resource.gpu_name,
                        'gpu_memory': gpu_resource.memory_gb * 1024,  # Converti in MB
                        'performance_score': gpu_resource.performance_score,
                        'node_type': gpu_resource.node_type.value,
                        'status': 'active'
                    }

                    print(f"📤 DEBUG: Registering worker {worker_id} - {gpu_resource.gpu_name}")

                    response = requests.post(
                        f"{self.central_server_url}/api/register_node",
                        json=node_data,
                        timeout=5
                    )

                    if response.status_code == 200:
                        print(f"✅ DEBUG: Worker {worker_id} registered successfully")
                    else:
                        print(f"⚠️ DEBUG: Worker registration failed: {response.status_code}")

            print(f"🎉 DEBUG: All workers registration completed")

        except Exception as e:
            print(f"❌ DEBUG: Worker registration error: {e}")

class Coordinator(ScalableCoordinator):
    """
    Classe alias per mantenere compatibilità con il codice esistente
    """
    
    def __init__(self, host: str, port: int, blockchain):
        super().__init__(host, port, blockchain)
        self.message_processor = MessageProcessor(self)
        logger.info("Legacy Coordinator alias initialized with full compatibility")

        if __name__ == "__main__":
            logging.basicConfig(
                level=logging.INFO,
                format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
            )
            
            coordinator = ScalableCoordinator(host="0.0.0.0", port=8765)
            try:
                coordinator.start()
            except KeyboardInterrupt:
                coordinator.stop()



