import logging
import torch
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import Optional

logger = logging.getLogger(__name__)

class LLMManager:
    def __init__(self, model_path: str, gpu_id: int = 0):
        self.model_path = model_path
        self.gpu_id = gpu_id
        self.model = None
        self.tokenizer = None
        self.device = None
        self._initialize_model()
        
    def _initialize_model(self):
        """Inizializza il modello LLM"""
        try:
            logger.info(f"Loading model from {self.model_path} on GPU {self.gpu_id}")
            
            # Determina il device
            if torch.cuda.is_available() and self.gpu_id < torch.cuda.device_count():
                self.device = torch.device(f"cuda:{self.gpu_id}")
                logger.info(f"Using GPU: {torch.cuda.get_device_name(self.gpu_id)}")
            else:
                self.device = torch.device("cpu")
                logger.info("Using CPU")
            
            # Verifica se è un percorso locale o un repository HuggingFace
            if os.path.exists(self.model_path):
                logger.info(f"Loading from local path: {self.model_path}")
                # Carica da percorso locale
                self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, local_files_only=True)
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_path,
                    local_files_only=True,
                    torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32,
                    device_map="auto" if self.device.type == 'cuda' else None,
                    low_cpu_mem_usage=True
                )
            else:
                logger.info(f"Loading from HuggingFace Hub: {self.model_path}")
                # Carica da HuggingFace Hub
                self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_path,
                    torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32,
                    device_map="auto" if self.device.type == 'cuda' else None,
                    low_cpu_mem_usage=True
                )
            
            # Configurazione per modelli che potrebbero non avere pad token
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                
            if self.device.type == 'cpu':
                self.model = self.model.to(self.device)
                
            logger.info("Model loaded successfully")
            
        except Exception as e:
            logger.error(f"Error loading model: {e}")
            # Fallback: usa un modello molto piccolo per test
            self._load_fallback_model()
    
    def _load_fallback_model(self):
        """Carica un modello di fallback molto piccolo per test"""
        try:
            logger.info("Attempting to load fallback model for testing...")
            
            # Usa un modello molto piccolo e veloce
            fallback_model = "microsoft/DialoGPT-small"
            
            self.tokenizer = AutoTokenizer.from_pretrained(fallback_model)
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
            self.model = AutoModelForCausalLM.from_pretrained(
                fallback_model,
                torch_dtype=torch.float16 if self.device.type == 'cuda' else torch.float32,
                device_map="auto" if self.device.type == 'cuda' else None,
            )
            
            if self.device.type == 'cpu':
                self.model = self.model.to(self.device)
                
            logger.info("Fallback model loaded successfully")
            
        except Exception as e:
            logger.error(f"Error loading fallback model: {e}")
            raise
    
    def get_model_info(self) -> dict:
        """Restituisce informazioni sul modello"""
        if self.model is None:
            return {"error": "Model not loaded"}
            
        return {
            "model_path": self.model_path,
            "model_type": type(self.model).__name__,
            "device": str(self.device),
            "parameters": sum(p.numel() for p in self.model.parameters()),
            "loaded": True
        }
    
    def generate(self, prompt: str, max_length: int = 1000) -> str:
        """Genera testo dal prompt"""
        if self.model is None or self.tokenizer is None:
            return "Error: Model not loaded"
            
        try:
            # Tokenizza il prompt
            inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
            
            # Genera testo
            with torch.no_grad():
                outputs = self.model.generate(
                    inputs,
                    max_length=len(inputs[0]) + max_length,
                    num_return_sequences=1,
                    temperature=0.7,
                    do_sample=True,
                    pad_token_id=self.tokenizer.eos_token_id,
                    attention_mask=inputs.ne(self.tokenizer.pad_token_id)
                )
            
            # Decodifica il risultato
            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Rimuove il prompt dal risultato
            if generated_text.startswith(prompt):
                result = generated_text[len(prompt):].strip()
            else:
                result = generated_text.strip()
                
            logger.info(f"Generated {len(result)} characters")
            return result
            
        except Exception as e:
            logger.error(f"Error during generation: {e}")
            return f"Error during text generation: {str(e)}"
    
    def cleanup(self):
        """Pulizia delle risorse"""
        if self.model:
            del self.model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()