import socket
import json
import time
import threading
import logging
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class HeavyWorker:
    def __init__(self, coordinator_host='localhost', coordinator_port=8765, model_path=None, gpu_id=0):
        self.coordinator_host = coordinator_host
        self.coordinator_port = coordinator_port
        self.model_path = model_path
        self.gpu_id = gpu_id
        self.worker_id = f"heavy_worker_{gpu_id}_{int(time.time())}"
        self.socket = None
        self.connected = False
        self.model = None
        self.tokenizer = None
        self.device = None
        
    def connect(self):
        try:
            # Prima carica il modello
            self._load_model()
            
            # Poi connetti al coordinator
            self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.socket.connect((self.coordinator_host, self.coordinator_port))
            self.connected = True
            
            logger.info("✅ Modello caricato e connesso al coordinator!")
            
            # Registrazione
            self._register_worker()
            
            # Threads
            threading.Thread(target=self._heartbeat_loop, daemon=True).start()
            threading.Thread(target=self._receive_loop, daemon=True).start()
            
            # Mantieni attivo
            while self.connected:
                time.sleep(1)
                
        except Exception as e:
            logger.error(f"❌ Errore: {e}")
    
    def _load_model(self):
        """Carica il modello LLM pesante"""
        try:
            logger.info(f"🧠 Caricando modello da {self.model_path}...")
            
            if torch.cuda.is_available():
                self.device = torch.device(f"cuda:{self.gpu_id}")
                logger.info(f"🎮 Usando GPU: {torch.cuda.get_device_name(self.gpu_id)}")
            else:
                self.device = torch.device("cpu")
                logger.info("💻 Usando CPU")
            
            # Carica tokenizer e modello
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_path, 
                local_files_only=True,
                trust_remote_code=True
            )
            
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            self.model = AutoModelForCausalLM.from_pretrained(
                self.model_path,
                local_files_only=True,
                torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32,
                device_map="auto" if self.device.type == 'cuda' else None,
                trust_remote_code=True,
                low_cpu_mem_usage=True
            )
            
            logger.info("✅ Modello caricato correttamente!")
            
        except Exception as e:
            logger.error(f"❌ Errore caricamento modello: {e}")
            raise
    
    def _register_worker(self):
        """Registra il worker"""
        gpu_memory = 0
        if torch.cuda.is_available():
            gpu_memory = torch.cuda.get_device_properties(self.gpu_id).total_memory
            
        register_msg = {
            'type': 'worker_register',
            'worker_id': self.worker_id,
            'gpu_info': {
                'gpu_name': torch.cuda.get_device_name(self.gpu_id) if torch.cuda.is_available() else 'CPU',
                'total_memory': gpu_memory
            },
            'model_info': {
                'model_path': self.model_path,
                'parameters': sum(p.numel() for p in self.model.parameters()),
                'loaded': True
            },
            'performance_score': 0.9
        }
        self._send_message(register_msg)
    
    def _send_message(self, message):
        try:
            self.socket.send(json.dumps(message).encode('utf-8'))
        except Exception as e:
            logger.error(f"Send error: {e}")
            self.connected = False
    
    def _heartbeat_loop(self):
        while self.connected:
            try:
                self._send_message({
                    'type': 'heartbeat', 
                    'worker_id': self.worker_id,
                    'timestamp': time.time()
                })
                time.sleep(10)
            except:
                break
    
    def _receive_loop(self):
        while self.connected:
            try:
                data = self.socket.recv(4096).decode('utf-8')
                if not data:
                    break
                    
                message = json.loads(data)
                if message.get('type') == 'inference_task':
                    self._handle_inference(message)  # ← CORRETTO!
                    
            except Exception as e:
                logger.error(f"Receive error: {e}")
                break
    
    def _handle_inference(self, message):  # ← NOME CORRETTO!
        """Esegue inferenza con il modello vero"""
        try:
            prompt = message['prompt_chunk']
            request_id = message['request_id']
            
            logger.info(f"🎯 Processing request {request_id}: {prompt[:50]}...")
            
            # PROVA: Prompt più semplice per Phi-2
            full_prompt = f"{prompt}\n\nRisposta:"
            
            # Tokenizza
            inputs = self.tokenizer.encode(full_prompt, return_tensors="pt").to(self.device)
            
            logger.info(f"📝 Input tokens: {len(inputs[0])}")
            
            # Genera
            with torch.no_grad():
                outputs = self.model.generate(
                    inputs,
                    max_length=len(inputs[0]) + 150,  # Aumenta la lunghezza
                    temperature=0.8,
                    do_sample=True,
                    top_p=0.9,
                    pad_token_id=self.tokenizer.eos_token_id,
                    repetition_penalty=1.1,
                    attention_mask=inputs.ne(self.tokenizer.pad_token_id).long()  # Aggiungi attention mask
                )
            
            # Decodifica
            full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            logger.info(f"📄 Full response: {repr(full_response)}")
            
            # Estrae solo la risposta (dopo il prompt)
            if full_response.startswith(full_prompt):
                response = full_response[len(full_prompt):].strip()
            else:
                response = full_response.strip()
            
            logger.info(f"🎯 Clean response: {repr(response)}")
            logger.info(f"📤 Generated response: {len(response)} characters")
            
            if len(response) < 5:
                response = "🤖 Scusa, la risposta era troppo corta. Prova a fare una domanda più specifica!"
            
            # Invia risposta
            self._send_message({
                'type': 'worker_response',
                'request_id': request_id,
                'worker_id': self.worker_id,
                'result': response
            })
            
        except Exception as e:
            logger.error(f"❌ Inference error: {e}")
            self._send_message({
                'type': 'worker_error',
                'request_id': message['request_id'],
                'error': str(e)
            })

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='Heavy Worker per modelli LLM grandi')
    parser.add_argument('--model-path', required=True, help='Percorso al modello scaricato')
    parser.add_argument('--gpu-id', type=int, default=0, help='ID GPU da usare')
    parser.add_argument('--coordinator-host', default='localhost', help='Host del coordinator')
    parser.add_argument('--coordinator-port', type=int, default=8765, help='Porta del coordinator')
    
    args = parser.parse_args()
    
    print("=== HEAVY WORKER AVVIATO ===")
    print(f"Modello: {args.model_path}")
    print(f"GPU: {args.gpu_id}")
    print(f"Coordinator: {args.coordinator_host}:{args.coordinator_port}")
    print("=" * 30)
    
    worker = HeavyWorker(
        coordinator_host=args.coordinator_host,
        coordinator_port=args.coordinator_port,
        model_path=args.model_path,
        gpu_id=args.gpu_id
    )
    worker.connect()