import socket
import threading
import json
import time
import logging
import asyncio
import uuid
from typing import Dict, List, Any, Tuple, Optional
from dataclasses import dataclass
from enum import Enum
import heapq
from concurrent.futures import ThreadPoolExecutor

logger = logging.getLogger(__name__)

class NodeType(Enum):
    HIGH_PERFORMANCE = "high_perf"
    MEDIUM_PERFORMANCE = "medium" 
    LOW_PERFORMANCE = "low_perf"
    EDGE_NODE = "edge"

class PipelineStrategy(Enum):
    SINGLE_GPU = "single"
    HYBRID_PARALLEL = "hybrid"
    PIPELINE_PARALLEL = "pipeline"
    TENSOR_PARALLEL = "tensor"
    MASSIVE_SCALE = "massive"

@dataclass
class GPUResource:
    worker_id: str
    gpu_name: str
    memory_gb: float
    performance_score: float
    node_type: NodeType
    current_load: float
    assigned_layers: List[int]
    last_heartbeat: float

@dataclass
class ModelShard:
    shard_id: str
    gpu_id: str
    layers: List[int]
    attention_heads: List[int]
    parameters_count: int
    memory_required: float

@dataclass
class PipelineStage:
    stage_id: str
    stage_type: str
    assigned_gpus: List[str]
    input_data: Any
    is_final_stage: bool
    dependencies: List[str]

class HierarchicalLoadBalancer:
    def __init__(self):
        self.gpu_resources: Dict[str, GPUResource] = {}
        self.performance_metrics = {}
        self.load_distribution = {}
        
    def add_gpu_resource(self, gpu_resource: GPUResource):
        self.gpu_resources[gpu_resource.worker_id] = gpu_resource
        
    def get_optimal_gpu_allocation(self, total_gpus_needed: int, complexity: str) -> List[GPUResource]:
        """Seleziona le GPU ottimali basandosi su performance e carico"""
        available_gpus = [gpu for gpu in self.gpu_resources.values() 
                         if gpu.current_load < 0.8]  # Solo GPU con carico < 80%
        
        if not available_gpus:
            return []
            
        # Ordina per performance score (migliori prima) e poi per carico (minore prima)
        sorted_gpus = sorted(available_gpus, 
                           key=lambda x: (-x.performance_score, x.current_load))
        
        # Per alta complessità, preferisci GPU performanti
        if complexity == "high":
            return sorted_gpus[:total_gpus_needed]
        # Per media complessità, mix bilanciato
        elif complexity == "medium":
            return self._get_balanced_mix(sorted_gpus, total_gpus_needed)
        # Per bassa complessità, qualsiasi GPU disponibile
        else:
            return sorted_gpus[:total_gpus_needed]
    
    def _get_balanced_mix(self, gpus: List[GPUResource], count: int) -> List[GPUResource]:
        """Restituisce un mix bilanciato di GPU"""
        if len(gpus) <= count:
            return gpus
            
        # Prendi le migliori e alcune medie per bilanciamento
        high_perf = [g for g in gpus if g.node_type == NodeType.HIGH_PERFORMANCE][:count//2]
        medium_perf = [g for g in gpus if g.node_type == NodeType.MEDIUM_PERFORMANCE][:count - len(high_perf)]
        
        return high_perf + medium_perf

class DynamicModelSharding:
    def __init__(self):
        self.sharding_cache = {}
        
    def calculate_optimal_sharding(self, total_gpus: int, model_layers: int = 32) -> List[ModelShard]:
        """Calcola sharding ottimale per qualsiasi numero di GPU"""
        cache_key = f"{total_gpus}_{model_layers}"
        if cache_key in self.sharding_cache:
            return self.sharding_cache[cache_key].copy()
        
        shards = []
        layers_per_gpu = max(model_layers // total_gpus, 1)
        
        for gpu_idx in range(total_gpus):
            start_layer = gpu_idx * layers_per_gpu
            end_layer = start_layer + layers_per_gpu if gpu_idx < total_gpus - 1 else model_layers
            
            shard = ModelShard(
                shard_id=f"shard_{gpu_idx}",
                gpu_id=f"gpu_{gpu_idx}",
                layers=list(range(start_layer, end_layer)),
                attention_heads=self._assign_attention_heads(gpu_idx, total_gpus),
                parameters_count=7000000000 // total_gpus,  # 7B params
                memory_required=4.0 / total_gpus  # GB per shard
            )
            shards.append(shard)
        
        self.sharding_cache[cache_key] = shards
        return shards.copy()
    
    def _assign_attention_heads(self, gpu_idx: int, total_gpus: int) -> List[int]:
        """Assegna attention heads in modo bilanciato"""
        total_heads = 32
        heads_per_gpu = max(total_heads // total_gpus, 1)
        start_head = gpu_idx * heads_per_gpu
        end_head = start_head + heads_per_gpu if gpu_idx < total_gpus - 1 else total_heads
        return list(range(start_head, end_head))

class PipelineManager:
    def __init__(self):
        self.pipeline_templates = {}
        self.performance_history = {}
        
    def create_optimal_pipeline(self, available_gpus: int, prompt_complexity: str) -> List[PipelineStage]:
        """Crea pipeline ottimale basata su risorse disponibili"""
        
        if available_gpus == 1:
            return self._create_single_gpu_pipeline()
        elif available_gpus == 2:
            return self._create_hybrid_pipeline()
        elif 3 <= available_gpus <= 8:
            return self._create_small_pipeline(available_gpus)
        elif 9 <= available_gpus <= 32:
            return self._create_medium_pipeline(available_gpus)
        elif 33 <= available_gpus <= 100:
            return self._create_large_pipeline(available_gpus)
        else:
            return self._create_massive_pipeline(available_gpus)
    
    def _create_single_gpu_pipeline(self) -> List[PipelineStage]:
        return [PipelineStage(
            stage_id="full_inference",
            stage_type="complete_generation",
            assigned_gpus=["gpu_0"],
            input_data=None,
            is_final_stage=True,
            dependencies=[]
        )]
    
    def _create_hybrid_pipeline(self) -> List[PipelineStage]:
        return [
            PipelineStage(
                stage_id="input_processing",
                stage_type="input_processor",
                assigned_gpus=["gpu_0"],
                input_data=None,
                is_final_stage=False,
                dependencies=[]
            ),
            PipelineStage(
                stage_id="output_generation",
                stage_type="output_generator", 
                assigned_gpus=["gpu_1"],
                input_data=None,
                is_final_stage=True,
                dependencies=["input_processing"]
            )
        ]
    
    def _create_small_pipeline(self, gpu_count: int) -> List[PipelineStage]:
        stages = []
        layers_per_stage = max(32 // gpu_count, 1)
        
        for i in range(gpu_count):
            stages.append(PipelineStage(
                stage_id=f"stage_{i}",
                stage_type=f"layers_{i*layers_per_stage}_{min((i+1)*layers_stage, 32)}",
                assigned_gpus=[f"gpu_{i}"],
                input_data=None,
                is_final_stage=(i == gpu_count - 1),
                dependencies=[f"stage_{i-1}"] if i > 0 else []
            ))
        
        return stages
    
    def _create_massive_pipeline(self, gpu_count: int) -> List[PipelineStage]:
        """Pipeline per 100+ GPU con tensor parallelism"""
        stages = []
        
        # Configurazione ottimale per massive scale
        pipeline_depth = min(gpu_count // 8, 16)  # Max 16 stadi pipeline
        tensor_parallelism = gpu_count // pipeline_depth
        
        logger.info(f"🚀 Creating massive pipeline: depth={pipeline_depth}, tensor_parallelism={tensor_parallelism}")
        
        for stage_idx in range(pipeline_depth):
            gpu_group = [f"gpu_{stage_idx * tensor_parallelism + i}" 
                        for i in range(tensor_parallelism)]
            
            stages.append(PipelineStage(
                stage_id=f"pipeline_stage_{stage_idx}",
                stage_type=f"tensor_parallel_block_{stage_idx}",
                assigned_gpus=gpu_group,
                input_data=None,
                is_final_stage=(stage_idx == pipeline_depth - 1),
                dependencies=[f"pipeline_stage_{stage_idx-1}"] if stage_idx > 0 else []
            ))
        
        return stages

class ScalableCoordinator:
    """
    Coordinator completamente riscritto per scalabilità massiva (2-1000+ GPU)
    Architettura gerarchica senza colli di bottiglia
    """
    
    def __init__(self, host: str = "0.0.0.0", port: int = 8765, blockchain=None):
        self.host = host
        self.port = port
        self.blockchain = blockchain
        
        # Core components
        self.load_balancer = HierarchicalLoadBalancer()
        self.model_sharding = DynamicModelSharding()
        self.pipeline_manager = PipelineManager()
        
        # Gestione risorse
        self.worker_registry: Dict[str, Dict] = {}
        self.gpu_resources: Dict[str, GPUResource] = {}
        self.active_pipelines: Dict[str, Any] = {}
        self.session_manager = SessionManager()
        
        # Gestione connessioni
        self.socket = None
        self.running = False
        self.thread_pool = ThreadPoolExecutor(max_workers=100)
        
        # Metriche e monitoring
        self.performance_tracker = PerformanceTracker()
        self.auto_scaler = AutoScaler(self)
        
        # Configurazione
        self.coordinator_id = f"coord_{uuid.uuid4().hex[:8]}"
        self.central_server_url = "https://ailo.site"
        
        logger.info(f"🏗️ Scalable Coordinator {self.coordinator_id} initialized")

    def start(self):
        """Avvia il coordinator con gestione massiva"""
        try:
            self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            self.socket.bind((self.host, self.port))
            self.socket.listen(1000)  # Supporta 1000 connessioni simultanee
            self.socket.settimeout(1.0)
            self.running = True

            # Servizi in background
            self._start_background_services()
            
            logger.info(f"🎯 Scalable Coordinator listening on {self.host}:{self.port}")
            logger.info("🚀 Ready for 1000+ GPU cluster")

            # Main acceptance loop
            self._accept_connections_loop()

        except Exception as e:
            logger.error(f"❌ Coordinator startup failed: {e}")
            raise

    def _start_background_services(self):
        """Avvia tutti i servizi in background"""
        # Thread di cleanup
        threading.Thread(target=self._cleanup_loop, daemon=True).start()
        
        # Thread di monitoring
        threading.Thread(target=self._monitoring_loop, daemon=True).start()
        
        # Thread di auto-scaling
        threading.Thread(target=self._auto_scaling_loop, daemon=True).start()
        
        # Registrazione con server centrale
        threading.Thread(target=self._register_with_central_server, daemon=True).start()

    def _accept_connections_loop(self):
        """Loop principale di accettazione connessioni"""
        while self.running:
            try:
                client_socket, address = self.socket.accept()
                logger.info(f"📡 New connection from {address}")
                
                # Gestisci in thread separato
                self.thread_pool.submit(
                    self._handle_client_connection, 
                    client_socket, 
                    address
                )
                
            except socket.timeout:
                continue
            except Exception as e:
                if self.running:
                    logger.error(f"❌ Connection acceptance error: {e}")


    def update_gpu_load(self, worker_id: str, new_load: float):
        """Aggiorna il carico di una GPU in tempo reale"""
        if worker_id in self.gpu_resources:
            self.gpu_resources[worker_id].current_load = max(0.0, min(1.0, new_load))
            logger.info(f"📊 Updated {worker_id} load: {new_load:.2f}")

    def update_gpu_load_from_task(self, worker_id: str, task_type: str, duration: float):
        """Aggiorna carico basato sul tipo di task e durata"""
        if worker_id in self.gpu_resources:
            # Calcola carico basato su complessità del task
            load_factors = {
                'inference_task': 0.3,
                'model_shard_task': 0.6,
                'pipeline_task': 0.8
            }
            base_load = load_factors.get(task_type, 0.5)
            
            # Aggiungi carico temporaneo basato sulla durata
            time_factor = min(duration / 60.0, 1.0)  # Normalizza a 1 minuto
            new_load = min(0.95, self.gpu_resources[worker_id].current_load + base_load * time_factor)
            
            self.gpu_resources[worker_id].current_load = new_load
            logger.info(f"📈 {worker_id} load updated to {new_load:.2f} after {task_type}")

    def _handle_client_connection(self, client_socket: socket.socket, address: tuple):
        """Gestisce connessione client in thread separato"""
        buffer = ""
        try:
            client_socket.settimeout(1.0)
            
            while self.running:
                try:
                    data = client_socket.recv(65536).decode('utf-8')  # 64KB buffer
                    if not data:
                        break
                        
                    buffer += data
                    
                    # Processa tutti i messaggi completi
                    while '\n' in buffer:
                        line, buffer = buffer.split('\n', 1)
                        if line.strip():
                            self._process_message_async(client_socket, address, line.strip())
                            
                except socket.timeout:
                    continue
                except Exception as e:
                    logger.error(f"❌ Client {address} error: {e}")
                    break
                    
        except Exception as e:
            logger.error(f"❌ Error handling client {address}: {e}")
        finally:
            self._cleanup_client_connection(client_socket, address)

    def _process_message_async(self, client_socket: socket.socket, address: tuple, message_str: str):
        """Processa messaggio in modo asincrono"""
        try:
            message = json.loads(message_str)
            message_type = message.get('type')

            # Dispatch basato sul tipo di messaggio - TUTTI CON 3 PARAMETRI
            handlers = {
                'worker_register': self._handle_worker_registration,
                'inference_request': self._handle_inference_request,
                'pipeline_response': self._handle_pipeline_response,
                'model_shard_response': self._handle_model_shard_response,
                'worker_response': self._handle_worker_response,
                'heartbeat': self._handle_heartbeat,
                'get_system_stats': self._handle_system_stats_request
            }

            handler = handlers.get(message_type)
            if handler:
                # TUTTI GLI HANDLER RICEVONO 3 PARAMETRI
                handler(client_socket, address, message)
            else:
                logger.warning(f"❓ Unknown message type: {message_type}")
                
        except json.JSONDecodeError as e:
            logger.warning(f"❌ Invalid JSON from {address}: {e}")
        except Exception as e:
            logger.error(f"❌ Message processing error: {e}")

    def _handle_worker_registration(self, client_socket: socket.socket, address: tuple, message: dict):
        """Registra worker con analisi avanzata delle capacità"""
        try:
            worker_id = message.get('worker_id', f'worker_{uuid.uuid4().hex[:8]}')
            gpu_info = message.get('gpu_info', {})
            
            # Analisi dettagliata GPU
            gpu_resource = self._analyze_gpu_capabilities(worker_id, gpu_info)
            
            # Aggiorna registry
            self.worker_registry[worker_id] = {
                'socket': client_socket,
                'address': address,
                'gpu_info': gpu_info,
                'last_heartbeat': time.time(),
                'resource': gpu_resource
            }
            
            # Aggiorna load balancer
            self.load_balancer.add_gpu_resource(gpu_resource)
            self.gpu_resources[worker_id] = gpu_resource
            
            logger.info(f"✅ Worker registered: {worker_id}")
            logger.info(f"🎯 GPU: {gpu_resource.gpu_name} | Score: {gpu_resource.performance_score:.2f}")
            logger.info(f"💾 Memory: {gpu_resource.memory_gb}GB | Type: {gpu_resource.node_type.value}")

            # Conferma registrazione
            response = {
                'type': 'registration_confirmed',
                'worker_id': worker_id,
                'coordinator_id': self.coordinator_id,
                'timestamp': time.time()
            }
            self._send_to_client(client_socket, response)
            
        except Exception as e:
            logger.error(f"❌ Worker registration error: {e}")

    def _analyze_gpu_capabilities(self, worker_id: str, gpu_info: dict) -> GPUResource:
        """Analizza approfonditamente le capacità della GPU"""
        gpu_name = gpu_info.get('gpu_name', 'Unknown').lower()
        memory_gb = self._extract_gpu_memory(gpu_info)
        
        # Calcolo score prestazioni
        performance_score = self._calculate_performance_score(gpu_name, memory_gb)
        
        # Determinazione tipo nodo
        node_type = self._determine_node_type(performance_score, memory_gb)
        
        return GPUResource(
            worker_id=worker_id,
            gpu_name=gpu_info.get('gpu_name', 'Unknown'),
            memory_gb=memory_gb,
            performance_score=performance_score,
            node_type=node_type,
            current_load=0.0,
            assigned_layers=[],
            last_heartbeat=time.time()
        )

    def _calculate_performance_score(self, gpu_name: str, memory_gb: float) -> float:
        """Calcola score prestazioni avanzato"""
        score = 1.0
        
        # Punteggio basato su memoria
        if memory_gb >= 16:
            score *= 2.0
        elif memory_gb >= 12:
            score *= 1.7
        elif memory_gb >= 8:  # 🔴 CORREGGI: memory_gb invece di memory_gpu
            score *= 1.3
        elif memory_gb <= 4:
            score *= 0.6
            
        # Punteggio basato su modello GPU
        performance_factors = {
            'rtx 4090': 2.8, 'rtx 4080': 2.4, 'rtx 4070': 2.0,
            'rtx 3090': 2.5, 'rtx 3080': 2.2, 'rtx 3070': 1.8,
            'rtx 3060': 1.5, 'rtx 3050': 1.0,
            'a100': 3.5, 'h100': 4.0, 'v100': 3.0
        }
        
        for model, factor in performance_factors.items():
            if model in gpu_name:
                score *= factor
                break
                
        # Punteggio architettura
        if 'rtx 40' in gpu_name:
            score *= 1.4
        elif 'rtx 30' in gpu_name:
            score *= 1.2
            
        return max(0.5, min(4.0, score))

    def _determine_node_type(self, performance_score: float, memory_gb: float) -> NodeType:
        """Determina il tipo di nodo basato su capacità"""
        if performance_score >= 2.5 and memory_gb >= 12:
            return NodeType.HIGH_PERFORMANCE
        elif performance_score >= 1.5 and memory_gb >= 8:
            return NodeType.MEDIUM_PERFORMANCE
        elif performance_score < 1.0 or memory_gb < 4:
            return NodeType.EDGE_NODE
        else:
            return NodeType.LOW_PERFORMANCE

    def _extract_gpu_memory(self, gpu_info: dict) -> float:
        """Estrae memoria GPU in GB"""
        try:
            memory_str = str(gpu_info.get('total_memory', '0'))
            
            if 'GB' in memory_str:
                return float(memory_str.replace('GB', '').strip())
            elif 'MB' in memory_str:
                return float(memory_str.replace('MB', '').strip()) / 1024
            elif 'MiB' in memory_str:
                return float(memory_str.replace('MiB', '').strip()) / 1024
            elif 'GiB' in memory_str:
                return float(memory_str.replace('GiB', '').strip())
            else:
                return float(memory_str) / (1024**3)  # Assume bytes
                
        except:
            return 4.0  # Default conservativo

    def _handle_inference_request(self, client_socket: socket.socket, address: tuple, message: dict):
        """Gestisce richiesta di inferenza con strategia scalabile"""
        try:
            request_id = message.get('request_id', f'req_{uuid.uuid4().hex[:8]}')
            prompt = message.get('prompt', '')
            session_id = message.get('session_id', 'default')

            logger.info(f"📥 Inference request {request_id}: {prompt[:50]}...")

            # Determina strategia ottimale
            strategy = self._select_optimal_strategy(prompt)
            
            # Esegui con strategia selezionata
            if strategy == PipelineStrategy.SINGLE_GPU:
                self._execute_single_gpu_inference(client_socket, request_id, prompt, session_id)
            elif strategy == PipelineStrategy.HYBRID_PARALLEL:
                self._execute_hybrid_inference(client_socket, request_id, prompt, session_id)
            elif strategy == PipelineStrategy.PIPELINE_PARALLEL:
                self._execute_pipeline_inference(client_socket, request_id, prompt, session_id)
            elif strategy == PipelineStrategy.MASSIVE_SCALE:
                self._execute_massive_scale_inference(client_socket, request_id, prompt, session_id)
            else:
                self._execute_adaptive_inference(client_socket, request_id, prompt, session_id)

        except Exception as e:
            logger.error(f"❌ Inference request error: {e}")
            self._send_error(client_socket, message.get('request_id', 'unknown'), str(e))

    def _select_optimal_strategy(self, prompt: str) -> PipelineStrategy:
        """Seleziona la strategia ottimale basata su risorse e complessità"""
        available_gpus = len(self.gpu_resources)
        prompt_complexity = self._analyze_prompt_complexity(prompt)
        
        if available_gpus == 0:
            raise Exception("No available workers")
        elif available_gpus == 1:
            return PipelineStrategy.SINGLE_GPU
        elif available_gpus == 2:
            return PipelineStrategy.HYBRID_PARALLEL
        elif 3 <= available_gpus <= 8:
            return PipelineStrategy.PIPELINE_PARALLEL
        elif available_gpus >= 100:
            return PipelineStrategy.MASSIVE_SCALE
        else:
            return PipelineStrategy.TENSOR_PARALLEL

    def _analyze_prompt_complexity(self, prompt: str) -> str:
        """Analizza la complessità del prompt"""
        word_count = len(prompt.split())
        has_complex_questions = any(term in prompt.lower() for term in 
                                  ['explain', 'analyze', 'compare', 'describe in detail'])
        requires_reasoning = any(term in prompt.lower() for term in
                               ['why', 'how', 'what if', 'consequence'])
        
        if word_count > 200 or (has_complex_questions and requires_reasoning):
            return "very_high"
        elif word_count > 100 or has_complex_questions:
            return "high"
        elif word_count > 50:
            return "medium"
        else:
            return "low"

    def _execute_massive_scale_inference(self, client_socket, request_id, prompt, session_id):
        """Esegui inferenza massiva per 100+ GPU"""
        try:
            available_gpus = len(self.gpu_resources)
            logger.info(f"🚀 Starting massive scale inference with {available_gpus} GPUs")
            
            # 1. Sharding del modello
            model_shards = self.model_sharding.calculate_optimal_sharding(available_gpus)
            
            # 2. Creazione pipeline
            pipeline_stages = self.pipeline_manager.create_optimal_pipeline(available_gpus, "high")
            
            # 3. Allocazione GPU ottimale
            gpu_allocation = self.load_balancer.get_optimal_gpu_allocation(available_gpus, "very_high")
            
            # 4. Esecuzione distribuita
            self._execute_distributed_pipeline(
                client_socket, request_id, prompt, session_id,
                pipeline_stages, gpu_allocation, model_shards
            )
            
        except Exception as e:
            logger.error(f"❌ Massive scale inference failed: {e}")
            self._send_error(client_socket, request_id, f"Massive scale failed: {e}")

    def _execute_distributed_pipeline(self, client_socket, request_id, prompt, session_id,
                                    pipeline_stages, gpu_allocation, model_shards):
        """Esegue pipeline distribuita su multiple GPU"""
        pipeline_id = f"pipeline_{request_id}"
        
        # Inizializza pipeline
        self.active_pipelines[pipeline_id] = {
            'client_socket': client_socket,
            'request_id': request_id,
            'session_id': session_id,
            'stages': pipeline_stages,
            'completed_stages': set(),
            'start_time': time.time(),
            'results': {},
            'gpu_allocation': gpu_allocation
        }
        
        # Avvia tutti gli stage
        for stage in pipeline_stages:
            self._execute_pipeline_stage(pipeline_id, stage)
        
        logger.info(f"🏗️ Distributed pipeline started: {pipeline_id} with {len(pipeline_stages)} stages")

    def _execute_pipeline_stage(self, pipeline_id: str, stage: PipelineStage):
        """Esegue uno stage della pipeline"""
        try:
            pipeline = self.active_pipelines[pipeline_id]
            
            # Prepara task per ogni GPU nello stage
            for gpu_id in stage.assigned_gpus:
                if gpu_id in self.worker_registry:
                    worker_info = self.worker_registry[gpu_id]
                    task = self._create_pipeline_task(pipeline_id, stage, gpu_id, pipeline)
                    self._send_to_worker(gpu_id, task)
                    
        except Exception as e:
            logger.error(f"❌ Pipeline stage execution failed: {e}")

    def _create_pipeline_task(self, pipeline_id: str, stage: PipelineStage, gpu_id: str, pipeline: dict) -> dict:
        """Crea task per pipeline"""
        return {
            'type': 'pipeline_task',
            'pipeline_id': pipeline_id,
            'stage_id': stage.stage_id,
            'stage_type': stage.stage_type,
            'input_data': pipeline.get('input_data', ''),
            'is_final_stage': stage.is_final_stage,
            'assigned_layers': self._get_assigned_layers(gpu_id),
            'dependencies': stage.dependencies,
            'timestamp': time.time()
        }

    def _handle_pipeline_response(self, client_socket: socket.socket, address: tuple, message: dict):
        """Gestisce risposte dalla pipeline"""
        try:
            pipeline_id = message.get('pipeline_id')
            stage_id = message.get('stage_id')
            result = message.get('result')
            worker_id = message.get('worker_id')
            
            if pipeline_id not in self.active_pipelines:
                logger.warning(f"⚠️ Unknown pipeline: {pipeline_id}")
                return
                
            pipeline = self.active_pipelines[pipeline_id]
            pipeline['completed_stages'].add(stage_id)
            pipeline['results'][stage_id] = result
            
            # Controlla se la pipeline è completata
            if self._is_pipeline_completed(pipeline):
                self._finalize_pipeline(pipeline_id)
                
        except Exception as e:
            logger.error(f"❌ Pipeline response error: {e}")

    def _is_pipeline_completed(self, pipeline: dict) -> bool:
        """Verifica se tutti gli stage della pipeline sono completati"""
        completed = pipeline['completed_stages']
        total_stages = len(pipeline['stages'])
        return len(completed) == total_stages

    def _finalize_pipeline(self, pipeline_id: str):
        """Finalizza pipeline e rilascia risorse"""
        try:
            pipeline = self.active_pipelines[pipeline_id]
            
            # Combina risultati
            final_result = self._combine_pipeline_results(pipeline)
            
            # 🔥 RILASCIA CARICO DELLE GPU
            if 'initial_loads' in pipeline:
                for worker_id, initial_load in pipeline['initial_loads'].items():
                    if worker_id in self.gpu_resources:
                        # Riduci carico (simula completamento task)
                        current_load = self.gpu_resources[worker_id].current_load
                        new_load = max(0.0, current_load - 0.3)  # Riduci del 30%
                        self.update_gpu_load(worker_id, new_load)
                        logger.info(f"📉 Released load for {worker_id}: {current_load:.2f} -> {new_load:.2f}")
            
            # Invia risposta al client
            response = {
                'type': 'inference_response',
                'request_id': pipeline['request_id'],
                'result': final_result,
                'pipeline_id': pipeline_id,
                'workers_used': list(pipeline.get('gpu_allocation', {}).keys()),
                'processing_time': time.time() - pipeline['start_time'],
                'final_loads': {  # 🔥 INVIA METRICHE FINALI
                    worker_id: self.gpu_resources[worker_id].current_load 
                    for worker_id in pipeline.get('workers_used', [])
                    if worker_id in self.gpu_resources
                },
                'timestamp': time.time()
            }
            
            self._send_to_client(pipeline['client_socket'], response)
            logger.info(f"✅ Pipeline completed with load balancing")
            
            # Cleanup
            del self.active_pipelines[pipeline_id]
            
        except Exception as e:
            logger.error(f"❌ Pipeline finalization error: {e}")

    def _combine_pipeline_results(self, pipeline: dict) -> str:
        """Combina risultati da multiple stage"""
        try:
            # Per pipeline semplici, prendi l'ultimo stage
            final_stage = None
            for stage in pipeline['stages']:
                if stage.is_final_stage:
                    final_stage = stage
                    break
                    
            if final_stage and final_stage.stage_id in pipeline['results']:
                return pipeline['results'][final_stage.stage_id]
            
            # Fallback: prendi l'ultimo risultato disponibile
            if pipeline['results']:
                return list(pipeline['results'].values())[-1]
                
            return "Pipeline completed successfully"
            
        except Exception as e:
            logger.error(f"❌ Results combination error: {e}")
            return "Processing completed"

    def _handle_heartbeat(self, client_socket: socket.socket, address: tuple, message: dict):
        """Gestisce heartbeat dei worker"""
        worker_id = message.get('worker_id')
        if worker_id in self.worker_registry:
            self.worker_registry[worker_id]['last_heartbeat'] = time.time()
            if worker_id in self.gpu_resources:
                self.gpu_resources[worker_id].last_heartbeat = time.time()

    def _handle_system_stats_request(self, client_socket: socket.socket, address: tuple, message: dict):
        """Fornisce statistiche di sistema dettagliate"""
        stats = self._get_detailed_system_stats()
        response = {
            'type': 'system_stats',
            'stats': stats,
            'timestamp': time.time()
        }
        self._send_to_client(client_socket, response)

    def _get_detailed_system_stats(self) -> Dict[str, Any]:
        """Restituisce statistiche di sistema dettagliate"""
        total_gpus = len(self.gpu_resources)
        active_pipelines = len(self.active_pipelines)
        
        # Calcola utilizzo medio
        avg_load = sum(gpu.current_load for gpu in self.gpu_resources.values()) / total_gpus if total_gpus > 0 else 0
        
        # Distribuzione tipi GPU
        gpu_types = {}
        for gpu in self.gpu_resources.values():
            gpu_type = gpu.node_type.value
            gpu_types[gpu_type] = gpu_types.get(gpu_type, 0) + 1
        
        return {
            'total_gpus': total_gpus,
            'active_pipelines': active_pipelines,
            'average_load': avg_load,
            'gpu_distribution': gpu_types,
            'performance_score': self._calculate_cluster_performance(),
            'coordinator_id': self.coordinator_id,
            'uptime': time.time() - getattr(self, 'start_time', time.time()),
            'memory_usage_mb': self._get_memory_usage(),
            'requests_processed': getattr(self, 'requests_processed', 0)
        }

    def _calculate_cluster_performance(self) -> float:
        """Calcola performance complessiva del cluster"""
        if not self.gpu_resources:
            return 0.0
        total_score = sum(gpu.performance_score for gpu in self.gpu_resources.values())
        return total_score / len(self.gpu_resources)

    def _get_memory_usage(self) -> int:
        """Stima uso memoria (semplificato)"""
        import psutil
        return psutil.Process().memory_info().rss // 1024 // 1024  # MB

    def _cleanup_loop(self):
        """Pulizia periodica delle risorse"""
        while self.running:
            time.sleep(60)  # Esegui ogni minuto
            try:
                self._cleanup_inactive_workers()
                self._cleanup_expired_sessions()
                self._cleanup_stalled_pipelines()
            except Exception as e:
                logger.error(f"❌ Cleanup error: {e}")

    def _cleanup_inactive_workers(self):
        """Rimuove worker inattivi"""
        current_time = time.time()
        inactive_workers = []
        
        for worker_id, worker_info in self.worker_registry.items():
            if current_time - worker_info['last_heartbeat'] > 120:  # 2 minuti
                inactive_workers.append(worker_id)
        
        for worker_id in inactive_workers:
            self._remove_worker(worker_id)
            logger.info(f"🗑️ Removed inactive worker: {worker_id}")

    def _remove_worker(self, worker_id: str):
        """Rimuove worker dal sistema"""
        if worker_id in self.worker_registry:
            try:
                # Chiudi socket
                socket = self.worker_registry[worker_id]['socket']
                socket.close()
            except:
                pass
            finally:
                # Cleanup risorse
                del self.worker_registry[worker_id]
                if worker_id in self.gpu_resources:
                    del self.gpu_resources[worker_id]

    def _cleanup_client_connection(self, client_socket: socket.socket, address: tuple):
        """Pulizia connessione client"""
        # Rimuovi dai worker registry
        workers_to_remove = []
        for worker_id, worker_info in self.worker_registry.items():
            if worker_info['socket'] == client_socket:
                workers_to_remove.append(worker_id)
        
        for worker_id in workers_to_remove:
            self._remove_worker(worker_id)
        
        # Chiudi socket
        try:
            client_socket.close()
        except:
            pass
            
        logger.info(f"🔌 Client disconnected: {address}")

    def _send_to_client(self, client_socket: socket.socket, message: dict):
        """Invia messaggio a client"""
        try:
            message_str = json.dumps(message) + '\n'
            client_socket.send(message_str.encode('utf-8'))
        except Exception as e:
            logger.error(f"❌ Send error: {e}")

    def _send_to_worker(self, worker_id: str, message: dict):
        """Invia messaggio a worker specifico"""
        if worker_id in self.worker_registry:
            self._send_to_client(self.worker_registry[worker_id]['socket'], message)
        else:
            logger.error(f"❌ Worker not found: {worker_id}")

    def _send_error(self, client_socket: socket.socket, request_id: str, error_msg: str):
        """Invia messaggio di errore"""
        error_response = {
            'type': 'error',
            'request_id': request_id,
            'error': error_msg,
            'timestamp': time.time()
        }
        self._send_to_client(client_socket, error_response)

    def _register_with_central_server(self):
        """Registrazione con server centrale"""
        try:
            import requests
            coord_data = {
                'coordinator_id': self.coordinator_id,
                'host': self.host,
                'port': self.port,
                'start_time': time.time(),
                'capabilities': {
                    'max_gpus': 1000,
                    'pipeline_parallelism': True,
                    'tensor_parallelism': True,
                    'auto_scaling': True
                }
            }
            
            response = requests.post(
                f"{self.central_server_url}/api/register_coordinator",
                json=coord_data,
                timeout=10
            )
            
            if response.status_code == 200:
                logger.info("✅ Registered with central server")
            else:
                logger.warning(f"⚠️ Central server registration failed: {response.status_code}")
                
        except Exception as e:
            logger.warning(f"⚠️ Central server not available: {e}")

    def _monitoring_loop(self):
        """Loop di monitoring con metriche di carico"""
        while self.running:
            time.sleep(30)  # Ogni 30 secondi
            try:
                stats = self._get_detailed_system_stats()
                
                # 🔥 METRICHE DI CARICO DETTAGLIATE
                load_distribution = {}
                for gpu in self.gpu_resources.values():
                    load_level = "high" if gpu.current_load > 0.7 else "medium" if gpu.current_load > 0.4 else "low"
                    load_distribution[load_level] = load_distribution.get(load_level, 0) + 1
                
                logger.info(f"📊 Cluster Stats: {stats['total_gpus']} GPUs, "
                        f"Avg Load: {stats['average_load']:.2f}, "
                        f"Load Dist: {load_distribution}")
                
                # 🔥 ALLARMI CARICO ELEVATO
                high_load_gpus = [gpu.worker_id for gpu in self.gpu_resources.values() 
                                if gpu.current_load > 0.85]
                if high_load_gpus:
                    logger.warning(f"🚨 High load GPUs: {high_load_gpus}")
                    
            except Exception as e:
                logger.error(f"❌ Monitoring error: {e}")

    def _auto_scaling_loop(self):
        """Loop di auto-scaling"""
        while self.running:
            time.sleep(60)  # Ogni minuto
            try:
                self.auto_scaler.check_and_scale()
            except Exception as e: 
                logger.error(f"❌ Auto-scaling error: {e}")

    def stop(self):
        """Ferma il coordinator gracefulmente"""
        self.running = False
        
        # Chiudi tutte le connessioni
        for worker_info in self.worker_registry.values():
            try:
                worker_info['socket'].close()
            except:
                pass
                
        # Chiudi socket principale
        if self.socket:
            self.socket.close()
            
        # Ferma thread pool
        self.thread_pool.shutdown(wait=True)
        
        logger.info("🛑 Scalable Coordinator stopped gracefully")

    def _execute_hybrid_inference(self, client_socket, request_id, prompt, session_id):
        """Esegui inferenza ibrida con gestione ID worker dinamica"""
        try:
            # PRIMA VERIFICA SE CI SONO WORKER DISPONIBILI
            available_workers = list(self.worker_registry.keys())
            
            if len(available_workers) < 2:
                logger.warning(f"⚠️ Not enough workers ({len(available_workers)}), using single GPU")
                self._execute_single_gpu_inference(client_socket, request_id, prompt, session_id)
                return
            
            # PRENDI I PRIMI 2 WORKER DISPONIBILI (non importa l'ID esatto)
            worker1, worker2 = available_workers[:2]
            
            logger.info(f"🎯 Using workers: {worker1} + {worker2}")
            
            # Task per input processing
            input_task = {
                'type': 'model_shard_task',
                'request_id': request_id,
                'shard_id': 'shard_0',
                'shard_config': {'type': 'input_processor', 'role': 'process_input'},
                'prompt': prompt,
                'total_shards': 2,
                'current_shard': 0,
                'is_final_shard': False,
                'timestamp': time.time()
            }
            
            # Task per output generation
            output_task = {
                'type': 'model_shard_task',
                'request_id': request_id,
                'shard_id': 'shard_1',
                'shard_config': {'type': 'output_generator', 'role': 'generate_output'},
                'prompt': prompt,
                'total_shards': 2,
                'current_shard': 1,
                'is_final_shard': True,
                'timestamp': time.time()
            }
            
            # INVIA TASK CON VERIFICA
            success_count = 0
            
            if worker1 in self.worker_registry:
                self._send_to_worker(worker1, input_task)
                success_count += 1
                logger.info(f"✅ Sent input task to {worker1}")
            else:
                logger.error(f"❌ Worker {worker1} not found in registry")
            
            if worker2 in self.worker_registry:
                self._send_to_worker(worker2, output_task)
                success_count += 1
                logger.info(f"✅ Sent output task to {worker2}")
            else:
                logger.error(f"❌ Worker {worker2} not found in registry")
            
            # SE NON RIESCI A INVIARE A ENTRAMBI, USA SINGLE GPU
            if success_count < 2:
                logger.error(f"❌ Only {success_count}/2 tasks sent, using single GPU")
                self._execute_single_gpu_inference(client_socket, request_id, prompt, session_id)
                return
            
            # REGISTRA LA PIPELINE
            self.active_pipelines[request_id] = {
                'client_socket': client_socket,
                'request_id': request_id,
                'session_id': session_id,
                'strategy': 'hybrid_parallel',
                'start_time': time.time(),
                'workers_used': [worker1, worker2],
                'received_responses': 0,
                'total_expected': 2,
                'results': {},
                'completed_stages': set(),
                'stages': [
                    {'stage_id': 'shard_0', 'is_final_stage': False},
                    {'stage_id': 'shard_1', 'is_final_stage': True}
                ]
            }
            
            logger.info(f"🚀 Hybrid inference started for {request_id}")

        except Exception as e:
            logger.error(f"❌ Hybrid inference failed: {e}")
            self._send_error(client_socket, request_id, f"Hybrid inference failed: {e}")

    def _execute_single_gpu_inference(self, client_socket, request_id, prompt, session_id):
        """Esegui inferenza single GPU - Metodo di compatibilità"""
        try:
            available_gpus = list(self.gpu_resources.keys())
            if not available_gpus:
                self._send_error(client_socket, request_id, "No workers available")
                return

            gpu_id = available_gpus[0]
            
            task = {
                'type': 'inference_task',
                'request_id': request_id,
                'prompt': prompt,
                'session_id': session_id,
                'timestamp': time.time()
            }
            self._send_to_worker(gpu_id, task)

            # 🔥 CORREZIONE: Aggiungi struttura completa
            self.active_pipelines[request_id] = {
                'client_socket': client_socket,
                'request_id': request_id, 
                'session_id': session_id,
                'strategy': 'single',
                'start_time': time.time(),
                'workers_used': [gpu_id],
                'received_responses': 0,
                'total_expected': 1,
                'results': {},
                'completed_stages': set(),  # 🔥 AGGIUNTO
                'stages': [                 # 🔥 AGGIUNTO
                    {'stage_id': 'single_stage', 'is_final_stage': True}
                ]
            }

            logger.info(f"🎯 Single GPU inference started for {request_id}")

        except Exception as e:
            logger.error(f"❌ Single GPU inference failed: {e}")
            self._send_error(client_socket, request_id, f"Single GPU inference failed: {e}")

    def _execute_pipeline_inference(self, client_socket, request_id, prompt, session_id):
        """Esegui pipeline inference - Metodo di compatibilità"""
        # Per ora usa hybrid come fallback
        self._execute_hybrid_inference(client_socket, request_id, prompt, session_id)

    def _execute_massive_scale_inference(self, client_socket, request_id, prompt, session_id):
        """Esegui massive scale inference - Metodo di compatibilità""" 
        # Per ora usa hybrid come fallback
        self._execute_hybrid_inference(client_socket, request_id, prompt, session_id)

    def _execute_adaptive_inference(self, client_socket, request_id, prompt, session_id):
        """Esegui adaptive inference - Metodo di compatibilità"""
        # Usa hybrid come strategia predefinita
        self._execute_hybrid_inference(client_socket, request_id, prompt, session_id)

    def _cleanup_expired_sessions(self):
        """Pulizia sessioni scadute - Metodo di compatibilità"""
        try:
            if hasattr(self, 'session_manager'):
                current_time = time.time()
                sessions_to_remove = []
                
                for session_id, session_data in self.session_manager.sessions.items():
                    if current_time - session_data['last_activity'] > self.session_manager.session_timeout:
                        sessions_to_remove.append(session_id)
                
                for session_id in sessions_to_remove:
                    del self.session_manager.sessions[session_id]
                    logger.info(f"🧹 Expired session cleaned: {session_id}")
                    
        except Exception as e:
            logger.error(f"❌ Session cleanup error: {e}")

    def _cleanup_stalled_pipelines(self):
        """Pulizia pipeline bloccate - Metodo di compatibilità"""
        try:
            current_time = time.time()
            stalled_pipelines = []
            
            for pipeline_id, pipeline_data in self.active_pipelines.items():
                elapsed = current_time - pipeline_data['start_time']
                if elapsed > 300:  # 5 minuti timeout
                    stalled_pipelines.append(pipeline_id)
            
            for pipeline_id in stalled_pipelines:
                logger.warning(f"⏰ Removing stalled pipeline: {pipeline_id}")
                try:
                    # Invia errore al client
                    self._send_error(
                        self.active_pipelines[pipeline_id]['client_socket'],
                        pipeline_id,
                        "Pipeline timeout - processing took too long"
                    )
                except:
                    pass
                finally:
                    del self.active_pipelines[pipeline_id]
                    
        except Exception as e:
            logger.error(f"❌ Stalled pipeline cleanup error: {e}")


    def _combine_hybrid_results(self, pipeline):
        """Combina risultati hybrid - Versione MIGLIORATA per pipeline parallela"""
        try:
            results = list(pipeline['results'].values())
            
            if not results:
                return "Elaborazione completata ma nessun risultato ricevuto"
            
            logger.info(f"🔍 Combining {len(results)} results from: {list(pipeline['results'].keys())}")
            
            # 🔥 ANALIZZA TUTTI I RISULTATI IN DETTAGLIO
            meaningful_results = {}
            
            for worker_id, result in pipeline['results'].items():
                if result and isinstance(result, str):
                    result_clean = result.strip()
                    
                    # 🔥 MIGLIORA IL FILTRO: non scartare basandosi solo su keyword
                    is_system_message = any(msg in result_clean.upper() for msg in 
                                        ['INPUT_PROCESSED', 'PROCESSED', 'SHARD_COMPLETE', 'COMPLETED'])
                    
                    if is_system_message:
                        logger.info(f"🔍 System message from {worker_id}: '{result_clean[:50]}...'")
                        continue
                    
                    # 🔥 CONSIDERA RISULTATI PIÙ CORTI MA SIGNIFICATIVI
                    if len(result_clean) > 5:  # Riduci soglia minima
                        meaningful_results[worker_id] = {
                            'result': result_clean,
                            'length': len(result_clean),
                            'is_meaningful': len(result_clean) > 20 or any(char in result_clean for char in ['.', '!', '?'])
                        }
                        logger.info(f"🔍 Meaningful result from {worker_id}: {len(result_clean)} chars")
            
            # 🔥 STRATEGIA DI COMBINAZIONE INTELLIGENTE
            if len(meaningful_results) == 0:
                logger.warning("⚠️ No meaningful results found, using fallback")
                # Prendi l'ultimo risultato nonostante tutto
                for result in reversed(results):
                    if result and len(str(result).strip()) > 5:
                        return str(result).strip()
                return "Elaborazione completata - Nessun risultato significativo ricevuto"
            
            elif len(meaningful_results) == 1:
                # Solo un risultato significativo
                worker_id, data = next(iter(meaningful_results.items()))
                logger.info(f"✅ Using single meaningful result from {worker_id}: {data['length']} chars")
                return data['result']
            
            else:
                # 🔥 COMBINA MULTIPLI RISULTATI
                logger.info(f"🔄 Combining {len(meaningful_results)} meaningful results")
                
                # Strategia: prendi il risultato più lungo e completo
                best_result = ""
                max_length = 0
                
                for worker_id, data in meaningful_results.items():
                    if data['length'] > max_length and data['is_meaningful']:
                        best_result = data['result']
                        max_length = data['length']
                        logger.info(f"🔍 Best candidate: {worker_id} with {max_length} chars")
                
                if best_result:
                    logger.info(f"✅ Combined result: {max_length} chars")
                    return best_result
                
                # Fallback: concatenazione intelligente
                all_results = [data['result'] for data in meaningful_results.values()]
                combined = " ".join(all_results)
                logger.info(f"🔄 Fallback concatenation: {len(combined)} chars")
                return combined
                
        except Exception as e:
            logger.error(f"❌ Results combination error: {e}")
            return "Errore nella combinazione dei risultati - Riprova"
    

# Classi di supporto
class SessionManager:
    def __init__(self):
        self.sessions: Dict[str, Dict] = {}
        self.session_timeout = 1800  # 30 minuti
        
    def get_session_context(self, session_id: str, prompt: str) -> str:
        """Gestisce contesto sessione"""
        if session_id not in self.sessions:
            self.sessions[session_id] = {
                'created': time.time(),
                'history': [],
                'last_activity': time.time()
            }
            
        session = self.sessions[session_id]
        session['last_activity'] = time.time()
        
        # Mantieni solo ultime conversazioni
        if len(session['history']) > 10:
            session['history'] = session['history'][-10:]
            
        return self._build_context(session, prompt)
    
    def _build_context(self, session: Dict, prompt: str) -> str:
        """Costruisce contesto dalla storia"""
        context = "Contesto conversazione:\n"
        for i, (q, a) in enumerate(session['history'][-3:]):
            context += f"Q{i+1}: {q}\nA{i+1}: {a}\n"
        context += f"\nNuova domanda: {prompt}\nRisposta:"
        return context

class PerformanceTracker:
    def __init__(self):
        self.metrics_history = []
        self.performance_thresholds = {
            'high_load': 0.8,
            'low_performance': 0.5
        }
        
    def track_metrics(self, metrics: Dict):
        """Traccia metriche di performance"""
        self.metrics_history.append({
            'timestamp': time.time(),
            'metrics': metrics
        })
        
        # Mantieni solo ultime 1000 entries
        if len(self.metrics_history) > 1000:
            self.metrics_history = self.metrics_history[-1000:]

class AutoScaler:
    def __init__(self, coordinator):
        self.coordinator = coordinator
        self.scaling_thresholds = {
            'scale_up_cpu': 0.8,
            'scale_down_cpu': 0.3,
            'scale_up_memory': 0.9
        }
        
    def check_and_scale(self):
        """Verifica e applica scaling se necessario"""
        stats = self.coordinator._get_detailed_system_stats()
        
        if stats['average_load'] > self.scaling_thresholds['scale_up_cpu']:
            self._scale_up()
        elif stats['average_load'] < self.scaling_thresholds['scale_down_cpu']:
            self._scale_down()
            
    def _scale_up(self):
        """Scala verso l'alto il cluster"""
        logger.info("📈 Auto-scaling: Scaling up cluster")
        # Implementa logica di scaling qui
        
    def _scale_down(self):
        """Scala verso il basso il cluster"""
        logger.info("📉 Auto-scaling: Scaling down cluster")
        # Implementa logica di scaling qui

if __name__ == "__main__":
    # Configura logging
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    )
    
    # Avvia coordinator
    coordinator = ScalableCoordinator(host="0.0.0.0", port=8765)
    try:
        coordinator.start()
    except KeyboardInterrupt:
        coordinator.stop()

class Coordinator(ScalableCoordinator):
    """
    Classe alias per mantenere compatibilità con il codice esistente
    CON TUTTI I METODI NECESSARI E FIRME CORRETTE
    """
    
    def __init__(self, host: str, port: int, blockchain):
        super().__init__(host, port, blockchain)
        self.message_processor = MessageProcessor(self)
        logger.info("🔧 Legacy Coordinator alias initialized with full compatibility")
    
    # 🔥 TUTTI GLI HANDLER CON 3 PARAMETRI (client_socket, address, message)
    
    def _handle_model_shard_response(self, client_socket: socket.socket, address: tuple, message: dict):
        """Compatibilità per model_shard_response - FIRMA CORRETTA"""
        try:
            logger.info(f"📦 Model shard response received from {address}")
            # Usa il message processor per gestire la compatibilità
            self.message_processor._handle_model_shard_response_compat(message)
        except Exception as e:
            logger.error(f"❌ Model shard response error: {e}")
    
    def _handle_worker_response(self, client_socket: socket.socket, address: tuple, message: dict):
        """Compatibilità per worker_response - FIRMA CORRETTA""" 
        try:
            logger.info(f"📦 Worker response received from {address}")
            self.message_processor._handle_worker_response_compat(message)
        except Exception as e:
            logger.error(f"❌ Worker response error: {e}")
    
    def _handle_worker_error(self, client_socket: socket.socket, address: tuple, message: dict):
        """Compatibilità per worker_error - FIRMA CORRETTA"""
        try:
            worker_id = message.get('worker_id', 'unknown')
            error_msg = message.get('error', 'Unknown error')
            logger.warning(f"⚠️ Worker {worker_id} error: {error_msg}")
            
            # Invia errore al client se possibile
            if client_socket and 'request_id' in message:
                self._send_error(client_socket, message['request_id'], f"Worker error: {error_msg}")
                
        except Exception as e:
            logger.error(f"❌ Worker error handling error: {e}")
    
    def _handle_pipeline_response(self, client_socket: socket.socket, address: tuple, message: dict):
        """Override per gestire anche chiamate legacy"""
        try:
            # Se chiamato in modo legacy (senza client_socket/address)
            if client_socket is None and address is None:
                super()._handle_pipeline_response(None, None, message)
            else:
                super()._handle_pipeline_response(client_socket, address, message)
        except Exception as e:
            logger.error(f"❌ Pipeline response compatibility error: {e}")
    
    def _process_message(self, client_socket: socket.socket, address: tuple, message: dict):
        """Processa messaggio con compatibilità completa"""
        try:
            self.message_processor.process_message(client_socket, address, message)
        except Exception as e:
            logger.error(f"❌ Message processing error: {e}")
    
    # 🔥 AGGIUNGI ALTRI METODI LEGACY SE NECESSARI
    
    def _send_to_worker(self, worker_id: str, message: dict):
        """Compatibilità per send_to_worker"""
        if worker_id in self.worker_registry:
            worker_info = self.worker_registry[worker_id]
            self._send_to_client(worker_info['socket'], message)
        else:
            logger.error(f"❌ Worker not found: {worker_id}")
    
    def _get_available_workers(self):
        """Compatibilità per get_available_workers"""
        available = []
        current_time = time.time()
        
        for worker_id, worker_info in self.worker_registry.items():
            if current_time - worker_info['last_heartbeat'] < 30:
                available.append((worker_id, worker_info))
        
        return available

    def _send_error(self, client_socket: socket.socket, request_id: str, error_msg: str):
        """Compatibilità per send_error"""
        error_response = {
            'type': 'error',
            'request_id': request_id,
            'error': error_msg,
            'timestamp': time.time()
        }
        self._send_to_client(client_socket, error_response)


class MessageProcessor:
    """Classe helper per processare i messaggi in modo compatibile"""
    
    def __init__(self, scalable_coordinator):
        self.scalable_coordinator = scalable_coordinator
    
    def process_message(self, client_socket: socket.socket, address: tuple, message: dict):
        """Processa messaggio con dispatch compatibile"""
        message_type = message.get('type')
        
        if message_type == 'worker_register':
            self.scalable_coordinator._handle_worker_registration(client_socket, address, message)
        elif message_type == 'inference_request':
            self.scalable_coordinator._handle_inference_request(client_socket, address, message)
        elif message_type == 'pipeline_response':
            self.scalable_coordinator._handle_pipeline_response(client_socket, address, message)
        elif message_type == 'model_shard_response':
            self._handle_model_shard_response_compat(message)
        elif message_type == 'worker_response':
            self._handle_worker_response_compat(message)
        elif message_type == 'worker_error':
            self.scalable_coordinator._handle_worker_error(client_socket, address, message)
        elif message_type == 'heartbeat':
            self.scalable_coordinator._handle_heartbeat(client_socket, address, message)
        elif message_type == 'get_system_stats':
            self.scalable_coordinator._handle_system_stats_request(client_socket, address, message)
        else:
            logger.warning(f"❓ Unknown message type: {message_type}")
    
    def _handle_model_shard_response_compat(self, message: dict):
        """Gestisce model_shard_response in modo compatibile - CON DEBUG"""
        try:
            request_id = message.get('request_id', 'unknown')
            worker_id = message.get('worker_id', 'unknown')
            shard_id = message.get('shard_id', 'unknown')
            is_final_shard = message.get('is_final_shard', False)
            
            logger.info(f"🔍 Model shard response: request={request_id}, worker={worker_id}, shard={shard_id}, final={is_final_shard}")
            
            # 🔥 VERIFICA SE È UNA PIPELINE IBRIDA
            if request_id in self.scalable_coordinator.active_pipelines:
                pipeline = self.scalable_coordinator.active_pipelines[request_id]
                
                logger.info(f"🔍 Pipeline found: strategy={pipeline.get('strategy')}, received={pipeline.get('received_responses', 0)}/{pipeline.get('total_expected', 0)}")
                
                if pipeline.get('strategy') == 'hybrid_parallel':
                    # 🔥 GESTISCI DIRETTAMENTE COME RISPOSTA IBRIDA
                    pipeline['received_responses'] += 1
                    pipeline['results'][worker_id] = message.get('result', '')
                    
                    logger.info(f"🔍 Updated pipeline: {pipeline['received_responses']}/{pipeline['total_expected']} responses")
                    
                    # Se tutte le risposte sono arrivate
                    if pipeline['received_responses'] >= pipeline['total_expected']:
                        logger.info(f"✅ All responses received for {request_id}")
                        # Combina risultati
                        final_result = self._combine_hybrid_results(pipeline)
                        
                        # Invia risposta
                        response = {
                            'type': 'inference_response',
                            'request_id': request_id,
                            'result': final_result,
                            'worker_id': 'hybrid_system',
                            'strategy': 'hybrid_parallel',
                            'workers_used': pipeline['workers_used'],
                            'session_id': pipeline.get('session_id', 'default'),
                            'timestamp': time.time()
                        }
                        
                        self.scalable_coordinator._send_to_client(pipeline['client_socket'], response)
                        logger.info(f"✅ Hybrid inference completed for {request_id}")
                        
                        # Cleanup
                        del self.scalable_coordinator.active_pipelines[request_id]
                    else:
                        logger.info(f"⏳ Waiting for more responses: {pipeline['received_responses']}/{pipeline['total_expected']}")
                    return
            
            # 🔥 SE NON È IBRIDA, CONVERTI IN PIPELINE_RESPONSE
            logger.info(f"🔍 Converting to pipeline response for {request_id}")
            pipeline_response = {
                'type': 'pipeline_response',
                'pipeline_id': request_id,
                'stage_id': f"shard_{shard_id}",
                'worker_id': worker_id,
                'result': message.get('result', ''),
                'is_final_stage': is_final_shard,
                'timestamp': time.time()
            }
            
            # Processa come pipeline response
            self.scalable_coordinator._handle_pipeline_response(None, None, pipeline_response)
            
        except Exception as e:
            logger.error(f"❌ Model shard response compatibility error: {e}")

    def _combine_hybrid_results(self, pipeline):
        """Combina risultati hybrid - Versione MIGLIORATA per risposte complete"""
        try:
            results = list(pipeline['results'].values())
            
            if not results:
                return "Elaborazione completata ma nessun risultato ricevuto"
            
            logger.info(f"🔍 Combining {len(results)} results: {list(pipeline['results'].keys())}")
            
            # 🔥 STRATEGIA MIGLIORATA: Cerca il risultato più lungo e significativo
            best_result = ""
            max_length = 0
            
            for worker_id, result in pipeline['results'].items():
                if result and isinstance(result, str):
                    # Filtra messaggi di sistema
                    if any(msg in result for msg in ['INPUT_PROCESSED', 'PROCESSED', 'SHARD_COMPLETE']):
                        logger.info(f"🔍 Skipping system message from {worker_id}")
                        continue
                    
                    # Preferisci risultati più lunghi e significativi
                    result_length = len(result.strip())
                    if result_length > max_length and result_length > 20:
                        best_result = result
                        max_length = result_length
                        logger.info(f"🔍 Better result found from {worker_id}: {result_length} chars")
            
            if best_result:
                logger.info(f"✅ Using result with {max_length} characters")
                return best_result
            
            # 🔥 FALLBACK: Prendi l'ultimo risultato non di sistema
            for result in reversed(results):
                if result and len(result.strip()) > 20:
                    logger.info(f"🔄 Using fallback result: {len(result.strip())} chars")
                    return result
            
            # 🔥 ULTIMO FALLBACK: Restituisci tutto
            final_fallback = " ".join([str(r) for r in results if r and len(str(r).strip()) > 5])
            if final_fallback:
                logger.info(f"⚠️ Using concatenated result: {len(final_fallback)} chars")
                return final_fallback
                
            return "Elaborazione completata con successo - Richiedi più dettagli se necessario"
            
        except Exception as e:
            logger.error(f"❌ Results combination error: {e}")
            return "Risposta generata dal sistema - Si prega di riformulare la domanda per maggiori dettagli"
    
    def _handle_worker_response_compat(self, message: dict):
        """Gestisce worker_response in modo compatibile"""
        try:
            request_id = message.get('request_id', 'unknown')
            
            # 🔥 VERIFICA SE È UNA RICHIESTA SINGLE GPU
            if request_id in self.scalable_coordinator.active_pipelines:
                pipeline = self.scalable_coordinator.active_pipelines[request_id]
                
                if pipeline.get('strategy') == 'single':
                    # 🔥 GESTISCI DIRETTAMENTE COME RISPOSTA SINGLE
                    response = {
                        'type': 'inference_response',
                        'request_id': request_id,
                        'result': message.get('result', ''),
                        'worker_id': message.get('worker_id', 'unknown'),
                        'strategy': 'single',
                        'timestamp': time.time()
                    }
                    
                    self.scalable_coordinator._send_to_client(pipeline['client_socket'], response)
                    logger.info(f"✅ Single GPU response sent for {request_id}")
                    
                    # Cleanup
                    del self.scalable_coordinator.active_pipelines[request_id]
                    return
            
            # 🔥 SE NON È SINGLE, CONVERTI IN PIPELINE_RESPONSE
            pipeline_response = {
                'type': 'pipeline_response', 
                'pipeline_id': f"single_{request_id}",
                'stage_id': 'single_stage',
                'worker_id': message.get('worker_id', 'unknown'),
                'result': message.get('result', ''),
                'is_final_stage': True,
                'timestamp': time.time()
            }
            
            self.scalable_coordinator._handle_pipeline_response(None, None, pipeline_response)
            
        except Exception as e:
            logger.error(f"❌ Worker response compatibility error: {e}")


class IntelligentLoadBalancer:
    def __init__(self):
        self.gpu_resources: Dict[str, GPUResource] = {}
        self.load_history: Dict[str, List[float]] = {}
        self.performance_metrics = {}
        
    def add_gpu_resource(self, gpu_resource: GPUResource):
        self.gpu_resources[gpu_resource.worker_id] = gpu_resource
        self.load_history[gpu_resource.worker_id] = []
        
    def get_optimal_gpu_allocation(self, total_gpus_needed: int, complexity: str) -> List[GPUResource]:
        """Selezione ottimale con bilanciamento intelligente"""
        # 🔥 FILTRO DINAMICO: Escludi GPU sovraccaricate
        available_gpus = [
            gpu for gpu in self.gpu_resources.values()
            if gpu.current_load < self._get_dynamic_threshold(gpu)
        ]
        
        if not available_gpus:
            logger.warning("⚠️ No available GPUs, using all with load consideration")
            available_gpus = list(self.gpu_resources.values())
        
        # 🔥 STRATEGIA DI BILANCIAMENTO INTELLIGENTE
        if complexity == "very_high":
            # Per alta complessità: performance prima, ma bilanciato
            sorted_gpus = sorted(available_gpus,
                               key=lambda x: (-x.performance_score * (1 - x.current_load), x.current_load))
        elif complexity == "high":
            # Bilanciamento performance/carico
            sorted_gpus = sorted(available_gpus,
                               key=lambda x: (-x.performance_score * 0.7 + (1 - x.current_load) * 0.3))
        else:
            # Per carichi leggeri: carico prima
            sorted_gpus = sorted(available_gpus,
                               key=lambda x: (x.current_load, -x.performance_score))
        
        selected = sorted_gpus[:total_gpus_needed]
        
        # 🔥 LOG DETTAGLIATO
        logger.info(f"🎯 Load balancing: {len(selected)}/{total_gpus_needed} GPUs selected")
        for gpu in selected:
            logger.info(f"   - {gpu.worker_id}: score={gpu.performance_score:.2f}, load={gpu.current_load:.2f}")
        
        return selected
    
    def _get_dynamic_threshold(self, gpu: GPUResource) -> float:
        """Soglia dinamica basata sulle capacità della GPU"""
        base_threshold = 0.8
        # GPU performanti possono gestire carico più alto
        if gpu.performance_score > 2.0:
            return 0.9
        elif gpu.node_type == NodeType.EDGE_NODE:
            return 0.6  # GPU edge più conservative
        return base_threshold
    
    def record_load_metrics(self, worker_id: str):
        """Registra metriche di carico per analisi"""
        if worker_id in self.gpu_resources:
            current_load = self.gpu_resources[worker_id].current_load
            self.load_history[worker_id].append(current_load)
            
            # Mantieni solo ultime 100 misurazioni
            if len(self.load_history[worker_id]) > 100:
                self.load_history[worker_id] = self.load_history[worker_id][-50:]

