inference :

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
import math
import os

# ====================================================================
# A. CONFIGURATION MINI-MISTRAL
# (Nécessaire pour que Hugging Face sache comment reconstruire le modèle)
# ====================================================================

class MiniMistralConfig(PretrainedConfig):
    model_type = "mini_mistral"

    def __init__(self,
                 vocab_size=32000,
                 hidden_size=512,
                 num_hidden_layers=8,
                 num_attention_heads=8,
                 num_key_value_heads=2,
                 intermediate_size=2048,
                 max_position_embeddings=4096,
                 rope_theta=10000.0,
                 sliding_window=4096,
                 rms_norm_eps=1e-5,
                 initializer_range=0.02,
                 **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_key_value_heads = num_key_value_heads
        self.intermediate_size = intermediate_size
        self.max_position_embeddings = max_position_embeddings
        self.rope_theta = rope_theta
        self.sliding_window = sliding_window
        self.rms_norm_eps = rms_norm_eps
        self.initializer_range = initializer_range

# ====================================================================
# B. COMPOSANTS DE L'ARCHITECTURE MINI-MISTRAL
# ====================================================================

# RMS Norm
class RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps

    def forward(self, hidden_states):
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
        return self.weight * hidden_states.to(hidden_states.dtype)

# RoPE Helpers
def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

def get_rope_freqs(head_dim, max_seq_len, rope_theta=10000.0):
    inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
    t = torch.arange(max_seq_len, dtype=torch.float32)
    freqs = torch.outer(t, inv_freq).to(torch.float32)
    emb = torch.cat((freqs, freqs), dim=-1)
    return emb.cos(), emb.sin()

# Multi-Head/Grouped-Query Attention (MHA/GQA)
class MiniMistralAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.num_kv_heads = config.num_key_value_heads
        self.head_dim = self.hidden_size // self.num_heads

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        self._cos_sin = None

    def _prepare_for_gqa(self, hidden_states, seq_len):
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # Explicitly reshape with num_heads/num_kv_heads instead of -1
        query_states = query_states.view(hidden_states.shape[0], seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(hidden_states.shape[0], seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(hidden_states.shape[0], seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)

        num_kv_groups = self.num_heads // self.num_kv_heads
        if num_kv_groups > 1:
            key_states = torch.repeat_interleave(key_states, num_kv_groups, dim=1)
            value_states = torch.repeat_interleave(value_states, num_kv_groups, dim=1)

        return query_states, key_states, value_states

    def forward(self, hidden_states, attention_mask=None, position_ids=None):
        B, seq_len, H = hidden_states.shape
        query_states, key_states, value_states = self._prepare_for_gqa(hidden_states, seq_len)

        if self._cos_sin is None:
            cos, sin = get_rope_freqs(self.head_dim, self.config.max_position_embeddings)
            self._cos_sin = (cos.to(hidden_states.device), sin.to(hidden_states.device))

        cos, sin = self._cos_sin

        # position_ids is expected to be (batch_size, seq_len) from model.generate()
        # cos and sin are (max_position_embeddings, head_dim)
        # Index cos/sin with position_ids to get (batch_size, seq_len, head_dim)
        # Then unsqueeze at dim=1 to allow broadcasting with num_heads: (batch_size, 1, seq_len, head_dim)
        cos_emb = cos[position_ids].unsqueeze(1)
        sin_emb = sin[position_ids].unsqueeze(1)

        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos_emb, sin_emb)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attention_mask is not None:
             attn_weights = attn_weights + attention_mask

        attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

        attn_output = torch.matmul(attn_weights, value_states)
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, seq_len, H)

        attn_output = self.o_proj(attn_output)

        return attn_output

# Gated-Gated MLP (SwiGLU)
class MiniMistralMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
        self.act_fn = nn.SiLU()

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

# Decoder Layer
class MiniMistralDecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = MiniMistralAttention(config=config)

        self.mlp = MiniMistralMLP(config)
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(self, hidden_states, attention_mask=None, position_ids=None):
        # Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)

        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
        )
        hidden_states = residual + hidden_states

        # MLP
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states

# Modèle Complet
class MiniMistralModel(PreTrainedModel, GenerationMixin):
    config_class = MiniMistralConfig
    base_model_prefix = "model"
    _no_split_modules = ["MiniMistralDecoderLayer"]

    def __init__(self, config):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList([MiniMistralDecoderLayer(config) for _ in range(config.num_hidden_layers)])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.is_decoder = True
        self.model_parallel = False

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

    # Le forward doit être compatible avec model.generate()
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        inputs_embeds=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        cache_position=None,
    ):

        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.embed_tokens(input_ids)

        # Masque Causal (pour l'auto-régression)
        seq_len = hidden_states.shape[1]
        causal_mask = torch.full((seq_len, seq_len), -torch.inf, device=hidden_states.device)
        causal_mask = torch.triu(causal_mask, diagonal=1)
        final_mask = causal_mask.unsqueeze(0).unsqueeze(0)

        if attention_mask is not None:
            final_mask = final_mask + (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -10000.0

        if position_ids is None:
             position_ids = torch.arange(seq_len, dtype=torch.long, device=hidden_states.device).unsqueeze(0)

        for decoder_layer in self.layers:
            hidden_states = decoder_layer(
                hidden_states,
                attention_mask=final_mask,
                position_ids=position_ids,
            )

        hidden_states = self.norm(hidden_states)
        logits = self.lm_head(hidden_states)

        if not return_dict:
            return (logits,)
        return CausalLMOutputWithPast(
            logits=logits,
            past_key_values=None, # Implement actual past_key_values if using cache
            hidden_states=hidden_states if output_hidden_states else None,
            attentions=None # Implement actual attentions if output_attentions
        )

# ====================================================================
# C. FONCTION PRINCIPALE D'INFÉRENCE
# ====================================================================

def run_inference():
    """
    Charge le modèle Mini-Mistral depuis Hugging Face et exécute la génération de texte.
    """
    # ID du dépôt à charger (DOIT CORRESPONDRE AU DÉPÔT PUBLIÉ)
    MODEL_ID = "Clemylia/Mini-Mistral"

    device = "cuda" if torch.cuda.is_available() else "cpu"

    print(f"🌍 Chargement du modèle {MODEL_ID} sur {device}...")

    try:
        # 1. Chargement du Tokeniseur
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        # 2. Chargement du Modèle et des Poids
        # Le script a besoin de connaître les classes (d'où les définitions ci-dessus)
        model = MiniMistralModel.from_pretrained(MODEL_ID)
        model.to(device)
        model.eval()

        print("✅ Modèle Mini-Mistral chargé avec succès.")

    except Exception as e:
        print(f"❌ Erreur lors du chargement ou du lancement : {e}")
        print("\n**Vérifications :**")
        print("1. Avez-vous publié le modèle sur Hugging Face ?")
        print(f"2. L'ID du dépôt ('{MODEL_ID}') est-il correct ?")
        print("3. Êtes-vous connecté via 'huggingface-cli login' si le dépôt est privé ?")
        return

    # --- 3. Test de Génération ---
    prompt = "The quick brown fox jumps over the lazy dog, and then he decided to"

    print(f"\n🧠 Prompt : **{prompt}**")

    # Tokenisation
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    # Paramètres de génération
    generation_args = {
        "max_length": 60,
        "temperature": 0.8,
        "do_sample": True, # Active l'échantillonnage stochastique
        "top_k": 50,
        "num_return_sequences": 1,
        "pad_token_id": tokenizer.pad_token_id,
        "eos_token_id": tokenizer.eos_token_id,
        "return_dict_in_generate": True, # Ensure generate returns a dictionary-like object
        "output_hidden_states": False, # Not implemented yet
        "output_attentions": False, # Not implemented yet
        "use_cache": False # Temporarily disable cache as past_key_values is not implemented
    }

    # Génération du texte
    with torch.no_grad():
        output_ids = model.generate(input_ids, **generation_args)

    # Décodage du résultat
    generated_text = tokenizer.decode(output_ids.sequences[0], skip_special_tokens=True)

    print(f"\n✨ Résultat généré : **{generated_text}**")
    print("-" * 50)


if __name__ == "__main__":
    # Assurez-vous d'avoir installé les dépendances : pip install torch transformers
    run_inference()
Downloads last month
46
Safetensors
Model size
61.7M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support