import socket
import json
import time
import threading
import logging
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import argparse

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class DistributedWorker:
    def __init__(self, coordinator_host='localhost', coordinator_port=8765, model_path=None, gpu_id=0):
        self.coordinator_host = coordinator_host
        self.coordinator_port = coordinator_port
        self.model_path = model_path
        self.gpu_id = gpu_id
        self.worker_id = f"dist_worker_{gpu_id}_{int(time.time())}"
        self.socket = None
        self.connected = False
        self.model = None
        self.tokenizer = None
        self.device = None
        
        # Statistiche
        self.requests_processed = 0
        self.total_tokens_generated = 0
        self.start_time = time.time()
        
    def connect(self):
        try:
            # Prima carica il modello
            self._load_model()
            
            # Poi connetti al coordinator
            self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.socket.connect((self.coordinator_host, self.coordinator_port))
            self.connected = True
            
            logger.info("🚀 Worker distribuito connesso!")
            
            # Registrazione
            self._register_worker()
            
            # Threads
            threading.Thread(target=self._heartbeat_loop, daemon=True).start()
            threading.Thread(target=self._receive_loop, daemon=True).start()
            
            # Mantieni attivo
            while self.connected:
                time.sleep(1)
                
        except Exception as e:
            logger.error(f"❌ Errore: {e}")
    
    def _load_model(self):
        """Carica il modello LLM"""
        try:
            logger.info(f"🧠 Caricando modello da {self.model_path}...")
            
            if torch.cuda.is_available():
                self.device = torch.device(f"cuda:{self.gpu_id}")
                logger.info(f"🎮 GPU {self.gpu_id}: {torch.cuda.get_device_name(self.gpu_id)}")
            else:
                self.device = torch.device("cpu")
                logger.info("💻 Usando CPU")
            
            # Carica tokenizer e modello
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_path, 
                local_files_only=True,
                trust_remote_code=True
            )
            
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_path,
                local_files_only=True,
                torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32,
                device_map="auto" if self.device.type == 'cuda' else None,
                trust_remote_code=True,
                low_cpu_mem_usage=True
            )
            
            logger.info("✅ Modello caricato correttamente!")
            
        except Exception as e:
            logger.error(f"❌ Errore caricamento modello: {e}")
            raise
    
    def _register_worker(self):
        """Registra il worker"""
        gpu_memory = 0
        if torch.cuda.is_available():
            gpu_memory = torch.cuda.get_device_properties(self.gpu_id).total_memory
            
        register_msg = {
            'type': 'worker_register',
            'worker_id': self.worker_id,
            'gpu_info': {
                'gpu_name': torch.cuda.get_device_name(self.gpu_id) if torch.cuda.is_available() else 'CPU',
                'total_memory': gpu_memory
            },
            'model_info': {
                'model_path': self.model_path,
                'parameters': sum(p.numel() for p in self.model.parameters()),
                'loaded': True
            },
            'performance_score': 0.9
        }
        self._send_single_message(register_msg)
    
    def _send_single_message(self, message):
        """Invia UN SOLO messaggio JSON con newline"""
        try:
            message_json = json.dumps(message) + '\n'
            self.socket.send(message_json.encode('utf-8'))
        except Exception as e:
            logger.error(f"❌ Errore invio: {e}")
            self.connected = False
    
    def _heartbeat_loop(self):
        """Invia heartbeat ogni 15 secondi"""
        while self.connected:
            try:
                heartbeat_msg = {
                    'type': 'heartbeat', 
                    'worker_id': self.worker_id,
                    'timestamp': time.time()
                }
                self._send_single_message(heartbeat_msg)
                time.sleep(15)
            except:
                break
    
    def _receive_loop(self):
        """Loop di ricezione - VERSIONE DEBUG"""
        buffer = ""
        while self.connected:
            try:
                data = self.socket.recv(4096).decode('utf-8')
                if not data:
                    break
                    
                print(f"🔍 WORKER Ricevuti {len(data)} bytes: {data[:100]}...")
                buffer += data
                
                # Processa messaggi completi (separati da newline)
                while '\n' in buffer:
                    line, buffer = buffer.split('\n', 1)
                    if line.strip():
                        print(f"🔍 WORKER Processing: {line[:100]}...")
                        try:
                            message = json.loads(line)
                            self._handle_received_message(message)
                        except json.JSONDecodeError:
                            print(f"❌ WORKER JSON non valido")
                            continue
                        except Exception as e:
                            print(f"❌ WORKER Errore processamento: {e}")
                            
            except Exception as e:
                print(f"❌ WORKER Errore ricezione: {e}")
                break
    def _register_worker(self):
        """Registra il worker - VERSIONE DEBUG"""
        gpu_memory = 0
        if torch.cuda.is_available():
            gpu_memory = torch.cuda.get_device_properties(self.gpu_id).total_memory
            
        register_msg = {
            'type': 'worker_register',
            'worker_id': self.worker_id,
            'gpu_info': {
                'gpu_name': torch.cuda.get_device_name(self.gpu_id) if torch.cuda.is_available() else 'CPU',
                'total_memory': gpu_memory
            },
            'model_info': {
                'model_path': self.model_path,
                'parameters': sum(p.numel() for p in self.model.parameters()),
                'loaded': True
            },
            'performance_score': 0.9
        }
    
        print(f"🔍 WORKER Invio registrazione: {register_msg}")
        self._send_single_message(register_msg)
    
    def _handle_received_message(self, message):
        """Gestisce i messaggi ricevuti"""
        msg_type = message.get('type')
        
        if msg_type == 'inference_task':
            threading.Thread(target=self._handle_inference_task, args=(message,), daemon=True).start()
        elif msg_type == 'registration_confirmed':
            logger.info("✅ Registrazione confermata dal coordinator!")
        else:
            logger.info(f"📨 Messaggio ricevuto: {msg_type}")
    
    def _handle_inference_task(self, message):
        """Gestisce task di inferenza - VERSIONE CORRETTA"""
        try:
            prompt = message['prompt_chunk']
            request_id = message['request_id']
            
            logger.info(f"🎯 Processing {request_id}: {prompt[:50]}...")
            
            # Prompt per Phi-2
            full_prompt = f"{prompt}\n\nRisposta:"
            
            # Tokenizza
            inputs = self.tokenizer.encode(full_prompt, return_tensors="pt").to(self.device)
            
            # Genera
            with torch.no_grad():
                outputs = self.model.generate(
                    inputs,
                    max_length=len(inputs[0]) + 150,
                    temperature=0.7,
                    do_sample=True,
                    top_p=0.9,
                    pad_token_id=self.tokenizer.eos_token_id,
                    repetition_penalty=1.1,
                    attention_mask=inputs.ne(self.tokenizer.pad_token_id).long()
                )
            
            # Decodifica
            full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Estrai solo la risposta (dopo il prompt)
            if full_response.startswith(full_prompt):
                response = full_response[len(full_prompt):].strip()
            else:
                response = full_response.strip()
            
            logger.info(f"✅ Generated {len(response)} chars for {request_id}")
            
            # Invia risposta
            response_message = {
                'type': 'worker_response',
                'request_id': request_id,
                'worker_id': self.worker_id,
                'result': response
            }
            self._send_single_message(response_message)
            
            # Aggiorna statistiche
            self.requests_processed += 1
            self.total_tokens_generated += len(response.split())
            
        except Exception as e:
            logger.error(f"❌ Inference error: {e}")
            error_message = {
                'type': 'worker_error', 
                'request_id': message.get('request_id', 'unknown'),
                'error': str(e)
            }
            self._send_single_message(error_message)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Worker Distribuito per LLM')
    parser.add_argument('--model-path', required=True, help='Percorso al modello')
    parser.add_argument('--gpu-id', type=int, default=0, help='ID GPU')
    parser.add_argument('--coordinator-host', default='localhost')
    parser.add_argument('--coordinator-port', type=int, default=8765)
    
    args = parser.parse_args()
    
    print("🚀 DISTRIBUTED WORKER AVVIATO")
    print(f"📁 Modello: {args.model_path}")
    print(f"🎮 GPU: {args.gpu_id}")
    print(f"🔗 Coordinator: {args.coordinator_host}:{args.coordinator_port}")
    print("=" * 50)
    
    worker = DistributedWorker(
        coordinator_host=args.coordinator_host,
        coordinator_port=args.coordinator_port,
        model_path=args.model_path,
        gpu_id=args.gpu_id
    )
    worker.connect()