Initial commit: SpamLLM - DistilBERT spam classifier for RSpamd
Multilingual spam classifier (DE/EN) with language detection. Non-DE/EN mails receive an additional spam score bonus. - train.py: Fine-tune distilbert-base-multilingual-cased on spam/ham data - server.py: FastAPI service with langdetect integration - rspamd/: Lua plugin and config for RSpamd integration - export_rspamd_data.py: Export Maildir folders to CSV training data - test_classify.py: Local model validation with DE/EN/foreign test cases Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
commit
38efd20b4d
7 changed files with 671 additions and 0 deletions
136
export_rspamd_data.py
Normal file
136
export_rspamd_data.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
"""
|
||||
Exportiert Mails aus Maildir-Ordnern als Trainingsdaten für SpamLLM.
|
||||
|
||||
Erwartet eine typische Maildir-Struktur:
|
||||
- Spam-Ordner: z.B. ~/.spam/ oder /var/vmail/user/.Junk/
|
||||
- Ham-Ordner: z.B. ~/Maildir/ oder /var/vmail/user/.INBOX/
|
||||
|
||||
Erzeugt: data/train.csv mit Spalten: text, labels (0=ham, 1=spam)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import email
|
||||
import email.policy
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def extract_text_from_email(filepath: str) -> dict | None:
|
||||
"""Extrahiert Subject und Body aus einer E-Mail-Datei."""
|
||||
try:
|
||||
with open(filepath, "rb") as f:
|
||||
msg = email.message_from_binary_file(f, policy=email.policy.default)
|
||||
|
||||
subject = msg.get("Subject", "")
|
||||
from_addr = msg.get("From", "")
|
||||
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/plain":
|
||||
content = part.get_content()
|
||||
if isinstance(content, str):
|
||||
body += content
|
||||
else:
|
||||
content = msg.get_content()
|
||||
if isinstance(content, str):
|
||||
body = content
|
||||
|
||||
# Auf sinnvolle Länge begrenzen
|
||||
body = body[:4096]
|
||||
|
||||
if not subject and not body:
|
||||
return None
|
||||
|
||||
text = f"From: {from_addr}\nSubject: {subject}\n\n{body}"
|
||||
return {"text": text}
|
||||
|
||||
except Exception as e:
|
||||
print(f" Skipping {filepath}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def collect_mails(directory: str, label: int, max_count: int = 0) -> list[dict]:
|
||||
"""Sammelt Mails aus einem Maildir-Verzeichnis."""
|
||||
results = []
|
||||
mail_dir = Path(directory)
|
||||
|
||||
if not mail_dir.exists():
|
||||
print(f"WARNING: {directory} existiert nicht!")
|
||||
return results
|
||||
|
||||
# Maildir hat typischerweise cur/, new/, tmp/ Unterordner
|
||||
search_dirs = [mail_dir]
|
||||
for subdir in ["cur", "new"]:
|
||||
sub = mail_dir / subdir
|
||||
if sub.exists():
|
||||
search_dirs.append(sub)
|
||||
|
||||
files = []
|
||||
for search_dir in search_dirs:
|
||||
for f in search_dir.iterdir():
|
||||
if f.is_file() and not f.name.startswith("."):
|
||||
files.append(f)
|
||||
|
||||
if max_count > 0 and len(files) > max_count:
|
||||
random.shuffle(files)
|
||||
files = files[:max_count]
|
||||
|
||||
label_name = "spam" if label == 1 else "ham"
|
||||
print(f"Processing {len(files)} {label_name} mails from {directory}...")
|
||||
|
||||
for filepath in files:
|
||||
extracted = extract_text_from_email(str(filepath))
|
||||
if extracted:
|
||||
extracted["labels"] = label
|
||||
results.append(extracted)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Export Maildir to CSV training data")
|
||||
parser.add_argument("--spam-dir", required=True, help="Path to spam Maildir")
|
||||
parser.add_argument("--ham-dir", required=True, help="Path to ham Maildir")
|
||||
parser.add_argument("--output", default="data/train.csv", help="Output CSV path")
|
||||
parser.add_argument("--max-per-class", type=int, default=0, help="Max mails per class (0=all)")
|
||||
parser.add_argument("--test-split", type=float, default=0.2, help="Test set ratio")
|
||||
args = parser.parse_args()
|
||||
|
||||
spam_mails = collect_mails(args.spam_dir, label=1, max_count=args.max_per_class)
|
||||
ham_mails = collect_mails(args.ham_dir, label=0, max_count=args.max_per_class)
|
||||
|
||||
all_mails = spam_mails + ham_mails
|
||||
random.shuffle(all_mails)
|
||||
|
||||
print(f"\nTotal: {len(all_mails)} mails ({len(spam_mails)} spam, {len(ham_mails)} ham)")
|
||||
|
||||
# Train/Test Split
|
||||
split_idx = int(len(all_mails) * (1 - args.test_split))
|
||||
train_data = all_mails[:split_idx]
|
||||
test_data = all_mails[split_idx:]
|
||||
|
||||
# Verzeichnis erstellen
|
||||
output_path = Path(args.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Train CSV
|
||||
with open(output_path, "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=["text", "labels"])
|
||||
writer.writeheader()
|
||||
writer.writerows(train_data)
|
||||
print(f"Train set: {len(train_data)} mails -> {output_path}")
|
||||
|
||||
# Test CSV
|
||||
test_path = output_path.parent / "test.csv"
|
||||
with open(test_path, "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=["text", "labels"])
|
||||
writer.writeheader()
|
||||
writer.writerows(test_data)
|
||||
print(f"Test set: {len(test_data)} mails -> {test_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
9
requirements.txt
Normal file
9
requirements.txt
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
torch>=2.0.0
|
||||
transformers>=4.36.0
|
||||
fastapi>=0.104.0
|
||||
uvicorn>=0.24.0
|
||||
pydantic>=2.0.0
|
||||
datasets>=2.16.0
|
||||
scikit-learn>=1.3.0
|
||||
accelerate>=0.25.0
|
||||
langdetect>=1.0.9
|
||||
26
rspamd/local.d/external_services.conf
Normal file
26
rspamd/local.d/external_services.conf
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
# RSpamd External Service Konfiguration für SpamLLM
|
||||
# Kopiere diese Datei nach /etc/rspamd/local.d/external_services.conf
|
||||
|
||||
spamllm {
|
||||
# Typ: HTTP-basierter externer Service
|
||||
type = "http";
|
||||
|
||||
# URL des SpamLLM FastAPI Service
|
||||
url = "http://127.0.0.1:8000/classify";
|
||||
|
||||
# Timeout in Sekunden
|
||||
timeout = 5.0;
|
||||
|
||||
# Maximale Nachrichtengröße die an den Service gesendet wird (in Bytes)
|
||||
max_size = 50k;
|
||||
|
||||
# Symbol das bei Spam-Erkennung gesetzt wird
|
||||
symbol = "SPAMLLM_SPAM";
|
||||
|
||||
# Score der dem Symbol zugewiesen wird (wird durch den Service dynamisch gesetzt)
|
||||
weight = 5.0;
|
||||
|
||||
# Nur Mails im Graubereich prüfen (Score zwischen 3 und 12)
|
||||
# Das spart Ressourcen: offensichtlicher Spam/Ham wird nicht an LLM geschickt
|
||||
condition = "not rspamd_config.is_local(task:get_from_ip()) and task:get_metric_score('default') > 3 and task:get_metric_score('default') < 12";
|
||||
}
|
||||
129
rspamd/lua/spamllm.lua
Normal file
129
rspamd/lua/spamllm.lua
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
-- RSpamd Lua Plugin für SpamLLM
|
||||
-- Kopiere nach /etc/rspamd/plugins.d/spamllm.lua
|
||||
--
|
||||
-- Dieser Plugin sendet Mail-Daten an den SpamLLM HTTP Service
|
||||
-- und setzt den Score basierend auf der Antwort.
|
||||
|
||||
local rspamd_http = require "rspamd_http"
|
||||
local rspamd_logger = require "rspamd_logger"
|
||||
local ucl = require "ucl"
|
||||
|
||||
local N = "spamllm"
|
||||
|
||||
local settings = {
|
||||
url = "http://127.0.0.1:8000/classify",
|
||||
timeout = 5.0,
|
||||
symbol_spam = "SPAMLLM_SPAM",
|
||||
symbol_ham = "SPAMLLM_HAM",
|
||||
symbol_foreign = "SPAMLLM_FOREIGN_LANG",
|
||||
threshold = 0.5,
|
||||
max_body_length = 4096,
|
||||
enabled = true,
|
||||
}
|
||||
|
||||
local function check_spamllm(task)
|
||||
-- Extrahiere Mail-Daten
|
||||
local from = task:get_from("smtp")
|
||||
local from_addr = ""
|
||||
if from and from[1] then
|
||||
from_addr = from[1].addr or ""
|
||||
end
|
||||
|
||||
local subject = task:get_subject() or ""
|
||||
|
||||
local text_parts = task:get_text_parts()
|
||||
local body = ""
|
||||
if text_parts then
|
||||
for _, part in ipairs(text_parts) do
|
||||
local content = part:get_content()
|
||||
if content then
|
||||
body = body .. tostring(content)
|
||||
if #body > settings.max_body_length then
|
||||
body = body:sub(1, settings.max_body_length)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- JSON Request Body
|
||||
local request_body = string.format(
|
||||
'{"from_addr":"%s","subject":"%s","body":"%s"}',
|
||||
from_addr:gsub('"', '\\"'),
|
||||
subject:gsub('"', '\\"'),
|
||||
body:gsub('"', '\\"'):gsub('\n', '\\n'):gsub('\r', '\\r')
|
||||
)
|
||||
|
||||
local function callback(err, code, response_body)
|
||||
if err then
|
||||
rspamd_logger.errx(task, "SpamLLM request failed: %s", err)
|
||||
return
|
||||
end
|
||||
|
||||
if code ~= 200 then
|
||||
rspamd_logger.errx(task, "SpamLLM returned HTTP %s", code)
|
||||
return
|
||||
end
|
||||
|
||||
local parser = ucl.parser()
|
||||
local ok, parse_err = parser:parse_string(response_body)
|
||||
if not ok then
|
||||
rspamd_logger.errx(task, "SpamLLM JSON parse error: %s", parse_err)
|
||||
return
|
||||
end
|
||||
|
||||
local result = parser:get_object()
|
||||
|
||||
if result.is_spam then
|
||||
task:insert_result(settings.symbol_spam, result.confidence, "SpamLLM")
|
||||
rspamd_logger.infox(task, "SpamLLM: SPAM (confidence=%.2f, score=%.2f, lang=%s)",
|
||||
result.confidence, result.score, result.language or "?")
|
||||
else
|
||||
task:insert_result(settings.symbol_ham, -result.confidence, "SpamLLM")
|
||||
end
|
||||
|
||||
-- Fremdsprachen-Bonus als separates Symbol
|
||||
if result.foreign_lang_bonus and result.foreign_lang_bonus > 0 then
|
||||
task:insert_result(settings.symbol_foreign, result.foreign_lang_bonus / 4.0,
|
||||
string.format("lang=%s", result.language or "unknown"))
|
||||
rspamd_logger.infox(task, "SpamLLM: Foreign language detected: %s (bonus=%.1f)",
|
||||
result.language, result.foreign_lang_bonus)
|
||||
end
|
||||
end
|
||||
|
||||
rspamd_http.request({
|
||||
task = task,
|
||||
url = settings.url,
|
||||
body = request_body,
|
||||
callback = callback,
|
||||
headers = {
|
||||
["Content-Type"] = "application/json",
|
||||
},
|
||||
timeout = settings.timeout,
|
||||
})
|
||||
end
|
||||
|
||||
-- Symbol registrieren
|
||||
rspamd_config:register_symbol({
|
||||
name = settings.symbol_spam,
|
||||
weight = 5.0,
|
||||
callback = check_spamllm,
|
||||
type = "normal",
|
||||
description = "SpamLLM DistilBERT spam classifier",
|
||||
})
|
||||
|
||||
rspamd_config:register_symbol({
|
||||
name = settings.symbol_ham,
|
||||
weight = -2.0,
|
||||
type = "virtual",
|
||||
parent = rspamd_config:get_symbol_id(settings.symbol_spam),
|
||||
description = "SpamLLM DistilBERT ham classification",
|
||||
})
|
||||
|
||||
rspamd_config:register_symbol({
|
||||
name = settings.symbol_foreign,
|
||||
weight = 4.0,
|
||||
type = "virtual",
|
||||
parent = rspamd_config:get_symbol_id(settings.symbol_spam),
|
||||
description = "Mail in unerwarteter Sprache (nicht DE/EN)",
|
||||
})
|
||||
120
server.py
Normal file
120
server.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
"""
|
||||
FastAPI Service für Spam-Klassifikation mit Spracherkennung.
|
||||
|
||||
Stellt einen HTTP-Endpunkt bereit, den RSpamd als external_service aufrufen kann.
|
||||
Mails in nicht-erwarteten Sprachen (nicht DE/EN) bekommen einen Spam-Bonus.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from fastapi import FastAPI
|
||||
from langdetect import DetectorFactory, detect_langs
|
||||
from pydantic import BaseModel
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
# Deterministische Spracherkennung
|
||||
DetectorFactory.seed = 0
|
||||
|
||||
# Erwartete Sprachen - alles andere bekommt einen Spam-Score-Bonus
|
||||
EXPECTED_LANGUAGES = {"de", "en"}
|
||||
# Score-Bonus für unerwartete Sprachen (0-5 Punkte extra)
|
||||
FOREIGN_LANG_BONUS = 4.0
|
||||
|
||||
logger = logging.getLogger("spamllm")
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
MODEL_PATH = Path("./model/final")
|
||||
|
||||
# Global model state
|
||||
model = None
|
||||
tokenizer = None
|
||||
device = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global model, tokenizer, device
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logger.info(f"Loading model from {MODEL_PATH} on {device}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(str(MODEL_PATH))
|
||||
model = AutoModelForSequenceClassification.from_pretrained(str(MODEL_PATH))
|
||||
model.to(device)
|
||||
model.eval()
|
||||
logger.info("Model loaded successfully")
|
||||
yield
|
||||
|
||||
|
||||
app = FastAPI(title="SpamLLM Classifier", lifespan=lifespan)
|
||||
|
||||
|
||||
class ClassifyRequest(BaseModel):
|
||||
subject: str = ""
|
||||
body: str = ""
|
||||
from_addr: str = ""
|
||||
|
||||
|
||||
class ClassifyResponse(BaseModel):
|
||||
is_spam: bool
|
||||
confidence: float
|
||||
score: float # RSpamd-kompatibler Score (0-15)
|
||||
language: str # Erkannte Sprache
|
||||
foreign_lang_bonus: float # Zusätzlicher Score für Fremdsprache
|
||||
|
||||
|
||||
def detect_language(text: str) -> tuple[str, bool]:
|
||||
"""Erkennt die Sprache und ob sie erwartet ist."""
|
||||
if not text or len(text.strip()) < 20:
|
||||
return "unknown", False
|
||||
|
||||
try:
|
||||
langs = detect_langs(text)
|
||||
top_lang = langs[0]
|
||||
lang_code = top_lang.lang
|
||||
is_foreign = lang_code not in EXPECTED_LANGUAGES
|
||||
return lang_code, is_foreign
|
||||
except Exception:
|
||||
return "unknown", False
|
||||
|
||||
|
||||
@app.post("/classify", response_model=ClassifyResponse)
|
||||
async def classify(request: ClassifyRequest):
|
||||
# Kombiniere Mail-Felder zu einem Text
|
||||
text = f"From: {request.from_addr}\nSubject: {request.subject}\n\n{request.body}"
|
||||
|
||||
# Spracherkennung auf dem Body (Subject ist oft zu kurz)
|
||||
lang_text = request.body if request.body else request.subject
|
||||
language, is_foreign = detect_language(lang_text)
|
||||
|
||||
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
probs = torch.softmax(outputs.logits, dim=-1)
|
||||
spam_prob = probs[0][1].item()
|
||||
|
||||
# Konvertiere Wahrscheinlichkeit zu RSpamd-Score (0-15 Skala)
|
||||
rspamd_score = spam_prob * 15.0
|
||||
|
||||
# Fremdsprachen-Bonus: Nicht DE/EN bekommt extra Punkte
|
||||
lang_bonus = FOREIGN_LANG_BONUS if is_foreign else 0.0
|
||||
rspamd_score = min(rspamd_score + lang_bonus, 15.0)
|
||||
|
||||
# Spam-Schwelle nach Bonus neu bewerten
|
||||
effective_spam = spam_prob > 0.5 or (is_foreign and spam_prob > 0.3)
|
||||
|
||||
return ClassifyResponse(
|
||||
is_spam=effective_spam,
|
||||
confidence=spam_prob,
|
||||
score=round(rspamd_score, 2),
|
||||
language=language,
|
||||
foreign_lang_bonus=lang_bonus,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "model_loaded": model is not None}
|
||||
93
test_classify.py
Normal file
93
test_classify.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
"""
|
||||
Teste den Classifier lokal ohne Server.
|
||||
Nützlich um das Modell nach dem Training schnell zu validieren.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
MODEL_PATH = Path("./model/final")
|
||||
|
||||
# Testmails - DE, EN und Fremdsprachen
|
||||
TEST_MESSAGES = [
|
||||
# Deutscher Spam
|
||||
{
|
||||
"label": "spam",
|
||||
"subject": "Herzlichen Glückwunsch! Sie haben gewonnen!",
|
||||
"body": "Sie wurden als Gewinner ausgewählt! Klicken Sie hier um Ihren Preis von 50.000 Euro abzuholen. Senden Sie uns Ihre Bankdaten zur Überweisung.",
|
||||
},
|
||||
{
|
||||
"label": "spam",
|
||||
"subject": "Dringend: Ihr Konto wird gesperrt",
|
||||
"body": "Sehr geehrter Kunde, Ihr Konto wird in 24 Stunden gesperrt. Bestätigen Sie Ihre Identität sofort unter folgendem Link um den Zugang zu behalten.",
|
||||
},
|
||||
# Deutscher Ham
|
||||
{
|
||||
"label": "ham",
|
||||
"subject": "Re: Besprechung morgen um 15 Uhr",
|
||||
"body": "Hallo zusammen, kurze Erinnerung an unser wöchentliches Meeting morgen um 15 Uhr. Bitte bereitet eure Status-Updates vor. Danke!",
|
||||
},
|
||||
{
|
||||
"label": "ham",
|
||||
"subject": "Rechnung Nr. 4521",
|
||||
"body": "Vielen Dank für die Zusendung der Rechnung. Ich habe sie an die Buchhaltung weitergeleitet. Die Zahlung erfolgt innerhalb von 30 Tagen.",
|
||||
},
|
||||
# Englischer Spam
|
||||
{
|
||||
"label": "spam",
|
||||
"subject": "You won $1,000,000!!!",
|
||||
"body": "Congratulations! You have been selected as our winner. Click here to claim your prize now! Send us your bank details immediately.",
|
||||
},
|
||||
# Englischer Ham
|
||||
{
|
||||
"label": "ham",
|
||||
"subject": "Re: Project timeline update",
|
||||
"body": "Hi team, just a quick update on the project timeline. We're on track for the Q2 release. Let me know if you have any blockers.",
|
||||
},
|
||||
# Fremdsprache (Russisch) - sollte als verdächtig gelten
|
||||
{
|
||||
"label": "spam",
|
||||
"subject": "Специальное предложение для вас",
|
||||
"body": "Поздравляем! Вы выиграли специальный приз. Нажмите здесь чтобы получить вашу награду прямо сейчас.",
|
||||
},
|
||||
# Fremdsprache (Chinesisch)
|
||||
{
|
||||
"label": "spam",
|
||||
"subject": "恭喜您中奖了",
|
||||
"body": "尊敬的用户,恭喜您被选为幸运用户,请点击链接领取您的奖品。",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
print(f"Loading model from {MODEL_PATH}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(str(MODEL_PATH))
|
||||
model = AutoModelForSequenceClassification.from_pretrained(str(MODEL_PATH))
|
||||
|
||||
print(f"\n{'Label':<6} {'Prediction':<12} {'Confidence':>10} Subject")
|
||||
print("-" * 70)
|
||||
|
||||
correct = 0
|
||||
for msg in TEST_MESSAGES:
|
||||
text = f"Subject: {msg['subject']}\n\n{msg['body']}"
|
||||
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
probs = torch.softmax(outputs.logits, dim=-1)
|
||||
spam_prob = probs[0][1].item()
|
||||
|
||||
prediction = "spam" if spam_prob > 0.5 else "ham"
|
||||
is_correct = prediction == msg["label"]
|
||||
correct += int(is_correct)
|
||||
marker = "+" if is_correct else "X"
|
||||
|
||||
print(f"{msg['label']:<6} {prediction:<12} {spam_prob:>9.1%} [{marker}] {msg['subject']}")
|
||||
|
||||
print(f"\nAccuracy: {correct}/{len(TEST_MESSAGES)} ({correct / len(TEST_MESSAGES):.0%})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
158
train.py
Normal file
158
train.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
"""
|
||||
Fine-Tune DistilBERT (multilingual) als Spam-Classifier.
|
||||
|
||||
Kombiniert englische und deutsche Spam-Datensätze.
|
||||
Für Produktion: Eigene gelabelte Spam/Ham-Mails unter data/ ablegen.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, concatenate_datasets, load_dataset
|
||||
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
|
||||
def tokenize(batch, tokenizer):
|
||||
return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=512)
|
||||
|
||||
|
||||
def compute_metrics(pred):
|
||||
labels = pred.label_ids
|
||||
preds = pred.predictions.argmax(-1)
|
||||
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")
|
||||
acc = accuracy_score(labels, preds)
|
||||
return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
|
||||
|
||||
|
||||
def load_public_datasets():
|
||||
"""Lädt und kombiniert öffentliche EN + DE Spam-Datensätze."""
|
||||
|
||||
# Englisch: SMS Spam Collection (~5.500 Nachrichten)
|
||||
print("Loading English SMS spam dataset...")
|
||||
en_sms = load_dataset("sms_spam", split="train")
|
||||
en_data = en_sms.rename_column("sms", "text").rename_column("label", "labels")
|
||||
|
||||
# Deutsch: Es gibt wenige öffentliche DE-Spam-Datensätze.
|
||||
# Wir nutzen einen multilingualen E-Mail-Datensatz falls verfügbar,
|
||||
# ansonsten den englischen als Basis.
|
||||
# Für Produktion: Eigene DE-Mails unter data/train_de.csv ablegen.
|
||||
print("Checking for custom German dataset at data/train_de.csv...")
|
||||
de_path = Path("data/train_de.csv")
|
||||
if de_path.exists():
|
||||
print(f" Found {de_path}, loading...")
|
||||
de_data = load_dataset("csv", data_files=str(de_path), split="train")
|
||||
# Erwartetes Format: Spalten "text" und "labels" (0=ham, 1=spam)
|
||||
combined = concatenate_datasets([en_data, de_data])
|
||||
else:
|
||||
print(" Not found. Using English-only dataset.")
|
||||
print(" Tipp: Exportiere deine RSpamd-Bayes-Daten als CSV nach data/train_de.csv")
|
||||
print(" Format: text,labels (labels: 0=ham, 1=spam)")
|
||||
combined = en_data
|
||||
|
||||
return combined
|
||||
|
||||
|
||||
def load_custom_dataset():
|
||||
"""Lädt benutzerdefinierte Datensätze aus data/."""
|
||||
train_path = Path("data/train.csv")
|
||||
test_path = Path("data/test.csv")
|
||||
|
||||
if not train_path.exists():
|
||||
return None
|
||||
|
||||
print(f"Loading custom dataset from {train_path}...")
|
||||
dataset = load_dataset("csv", data_files={"train": str(train_path)}, split="train")
|
||||
|
||||
if test_path.exists():
|
||||
test_ds = load_dataset("csv", data_files={"test": str(test_path)}, split="train")
|
||||
from datasets import DatasetDict
|
||||
return DatasetDict({"train": dataset, "test": test_ds})
|
||||
|
||||
return dataset.train_test_split(test_size=0.2, seed=42, stratify_by_column="labels")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train multilingual DistilBERT spam classifier")
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
default="distilbert-base-multilingual-cased",
|
||||
help="Base model (default: multilingual)",
|
||||
)
|
||||
parser.add_argument("--output-dir", default="./model", help="Where to save the trained model")
|
||||
parser.add_argument("--epochs", type=int, default=3, help="Training epochs")
|
||||
parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
|
||||
parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate")
|
||||
parser.add_argument("--custom-data", action="store_true", help="Use only custom data from data/")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Loading base model: {args.model_name}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
args.model_name,
|
||||
num_labels=2,
|
||||
id2label={0: "ham", 1: "spam"},
|
||||
label2id={"ham": 0, "spam": 1},
|
||||
)
|
||||
|
||||
# --- Datensatz laden ---
|
||||
if args.custom_data:
|
||||
dataset = load_custom_dataset()
|
||||
if dataset is None:
|
||||
print("ERROR: --custom-data gesetzt aber data/train.csv nicht gefunden!")
|
||||
return
|
||||
else:
|
||||
combined = load_public_datasets()
|
||||
dataset = combined.train_test_split(test_size=0.2, seed=42, stratify_by_column="labels")
|
||||
|
||||
tokenized = dataset.map(lambda batch: tokenize(batch, tokenizer), batched=True)
|
||||
tokenized.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
|
||||
|
||||
print(f"Train: {len(tokenized['train'])} samples, Test: {len(tokenized['test'])} samples")
|
||||
|
||||
# --- Training ---
|
||||
training_args = TrainingArguments(
|
||||
output_dir=args.output_dir,
|
||||
num_train_epochs=args.epochs,
|
||||
per_device_train_batch_size=args.batch_size,
|
||||
per_device_eval_batch_size=args.batch_size,
|
||||
learning_rate=args.lr,
|
||||
weight_decay=0.01,
|
||||
eval_strategy="epoch",
|
||||
save_strategy="epoch",
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="f1",
|
||||
logging_steps=50,
|
||||
fp16=torch.cuda.is_available(),
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized["train"],
|
||||
eval_dataset=tokenized["test"],
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
print("Starting training...")
|
||||
trainer.train()
|
||||
|
||||
# --- Evaluierung ---
|
||||
results = trainer.evaluate()
|
||||
print(f"\nResults: {results}")
|
||||
|
||||
# --- Speichern ---
|
||||
output_path = Path(args.output_dir) / "final"
|
||||
trainer.save_model(str(output_path))
|
||||
tokenizer.save_pretrained(str(output_path))
|
||||
print(f"\nModel saved to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue