import os
import sys
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import snapshot_download
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

def download_mistral_7b():
    """Scarica Mistral 7B Instruct v0.2"""
    model_name = "mistralai/Mistral-7B-Instruct-v0.2"
    local_path = "./models/mistral-7b-instruct"
    
    try:
        os.makedirs(local_path, exist_ok=True)
        
        logger.info(f"🎯 Scaricando {model_name}...")
        logger.info("📦 Questo modello è di ~15GB, potrebbe richiedere tempo...")
        
        # Scarica il modello
        snapshot_download(
            repo_id=model_name,
            local_dir=local_path,
            local_dir_use_symlinks=False,
            resume_download=True
        )
        
        logger.info("✅ Download completato!")
        logger.info(f"📁 Modello salvato in: {local_path}")
        
        return local_path
        
    except Exception as e:
        logger.error(f"❌ Errore nel download: {e}")
        return None

def test_model_loading(model_path):
    """Testa il caricamento del modello"""
    try:
        logger.info("🧪 Test caricamento modello...")
        
        # Carica solo il tokenizer per test (più veloce)
        tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
        
        logger.info("✅ Tokenizer caricato correttamente!")
        logger.info(f"🔤 Vocabolario: {len(tokenizer)} tokens")
        
        return True
    except Exception as e:
        logger.error(f"❌ Errore nel test: {e}")
        return False

if __name__ == "__main__":
    print("=== DOWNLOAD MISTRAL 7B ===")
    print("⚠️  Assicurati di avere almeno 30GB di spazio libero")
    print("⚠️  La connessione internet deve essere stabile")
    print("⏳ Il download potrebbe richiedere diverse ore...")
    
    input("Premi INVIO per iniziare il download...")
    
    model_path = download_mistral_7b()
    
    if model_path:
        print(f"\n🎉 Download completato!")
        print(f"📂 Percorso: {model_path}")
        print(f"\n🚀 Per usare il modello:")
        print(f'python main.py --mode worker --coordinator-host localhost --coordinator-port 8765 --model-path "{model_path}" --gpu-id 0')
        
        # Test rapido
        test_model_loading(model_path)
    else:
        print("\n❌ Download fallito")