inference :
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
import math
import os
# ====================================================================
# A. CONFIGURATION MINI-MISTRAL
# (Nécessaire pour que Hugging Face sache comment reconstruire le modèle)
# ====================================================================
class MiniMistralConfig(PretrainedConfig):
model_type = "mini_mistral"
def __init__(self,
vocab_size=32000,
hidden_size=512,
num_hidden_layers=8,
num_attention_heads=8,
num_key_value_heads=2,
intermediate_size=2048,
max_position_embeddings=4096,
rope_theta=10000.0,
sliding_window=4096,
rms_norm_eps=1e-5,
initializer_range=0.02,
**kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.intermediate_size = intermediate_size
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.rms_norm_eps = rms_norm_eps
self.initializer_range = initializer_range
# ====================================================================
# B. COMPOSANTS DE L'ARCHITECTURE MINI-MISTRAL
# ====================================================================
# RMS Norm
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
return self.weight * hidden_states.to(hidden_states.dtype)
# RoPE Helpers
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def get_rope_freqs(head_dim, max_seq_len, rope_theta=10000.0):
inv_freq = 1.0 / (rope_theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(max_seq_len, dtype=torch.float32)
freqs = torch.outer(t, inv_freq).to(torch.float32)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos(), emb.sin()
# Multi-Head/Grouped-Query Attention (MHA/GQA)
class MiniMistralAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = self.hidden_size // self.num_heads
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self._cos_sin = None
def _prepare_for_gqa(self, hidden_states, seq_len):
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Explicitly reshape with num_heads/num_kv_heads instead of -1
query_states = query_states.view(hidden_states.shape[0], seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(hidden_states.shape[0], seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(hidden_states.shape[0], seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
num_kv_groups = self.num_heads // self.num_kv_heads
if num_kv_groups > 1:
key_states = torch.repeat_interleave(key_states, num_kv_groups, dim=1)
value_states = torch.repeat_interleave(value_states, num_kv_groups, dim=1)
return query_states, key_states, value_states
def forward(self, hidden_states, attention_mask=None, position_ids=None):
B, seq_len, H = hidden_states.shape
query_states, key_states, value_states = self._prepare_for_gqa(hidden_states, seq_len)
if self._cos_sin is None:
cos, sin = get_rope_freqs(self.head_dim, self.config.max_position_embeddings)
self._cos_sin = (cos.to(hidden_states.device), sin.to(hidden_states.device))
cos, sin = self._cos_sin
# position_ids is expected to be (batch_size, seq_len) from model.generate()
# cos and sin are (max_position_embeddings, head_dim)
# Index cos/sin with position_ids to get (batch_size, seq_len, head_dim)
# Then unsqueeze at dim=1 to allow broadcasting with num_heads: (batch_size, 1, seq_len, head_dim)
cos_emb = cos[position_ids].unsqueeze(1)
sin_emb = sin[position_ids].unsqueeze(1)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos_emb, sin_emb)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous().view(B, seq_len, H)
attn_output = self.o_proj(attn_output)
return attn_output
# Gated-Gated MLP (SwiGLU)
class MiniMistralMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.act_fn = nn.SiLU()
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
# Decoder Layer
class MiniMistralDecoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = MiniMistralAttention(config=config)
self.mlp = MiniMistralMLP(config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, hidden_states, attention_mask=None, position_ids=None):
# Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
)
hidden_states = residual + hidden_states
# MLP
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
# Modèle Complet
class MiniMistralModel(PreTrainedModel, GenerationMixin):
config_class = MiniMistralConfig
base_model_prefix = "model"
_no_split_modules = ["MiniMistralDecoderLayer"]
def __init__(self, config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([MiniMistralDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.is_decoder = True
self.model_parallel = False
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
# Le forward doit être compatible avec model.generate()
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
inputs_embeds=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
cache_position=None,
):
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)
# Masque Causal (pour l'auto-régression)
seq_len = hidden_states.shape[1]
causal_mask = torch.full((seq_len, seq_len), -torch.inf, device=hidden_states.device)
causal_mask = torch.triu(causal_mask, diagonal=1)
final_mask = causal_mask.unsqueeze(0).unsqueeze(0)
if attention_mask is not None:
final_mask = final_mask + (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -10000.0
if position_ids is None:
position_ids = torch.arange(seq_len, dtype=torch.long, device=hidden_states.device).unsqueeze(0)
for decoder_layer in self.layers:
hidden_states = decoder_layer(
hidden_states,
attention_mask=final_mask,
position_ids=position_ids,
)
hidden_states = self.norm(hidden_states)
logits = self.lm_head(hidden_states)
if not return_dict:
return (logits,)
return CausalLMOutputWithPast(
logits=logits,
past_key_values=None, # Implement actual past_key_values if using cache
hidden_states=hidden_states if output_hidden_states else None,
attentions=None # Implement actual attentions if output_attentions
)
# ====================================================================
# C. FONCTION PRINCIPALE D'INFÉRENCE
# ====================================================================
def run_inference():
"""
Charge le modèle Mini-Mistral depuis Hugging Face et exécute la génération de texte.
"""
# ID du dépôt à charger (DOIT CORRESPONDRE AU DÉPÔT PUBLIÉ)
MODEL_ID = "Clemylia/Mini-Mistral"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🌍 Chargement du modèle {MODEL_ID} sur {device}...")
try:
# 1. Chargement du Tokeniseur
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 2. Chargement du Modèle et des Poids
# Le script a besoin de connaître les classes (d'où les définitions ci-dessus)
model = MiniMistralModel.from_pretrained(MODEL_ID)
model.to(device)
model.eval()
print("✅ Modèle Mini-Mistral chargé avec succès.")
except Exception as e:
print(f"❌ Erreur lors du chargement ou du lancement : {e}")
print("\n**Vérifications :**")
print("1. Avez-vous publié le modèle sur Hugging Face ?")
print(f"2. L'ID du dépôt ('{MODEL_ID}') est-il correct ?")
print("3. Êtes-vous connecté via 'huggingface-cli login' si le dépôt est privé ?")
return
# --- 3. Test de Génération ---
prompt = "The quick brown fox jumps over the lazy dog, and then he decided to"
print(f"\n🧠 Prompt : **{prompt}**")
# Tokenisation
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
# Paramètres de génération
generation_args = {
"max_length": 60,
"temperature": 0.8,
"do_sample": True, # Active l'échantillonnage stochastique
"top_k": 50,
"num_return_sequences": 1,
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": tokenizer.eos_token_id,
"return_dict_in_generate": True, # Ensure generate returns a dictionary-like object
"output_hidden_states": False, # Not implemented yet
"output_attentions": False, # Not implemented yet
"use_cache": False # Temporarily disable cache as past_key_values is not implemented
}
# Génération du texte
with torch.no_grad():
output_ids = model.generate(input_ids, **generation_args)
# Décodage du résultat
generated_text = tokenizer.decode(output_ids.sequences[0], skip_special_tokens=True)
print(f"\n✨ Résultat généré : **{generated_text}**")
print("-" * 50)
if __name__ == "__main__":
# Assurez-vous d'avoir installé les dépendances : pip install torch transformers
run_inference()
- Downloads last month
- 46