Spaces:
Runtime error
Runtime error
Upload 4 files
Browse files- app.py +139 -0
- data_processing.py +102 -0
- generate_response.py +173 -0
- requirements.txt +7 -0
app.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Application Gradio pour l'analyse de sentiment d'avis Amazon
|
| 3 |
+
avec génération automatique de réponses pour le service client
|
| 4 |
+
|
| 5 |
+
VERSION avec CroissantLLMChat - Modèle français bilingue 1.3B
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import gradio as gr
|
| 9 |
+
from data_processing import clean_text, label_to_sentiment
|
| 10 |
+
from generate_response import generer_reponse, load_model
|
| 11 |
+
import time
|
| 12 |
+
|
| 13 |
+
# Précharger le modèle CroissantLLMChat au démarrage
|
| 14 |
+
print("🥐 Préchargement de CroissantLLMChat (modèle français 1.3B)...")
|
| 15 |
+
load_model()
|
| 16 |
+
print("✅ Application prête !")
|
| 17 |
+
|
| 18 |
+
def analyze_review(review_text: str, sentiment_choice: str = "auto") -> tuple:
|
| 19 |
+
"""
|
| 20 |
+
Analyse un avis client et génère une réponse si négatif
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
review_text (str): Texte de l'avis client
|
| 24 |
+
sentiment_choice (str): "auto" pour détection auto, ou "positif"/"negatif"
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
tuple: (texte_nettoye, sentiment_affichage, reponse_affichage, temps)
|
| 28 |
+
"""
|
| 29 |
+
start_time = time.time()
|
| 30 |
+
|
| 31 |
+
# 1. Nettoyage du texte
|
| 32 |
+
cleaned = clean_text(review_text)
|
| 33 |
+
|
| 34 |
+
# 2. Détection du sentiment
|
| 35 |
+
if sentiment_choice == "auto":
|
| 36 |
+
# Détection automatique basique
|
| 37 |
+
mots_negatifs = ["mauvais", "nul", "déçu", "cassé", "retard", "problème",
|
| 38 |
+
"défectueux", "horrible", "arnaque", "pas", "ne", "aucun",
|
| 39 |
+
"inacceptable", "mécontentent", "insatisfait"]
|
| 40 |
+
|
| 41 |
+
mots_avis = cleaned.lower().split()
|
| 42 |
+
count_negatif = sum(1 for mot in mots_avis if any(neg in mot for neg in mots_negatifs))
|
| 43 |
+
|
| 44 |
+
sentiment = "negatif" if count_negatif >= 1 else "positif"
|
| 45 |
+
else:
|
| 46 |
+
sentiment = sentiment_choice.lower()
|
| 47 |
+
|
| 48 |
+
# Affichage du sentiment
|
| 49 |
+
if sentiment == "negatif":
|
| 50 |
+
sentiment_display = "🔴 **NÉGATIF**"
|
| 51 |
+
else:
|
| 52 |
+
sentiment_display = "🟢 **POSITIF**"
|
| 53 |
+
|
| 54 |
+
# 3. Génération de réponse (uniquement si négatif)
|
| 55 |
+
if sentiment == "negatif":
|
| 56 |
+
try:
|
| 57 |
+
response = generer_reponse(cleaned, max_tokens=120, temperature=0.7)
|
| 58 |
+
response_display = f"📧 **Réponse générée (CroissantLLMChat) :**\n\n{response}"
|
| 59 |
+
except Exception as e:
|
| 60 |
+
response = f"[Erreur : {e}]"
|
| 61 |
+
response_display = f"❌ Erreur lors de la génération : {e}"
|
| 62 |
+
else:
|
| 63 |
+
response = ""
|
| 64 |
+
response_display = "✅ Avis positif - Aucune réponse nécessaire"
|
| 65 |
+
|
| 66 |
+
# Temps d'exécution
|
| 67 |
+
elapsed_time = time.time() - start_time
|
| 68 |
+
|
| 69 |
+
return (
|
| 70 |
+
f"**Texte nettoyé :** {cleaned}",
|
| 71 |
+
f"**Sentiment détecté :** {sentiment_display}",
|
| 72 |
+
response_display,
|
| 73 |
+
f"⏱️ **Analyse terminée en {elapsed_time:.2f}s**"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Interface Gradio
|
| 77 |
+
with gr.Blocks(title="Analyse de Sentiment Amazon + Réponses Auto") as demo:
|
| 78 |
+
|
| 79 |
+
gr.Markdown("""
|
| 80 |
+
# 🛍️ Analyse de Sentiment d'Avis Amazon
|
| 81 |
+
## Pipeline IA complet : Nettoyage + Sentiment + Génération de réponses
|
| 82 |
+
|
| 83 |
+
**Projet Master IA** - Coralie | **Modèle** : CroissantLLMChat (1.3B, bilingue FR/EN)
|
| 84 |
+
|
| 85 |
+
🥐 **Version avec CroissantLLM** - Modèle français développé par CentraleSupélec.
|
| 86 |
+
Ce projet utilise un **pipeline CI/CD automatique** via Hugging Face Spaces.
|
| 87 |
+
""")
|
| 88 |
+
|
| 89 |
+
with gr.Row():
|
| 90 |
+
with gr.Column():
|
| 91 |
+
gr.Markdown("### 📝 Avis client Amazon")
|
| 92 |
+
|
| 93 |
+
review_input = gr.Textbox(
|
| 94 |
+
label="Avis client Amazon",
|
| 95 |
+
placeholder="Le produit est arrivé cassé et le service client ne répond pas. Très déçu !",
|
| 96 |
+
lines=5
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
sentiment_radio = gr.Radio(
|
| 100 |
+
choices=["auto", "positif", "negatif"],
|
| 101 |
+
value="auto",
|
| 102 |
+
label="🎯 Sentiment (optionnel - sinon détection auto)",
|
| 103 |
+
info="Laissez sur 'auto' pour détection automatique"
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
analyze_btn = gr.Button("🔍 Analyser l'avis", variant="primary")
|
| 107 |
+
|
| 108 |
+
with gr.Column():
|
| 109 |
+
gr.Markdown("### 📊 Résultats de l'analyse")
|
| 110 |
+
|
| 111 |
+
timing_output = gr.Markdown(label="Temps d'exécution")
|
| 112 |
+
cleaned_output = gr.Markdown(label="Texte nettoyé")
|
| 113 |
+
sentiment_output = gr.Markdown(label="Sentiment")
|
| 114 |
+
response_output = gr.Markdown(label="Réponse générée")
|
| 115 |
+
|
| 116 |
+
# Bouton d'analyse
|
| 117 |
+
analyze_btn.click(
|
| 118 |
+
fn=analyze_review,
|
| 119 |
+
inputs=[review_input, sentiment_radio],
|
| 120 |
+
outputs=[cleaned_output, sentiment_output, response_output, timing_output]
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Section d'exemples
|
| 124 |
+
gr.Markdown("""
|
| 125 |
+
---
|
| 126 |
+
### 💡 Exemples d'avis à tester :
|
| 127 |
+
|
| 128 |
+
**Avis négatif :** "Le produit est arrivé cassé et le service client ne répond pas. Très déçu !"
|
| 129 |
+
|
| 130 |
+
**Avis positif :** "Excellent produit, livraison rapide. Je recommande !"
|
| 131 |
+
""")
|
| 132 |
+
|
| 133 |
+
# Debug toggle
|
| 134 |
+
with gr.Accordion("🔧 Texte nettoyé (debug)", open=False):
|
| 135 |
+
gr.Markdown("Affiche le texte après nettoyage")
|
| 136 |
+
|
| 137 |
+
# Lancer l'application
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
data_processing.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Module de traitement des données pour l'analyse de sentiment
|
| 3 |
+
Nettoyage des textes et labellisation
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import re
|
| 7 |
+
import string
|
| 8 |
+
|
| 9 |
+
# Liste des stopwords français
|
| 10 |
+
FRENCH_STOPWORDS = {
|
| 11 |
+
"a", "à", "ai", "aie", "aient", "aies", "ait", "alors", "as", "au", "aucun", "aura",
|
| 12 |
+
"aurai", "auraient", "aurais", "aurait", "auve", "avec", "avez", "aviez", "avions",
|
| 13 |
+
"avoir", "avons", "bon", "car", "ce", "cela", "ces", "cet", "cette", "ceux", "chaque",
|
| 14 |
+
"comme", "d", "dans", "de", "des", "du", "elle", "en", "encore", "est", "et", "eu",
|
| 15 |
+
"fait", "faites", "fois", "ici", "il", "ils", "je", "la", "le", "les", "leur", "lui",
|
| 16 |
+
"mais", "me", "mes", "moi", "mon", "ne", "nos", "notre", "nous", "on", "ou", "par",
|
| 17 |
+
"pas", "pour", "plus", "qu", "que", "qui", "sa", "se", "ses", "son", "sur",
|
| 18 |
+
"ta", "te", "tes", "toi", "ton", "toujours", "tout", "tous", "très", "tu",
|
| 19 |
+
"un", "une", "vos", "votre", "vous", "y"
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
# Table de traduction pour remplacer la ponctuation par des espaces
|
| 23 |
+
PUNCT_TABLE = str.maketrans({c: " " for c in string.punctuation})
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def clean_text(text: str) -> str:
|
| 27 |
+
"""
|
| 28 |
+
Nettoie un texte d'avis client :
|
| 29 |
+
- Conversion en minuscules
|
| 30 |
+
- Suppression de la ponctuation
|
| 31 |
+
- Suppression des chiffres
|
| 32 |
+
- Suppression des stopwords français
|
| 33 |
+
- Normalisation des espaces
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
text (str): Texte brut à nettoyer
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
str: Texte nettoyé
|
| 40 |
+
"""
|
| 41 |
+
if not isinstance(text, str):
|
| 42 |
+
return ""
|
| 43 |
+
|
| 44 |
+
# 1. Minuscules
|
| 45 |
+
text = text.lower()
|
| 46 |
+
|
| 47 |
+
# 2. Suppression de la ponctuation
|
| 48 |
+
text = text.translate(PUNCT_TABLE)
|
| 49 |
+
|
| 50 |
+
# 3. Suppression des chiffres
|
| 51 |
+
text = re.sub(r"\d+", " ", text)
|
| 52 |
+
|
| 53 |
+
# 4. Normalisation des espaces
|
| 54 |
+
text = re.sub(r"\s+", " ", text).strip()
|
| 55 |
+
|
| 56 |
+
# 5. Suppression des stopwords
|
| 57 |
+
tokens = [tok for tok in text.split() if tok not in FRENCH_STOPWORDS]
|
| 58 |
+
|
| 59 |
+
return " ".join(tokens)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def label_to_sentiment(label_value: int) -> str:
|
| 63 |
+
"""
|
| 64 |
+
Convertit un label numérique (1-5 étoiles) en sentiment positif/négatif
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
label_value (int): Note de 1 à 5 étoiles
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
str: "positif" si >= 3 étoiles, "negatif" sinon
|
| 71 |
+
"""
|
| 72 |
+
try:
|
| 73 |
+
v = int(label_value)
|
| 74 |
+
except Exception:
|
| 75 |
+
v = 0
|
| 76 |
+
|
| 77 |
+
return "positif" if v >= 3 else "negatif"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def make_fake_email(index: int) -> str:
|
| 81 |
+
"""
|
| 82 |
+
Génère un email factice pour un client
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
index (int): Numéro du client
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
str: Email au format [email protected]
|
| 89 |
+
"""
|
| 90 |
+
return f"client{index:05d}@example.com"
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
# Tests
|
| 95 |
+
test_text = "Je suis TRÈS déçu de ce produit ! Il est arrivé cassé et le service client ne répond pas..."
|
| 96 |
+
print(f"Original : {test_text}")
|
| 97 |
+
print(f"Nettoyé : {clean_text(test_text)}")
|
| 98 |
+
|
| 99 |
+
print(f"\nLabel 1 → {label_to_sentiment(1)}")
|
| 100 |
+
print(f"Label 5 → {label_to_sentiment(5)}")
|
| 101 |
+
|
| 102 |
+
print(f"\nEmail test : {make_fake_email(42)}")
|
generate_response.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Module de génération de réponses pour le service client Amazon
|
| 3 |
+
Utilise CroissantLLMChat - Modèle bilingue français-anglais optimisé
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 8 |
+
|
| 9 |
+
# Modèle bilingue français-anglais spécialement conçu pour le français
|
| 10 |
+
MODEL_NAME = "croissantllm/CroissantLLMChat-v0.1"
|
| 11 |
+
|
| 12 |
+
# Variables globales pour le modèle
|
| 13 |
+
model = None
|
| 14 |
+
tokenizer = None
|
| 15 |
+
|
| 16 |
+
def load_model():
|
| 17 |
+
"""
|
| 18 |
+
Charge le modèle CroissantLLMChat et son tokenizer
|
| 19 |
+
CroissantLLM est un modèle 1.3B VRAIMENT bilingue (50% FR / 50% EN)
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
tuple: (model, tokenizer) chargés
|
| 23 |
+
"""
|
| 24 |
+
global model, tokenizer
|
| 25 |
+
|
| 26 |
+
print(f"🔄 Chargement du modèle {MODEL_NAME}...")
|
| 27 |
+
print("⏳ CroissantLLM est un modèle français de 1.3B paramètres (~2-3 GB)")
|
| 28 |
+
|
| 29 |
+
# Charger le tokenizer
|
| 30 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 31 |
+
|
| 32 |
+
# Charger le modèle en float32 pour CPU
|
| 33 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 34 |
+
MODEL_NAME,
|
| 35 |
+
torch_dtype=torch.float32,
|
| 36 |
+
device_map="cpu",
|
| 37 |
+
low_cpu_mem_usage=True
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
print("✅ Modèle CroissantLLMChat chargé avec succès !")
|
| 41 |
+
print("🥐 Modèle français bilingue prêt !")
|
| 42 |
+
|
| 43 |
+
return model, tokenizer
|
| 44 |
+
|
| 45 |
+
def build_chat_messages(review_text: str) -> list:
|
| 46 |
+
"""
|
| 47 |
+
Construit les messages pour CroissantLLMChat
|
| 48 |
+
Format officiel avec apply_chat_template
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
review_text (str): Texte de l'avis client négatif
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
list: Messages formatés pour CroissantLLMChat
|
| 55 |
+
"""
|
| 56 |
+
# CroissantLLMChat utilise un format chat officiel
|
| 57 |
+
# Avec un message utilisateur clair
|
| 58 |
+
chat_messages = [
|
| 59 |
+
{
|
| 60 |
+
"role": "user",
|
| 61 |
+
"content": f"""Tu es un agent du service client Amazon. Réponds en français à cet avis négatif avec empathie et professionnalisme :
|
| 62 |
+
|
| 63 |
+
"{review_text}"
|
| 64 |
+
|
| 65 |
+
Réponds en présentant des excuses, en reconnaissant le problème, et en proposant une solution concrète (remboursement ou échange)."""
|
| 66 |
+
}
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
return chat_messages
|
| 70 |
+
|
| 71 |
+
def generer_reponse(review_text: str, max_tokens: int = 120, temperature: float = 0.7) -> str:
|
| 72 |
+
"""
|
| 73 |
+
Génère une réponse au service client pour un avis négatif
|
| 74 |
+
Utilise CroissantLLMChat avec apply_chat_template (méthode officielle)
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
review_text (str): Texte de l'avis client négatif
|
| 78 |
+
max_tokens (int): Nombre maximum de tokens à générer
|
| 79 |
+
temperature (float): Température de génération (0.7 = équilibré)
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
str: Réponse générée par le modèle EN FRANÇAIS
|
| 83 |
+
"""
|
| 84 |
+
global model, tokenizer
|
| 85 |
+
|
| 86 |
+
# Charger le modèle si pas encore fait
|
| 87 |
+
if model is None or tokenizer is None:
|
| 88 |
+
load_model()
|
| 89 |
+
|
| 90 |
+
# Construire les messages au format chat
|
| 91 |
+
chat_messages = build_chat_messages(review_text)
|
| 92 |
+
|
| 93 |
+
# Appliquer le template officiel de CroissantLLMChat
|
| 94 |
+
# C'est la méthode recommandée dans la documentation
|
| 95 |
+
chat_input = tokenizer.apply_chat_template(
|
| 96 |
+
chat_messages,
|
| 97 |
+
tokenize=False,
|
| 98 |
+
add_generation_prompt=True
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Tokeniser le chat formaté
|
| 102 |
+
inputs = tokenizer(
|
| 103 |
+
chat_input,
|
| 104 |
+
return_tensors="pt",
|
| 105 |
+
max_length=512,
|
| 106 |
+
truncation=True
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# Générer avec CroissantLLMChat
|
| 110 |
+
# Température 0.7 recommandée (doc dit 0.3+ minimum)
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
outputs = model.generate(
|
| 113 |
+
inputs.input_ids,
|
| 114 |
+
attention_mask=inputs.attention_mask,
|
| 115 |
+
max_new_tokens=max_tokens,
|
| 116 |
+
temperature=temperature,
|
| 117 |
+
do_sample=True,
|
| 118 |
+
top_p=0.9,
|
| 119 |
+
top_k=50,
|
| 120 |
+
repetition_penalty=1.2,
|
| 121 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 122 |
+
eos_token_id=tokenizer.eos_token_id
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# MÉTHODE AMÉLIORÉE : Décoder UNIQUEMENT les nouveaux tokens
|
| 126 |
+
# On ne décode PAS le prompt d'entrée
|
| 127 |
+
input_length = inputs.input_ids.shape[1]
|
| 128 |
+
generated_tokens = outputs[0][input_length:] # Prendre uniquement les tokens générés
|
| 129 |
+
|
| 130 |
+
# Décoder uniquement la réponse générée
|
| 131 |
+
answer = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
| 132 |
+
|
| 133 |
+
# Nettoyer les tokens spéciaux qui pourraient rester
|
| 134 |
+
special_tokens = ["<|im_start|>", "<|im_end|>", "assistant", "user", "system"]
|
| 135 |
+
for token in special_tokens:
|
| 136 |
+
answer = answer.replace(token, "")
|
| 137 |
+
|
| 138 |
+
# Nettoyer espaces multiples
|
| 139 |
+
answer = ' '.join(answer.split())
|
| 140 |
+
answer = answer.strip()
|
| 141 |
+
|
| 142 |
+
# Limiter à 3-4 phrases maximum
|
| 143 |
+
sentences = answer.split('.')
|
| 144 |
+
clean_sentences = [s.strip() for s in sentences if s.strip()]
|
| 145 |
+
if len(clean_sentences) > 4:
|
| 146 |
+
answer = '. '.join(clean_sentences[:4]) + '.'
|
| 147 |
+
else:
|
| 148 |
+
answer = '. '.join(clean_sentences)
|
| 149 |
+
if not answer.endswith('.'):
|
| 150 |
+
answer += '.'
|
| 151 |
+
|
| 152 |
+
return answer
|
| 153 |
+
|
| 154 |
+
# Test du module
|
| 155 |
+
if __name__ == "__main__":
|
| 156 |
+
print("🧪 Test du module de génération avec CroissantLLMChat\n")
|
| 157 |
+
|
| 158 |
+
# Charger le modèle
|
| 159 |
+
load_model()
|
| 160 |
+
|
| 161 |
+
# Test 1
|
| 162 |
+
avis_test_1 = "Le produit est arrivé cassé et le service client ne répond pas. Très déçu !"
|
| 163 |
+
print(f"📝 Avis test 1: {avis_test_1}")
|
| 164 |
+
reponse_1 = generer_reponse(avis_test_1)
|
| 165 |
+
print(f"💬 Réponse: {reponse_1}\n")
|
| 166 |
+
|
| 167 |
+
# Test 2
|
| 168 |
+
avis_test_2 = "Livraison en retard de 2 semaines, produit endommagé."
|
| 169 |
+
print(f"📝 Avis test 2: {avis_test_2}")
|
| 170 |
+
reponse_2 = generer_reponse(avis_test_2)
|
| 171 |
+
print(f"💬 Réponse: {reponse_2}\n")
|
| 172 |
+
|
| 173 |
+
print("✅ Tests terminés !")
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.44.0
|
| 2 |
+
transformers>=4.35.0
|
| 3 |
+
torch>=2.0.0
|
| 4 |
+
accelerate>=0.24.0
|
| 5 |
+
datasets>=2.14.0
|
| 6 |
+
pandas>=2.0.0
|
| 7 |
+
sentencepiece>=0.1.99
|