spamBERT/test_classify.py
Carsten Abele 38efd20b4d Initial commit: SpamLLM - DistilBERT spam classifier for RSpamd
Multilingual spam classifier (DE/EN) with language detection.
Non-DE/EN mails receive an additional spam score bonus.

- train.py: Fine-tune distilbert-base-multilingual-cased on spam/ham data
- server.py: FastAPI service with langdetect integration
- rspamd/: Lua plugin and config for RSpamd integration
- export_rspamd_data.py: Export Maildir folders to CSV training data
- test_classify.py: Local model validation with DE/EN/foreign test cases

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-19 22:27:05 +01:00

93 lines
3.5 KiB
Python

"""
Teste den Classifier lokal ohne Server.
Nützlich um das Modell nach dem Training schnell zu validieren.
"""
from pathlib import Path
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
MODEL_PATH = Path("./model/final")
# Testmails - DE, EN und Fremdsprachen
TEST_MESSAGES = [
# Deutscher Spam
{
"label": "spam",
"subject": "Herzlichen Glückwunsch! Sie haben gewonnen!",
"body": "Sie wurden als Gewinner ausgewählt! Klicken Sie hier um Ihren Preis von 50.000 Euro abzuholen. Senden Sie uns Ihre Bankdaten zur Überweisung.",
},
{
"label": "spam",
"subject": "Dringend: Ihr Konto wird gesperrt",
"body": "Sehr geehrter Kunde, Ihr Konto wird in 24 Stunden gesperrt. Bestätigen Sie Ihre Identität sofort unter folgendem Link um den Zugang zu behalten.",
},
# Deutscher Ham
{
"label": "ham",
"subject": "Re: Besprechung morgen um 15 Uhr",
"body": "Hallo zusammen, kurze Erinnerung an unser wöchentliches Meeting morgen um 15 Uhr. Bitte bereitet eure Status-Updates vor. Danke!",
},
{
"label": "ham",
"subject": "Rechnung Nr. 4521",
"body": "Vielen Dank für die Zusendung der Rechnung. Ich habe sie an die Buchhaltung weitergeleitet. Die Zahlung erfolgt innerhalb von 30 Tagen.",
},
# Englischer Spam
{
"label": "spam",
"subject": "You won $1,000,000!!!",
"body": "Congratulations! You have been selected as our winner. Click here to claim your prize now! Send us your bank details immediately.",
},
# Englischer Ham
{
"label": "ham",
"subject": "Re: Project timeline update",
"body": "Hi team, just a quick update on the project timeline. We're on track for the Q2 release. Let me know if you have any blockers.",
},
# Fremdsprache (Russisch) - sollte als verdächtig gelten
{
"label": "spam",
"subject": "Специальное предложение для вас",
"body": "Поздравляем! Вы выиграли специальный приз. Нажмите здесь чтобы получить вашу награду прямо сейчас.",
},
# Fremdsprache (Chinesisch)
{
"label": "spam",
"subject": "恭喜您中奖了",
"body": "尊敬的用户,恭喜您被选为幸运用户,请点击链接领取您的奖品。",
},
]
def main():
print(f"Loading model from {MODEL_PATH}...")
tokenizer = AutoTokenizer.from_pretrained(str(MODEL_PATH))
model = AutoModelForSequenceClassification.from_pretrained(str(MODEL_PATH))
print(f"\n{'Label':<6} {'Prediction':<12} {'Confidence':>10} Subject")
print("-" * 70)
correct = 0
for msg in TEST_MESSAGES:
text = f"Subject: {msg['subject']}\n\n{msg['body']}"
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
spam_prob = probs[0][1].item()
prediction = "spam" if spam_prob > 0.5 else "ham"
is_correct = prediction == msg["label"]
correct += int(is_correct)
marker = "+" if is_correct else "X"
print(f"{msg['label']:<6} {prediction:<12} {spam_prob:>9.1%} [{marker}] {msg['subject']}")
print(f"\nAccuracy: {correct}/{len(TEST_MESSAGES)} ({correct / len(TEST_MESSAGES):.0%})")
if __name__ == "__main__":
main()