import torch
import transformers
import logging
import json
import socket
import time
from typing import Dict, Any

logger = logging.getLogger(__name__)

class ShardedGPUWorker:
    def __init__(self, coordinator_host: str, coordinator_port: int, model_path: str, 
                 gpu_id: int = 0, shard_id: int = 0, total_shards: int = 1):
        self.coordinator_host = coordinator_host
        self.coordinator_port = coordinator_port
        self.model_path = model_path
        self.gpu_id = gpu_id
        self.shard_id = shard_id
        self.total_shards = total_shards
        self.is_sharded = total_shards > 1
        
        self.model = None
        self.tokenizer = None
        self.device = None
        self.socket = None
        self.connected = False
        self.worker_id = f"worker_{gpu_id}_{shard_id}_{int(time.time())}"
        
    def connect(self):
        """Connette al coordinator e carica il modello shardato"""
        try:
            # Setup device
            self.device = torch.device(f"cuda:{self.gpu_id}" if torch.cuda.is_available() else "cpu")
            logger.info(f"🚀 Using device: {self.device}")
            
            # Carica modello shardato
            self._load_sharded_model()
            
            # Connetti al coordinator
            self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            self.socket.connect((self.coordinator_host, self.coordinator_port))
            self.connected = True
            
            # Registrati come worker shardato
            self._register_with_coordinator()
            
            # Avvia loop di lavoro
            self._work_loop()
            
        except Exception as e:
            logger.error(f"❌ Worker connection failed: {e}")
            raise
    
    def _load_sharded_model(self):
        """Carica una porzione shardata del modello"""
        try:
            logger.info(f"🔧 Loading sharded model {self.shard_id}/{self.total_shards} from {self.model_path}")
            
            # Carica tokenizer
            self.tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_path)
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            # Per sharding semplice: carica tutto il modello ma su GPU specifica
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                self.model_path,
                torch_dtype=torch.float16,
                device_map=f"cuda:{self.gpu_id}",
                trust_remote_code=True
            )
            
            logger.info(f"✅ Model shard {self.shard_id} loaded successfully")
            
        except Exception as e:
            logger.error(f"❌ Model loading failed: {e}")
            raise
    
    def _register_with_coordinator(self):
        """Registra il worker con il coordinator"""
        gpu_info = self._get_gpu_info()
        model_info = self._get_model_info()
        
        registration_data = {
            'type': 'worker_register',
            'worker_id': self.worker_id,
            'gpu_info': gpu_info,
            'model_info': model_info,
            'performance_score': 1.0,
            'is_sharded': self.is_sharded,
            'shard_id': self.shard_id,
            'total_shards': self.total_shards,
            'timestamp': time.time()
        }
        
        self._send_message(registration_data)
        logger.info(f"✅ Worker {self.worker_id} registered with coordinator")
    
    def _get_gpu_info(self) -> Dict[str, Any]:
        """Ottiene informazioni GPU"""
        if torch.cuda.is_available():
            gpu_props = torch.cuda.get_device_properties(self.gpu_id)
            return {
                'gpu_name': gpu_props.name,
                'total_memory': gpu_props.total_memory,
                'multi_processor_count': gpu_props.multi_processor_count,
                'major': gpu_props.major,
                'minor': gpu_props.minor
            }
        else:
            return {
                'gpu_name': 'CPU',
                'total_memory': 0,
                'multi_processor_count': 0
            }
    
    def _get_model_info(self) -> Dict[str, Any]:
        """Ottiene informazioni modello"""
        if self.model is None:
            return {}
            
        total_params = sum(p.numel() for p in self.model.parameters())
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        
        return {
            'model_name': self.model_path.split('/')[-1],
            'model_path': self.model_path,
            'total_parameters': total_params,
            'trainable_parameters': trainable_params,
            'is_sharded': self.is_sharded,
            'shard_id': self.shard_id,
            'total_shards': self.total_shards
        }
    
    def _work_loop(self):
        """Loop principale del worker"""
        buffer = ""
        
        while self.connected:
            try:
                self.socket.settimeout(1.0)
                data = self.socket.recv(16384).decode('utf-8')
                
                if not data:
                    logger.warning("📡 Coordinator disconnected")
                    break
                    
                buffer += data
                
                while '\n' in buffer:
                    message_str, buffer = buffer.split('\n', 1)
                    if message_str.strip():
                        try:
                            message = json.loads(message_str)
                            self._process_message(message)
                        except json.JSONDecodeError as e:
                            logger.warning(f"❌ Invalid JSON received: {e}")
                            
            except socket.timeout:
                # Invia heartbeat periodico
                if time.time() % 10 < 0.1:  # Ogni ~10 secondi
                    self._send_heartbeat()
                continue
            except Exception as e:
                logger.error(f"❌ Work loop error: {e}")
                break
    
    def _process_message(self, message: Dict[str, Any]):
        """Processa un messaggio dal coordinator"""
        msg_type = message.get('type')
        
        if msg_type == 'inference_task':
            self._handle_inference_task(message)
        elif msg_type == 'registration_confirmed':
            logger.info("✅ Registration confirmed by coordinator")
        else:
            logger.warning(f"❌ Unknown message type: {msg_type}")
    
    def _handle_inference_task(self, message: Dict[str, Any]):
        """Gestisce task di inferenza"""
        try:
            request_id = message['request_id']
            prompt_chunk = message['prompt_chunk']
            is_sharded = message.get('is_sharded', False)
            shard_id = message.get('shard_id', 0)
            
            logger.info(f"🧠 Processing inference {request_id} (shard {shard_id})")
            
            # Esegui inferenza
            result = self._run_inference(prompt_chunk)
            
            # Invia risposta
            response = {
                'type': 'worker_response',
                'request_id': request_id,
                'worker_id': self.worker_id,
                'result': result,
                'is_sharded': is_sharded,
                'shard_id': shard_id,
                'timestamp': time.time()
            }
            
            self._send_message(response)
            logger.info(f"✅ Inference completed for {request_id}")
            
        except Exception as e:
            logger.error(f"❌ Inference failed: {e}")
            
            # Invia errore
            error_response = {
                'type': 'error',
                'request_id': message.get('request_id', 'unknown'),
                'worker_id': self.worker_id,
                'error': str(e),
                'timestamp': time.time()
            }
            self._send_message(error_response)
    
    def _run_inference(self, prompt: str, max_length: int = 500) -> str:
        """Esegue inferenza sul prompt"""
        try:
            if not prompt.strip():
                return "Please provide a prompt."
            
            # Tokenizza input
            inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
            
            # Genera risposta
            with torch.no_grad():
                outputs = self.model.generate(
                    inputs,
                    max_length=min(len(inputs[0]) + max_length, 2048),
                    num_return_sequences=1,
                    temperature=0.7,
                    do_sample=True,
                    pad_token_id=self.tokenizer.eos_token_id,
                    attention_mask=inputs.new_ones(inputs.shape)
                )
            
            # Decodifica risposta
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Rimuovi il prompt dalla risposta
            if response.startswith(prompt):
                response = response[len(prompt):].strip()
            
            return response
            
        except Exception as e:
            logger.error(f"❌ Inference error: {e}")
            return f"Error during inference: {str(e)}"
    
    def _send_heartbeat(self):
        """Invia heartbeat al coordinator"""
        heartbeat = {
            'type': 'heartbeat',
            'worker_id': self.worker_id,
            'timestamp': time.time()
        }
        self._send_message(heartbeat)
    
    def _send_message(self, message_dict: Dict[str, Any]):
        """Invia messaggio al coordinator"""
        try:
            message_str = json.dumps(message_dict) + '\n'
            self.socket.send(message_str.encode('utf-8'))
        except Exception as e:
            logger.error(f"❌ Send error: {e}")
            self.connected = False
    
    def disconnect(self):
        """Disconnette il worker"""
        self.connected = False
        if self.socket:
            self.socket.close()
        logger.info("🔌 Worker disconnected")