Initial commit: SpamLLM - DistilBERT spam classifier for RSpamd

Multilingual spam classifier (DE/EN) with language detection.
Non-DE/EN mails receive an additional spam score bonus.

- train.py: Fine-tune distilbert-base-multilingual-cased on spam/ham data
- server.py: FastAPI service with langdetect integration
- rspamd/: Lua plugin and config for RSpamd integration
- export_rspamd_data.py: Export Maildir folders to CSV training data
- test_classify.py: Local model validation with DE/EN/foreign test cases

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Carsten Abele 2026-03-19 22:27:05 +01:00
commit 38efd20b4d
7 changed files with 671 additions and 0 deletions

136
export_rspamd_data.py Normal file
View file

@ -0,0 +1,136 @@
"""
Exportiert Mails aus Maildir-Ordnern als Trainingsdaten für SpamLLM.
Erwartet eine typische Maildir-Struktur:
- Spam-Ordner: z.B. ~/.spam/ oder /var/vmail/user/.Junk/
- Ham-Ordner: z.B. ~/Maildir/ oder /var/vmail/user/.INBOX/
Erzeugt: data/train.csv mit Spalten: text, labels (0=ham, 1=spam)
"""
import argparse
import csv
import email
import email.policy
import os
import random
from pathlib import Path
def extract_text_from_email(filepath: str) -> dict | None:
"""Extrahiert Subject und Body aus einer E-Mail-Datei."""
try:
with open(filepath, "rb") as f:
msg = email.message_from_binary_file(f, policy=email.policy.default)
subject = msg.get("Subject", "")
from_addr = msg.get("From", "")
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain":
content = part.get_content()
if isinstance(content, str):
body += content
else:
content = msg.get_content()
if isinstance(content, str):
body = content
# Auf sinnvolle Länge begrenzen
body = body[:4096]
if not subject and not body:
return None
text = f"From: {from_addr}\nSubject: {subject}\n\n{body}"
return {"text": text}
except Exception as e:
print(f" Skipping {filepath}: {e}")
return None
def collect_mails(directory: str, label: int, max_count: int = 0) -> list[dict]:
"""Sammelt Mails aus einem Maildir-Verzeichnis."""
results = []
mail_dir = Path(directory)
if not mail_dir.exists():
print(f"WARNING: {directory} existiert nicht!")
return results
# Maildir hat typischerweise cur/, new/, tmp/ Unterordner
search_dirs = [mail_dir]
for subdir in ["cur", "new"]:
sub = mail_dir / subdir
if sub.exists():
search_dirs.append(sub)
files = []
for search_dir in search_dirs:
for f in search_dir.iterdir():
if f.is_file() and not f.name.startswith("."):
files.append(f)
if max_count > 0 and len(files) > max_count:
random.shuffle(files)
files = files[:max_count]
label_name = "spam" if label == 1 else "ham"
print(f"Processing {len(files)} {label_name} mails from {directory}...")
for filepath in files:
extracted = extract_text_from_email(str(filepath))
if extracted:
extracted["labels"] = label
results.append(extracted)
return results
def main():
parser = argparse.ArgumentParser(description="Export Maildir to CSV training data")
parser.add_argument("--spam-dir", required=True, help="Path to spam Maildir")
parser.add_argument("--ham-dir", required=True, help="Path to ham Maildir")
parser.add_argument("--output", default="data/train.csv", help="Output CSV path")
parser.add_argument("--max-per-class", type=int, default=0, help="Max mails per class (0=all)")
parser.add_argument("--test-split", type=float, default=0.2, help="Test set ratio")
args = parser.parse_args()
spam_mails = collect_mails(args.spam_dir, label=1, max_count=args.max_per_class)
ham_mails = collect_mails(args.ham_dir, label=0, max_count=args.max_per_class)
all_mails = spam_mails + ham_mails
random.shuffle(all_mails)
print(f"\nTotal: {len(all_mails)} mails ({len(spam_mails)} spam, {len(ham_mails)} ham)")
# Train/Test Split
split_idx = int(len(all_mails) * (1 - args.test_split))
train_data = all_mails[:split_idx]
test_data = all_mails[split_idx:]
# Verzeichnis erstellen
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Train CSV
with open(output_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=["text", "labels"])
writer.writeheader()
writer.writerows(train_data)
print(f"Train set: {len(train_data)} mails -> {output_path}")
# Test CSV
test_path = output_path.parent / "test.csv"
with open(test_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=["text", "labels"])
writer.writeheader()
writer.writerows(test_data)
print(f"Test set: {len(test_data)} mails -> {test_path}")
if __name__ == "__main__":
main()

9
requirements.txt Normal file
View file

@ -0,0 +1,9 @@
torch>=2.0.0
transformers>=4.36.0
fastapi>=0.104.0
uvicorn>=0.24.0
pydantic>=2.0.0
datasets>=2.16.0
scikit-learn>=1.3.0
accelerate>=0.25.0
langdetect>=1.0.9

View file

@ -0,0 +1,26 @@
# RSpamd External Service Konfiguration für SpamLLM
# Kopiere diese Datei nach /etc/rspamd/local.d/external_services.conf
spamllm {
# Typ: HTTP-basierter externer Service
type = "http";
# URL des SpamLLM FastAPI Service
url = "http://127.0.0.1:8000/classify";
# Timeout in Sekunden
timeout = 5.0;
# Maximale Nachrichtengröße die an den Service gesendet wird (in Bytes)
max_size = 50k;
# Symbol das bei Spam-Erkennung gesetzt wird
symbol = "SPAMLLM_SPAM";
# Score der dem Symbol zugewiesen wird (wird durch den Service dynamisch gesetzt)
weight = 5.0;
# Nur Mails im Graubereich prüfen (Score zwischen 3 und 12)
# Das spart Ressourcen: offensichtlicher Spam/Ham wird nicht an LLM geschickt
condition = "not rspamd_config.is_local(task:get_from_ip()) and task:get_metric_score('default') > 3 and task:get_metric_score('default') < 12";
}

129
rspamd/lua/spamllm.lua Normal file
View file

@ -0,0 +1,129 @@
-- RSpamd Lua Plugin für SpamLLM
-- Kopiere nach /etc/rspamd/plugins.d/spamllm.lua
--
-- Dieser Plugin sendet Mail-Daten an den SpamLLM HTTP Service
-- und setzt den Score basierend auf der Antwort.
local rspamd_http = require "rspamd_http"
local rspamd_logger = require "rspamd_logger"
local ucl = require "ucl"
local N = "spamllm"
local settings = {
url = "http://127.0.0.1:8000/classify",
timeout = 5.0,
symbol_spam = "SPAMLLM_SPAM",
symbol_ham = "SPAMLLM_HAM",
symbol_foreign = "SPAMLLM_FOREIGN_LANG",
threshold = 0.5,
max_body_length = 4096,
enabled = true,
}
local function check_spamllm(task)
-- Extrahiere Mail-Daten
local from = task:get_from("smtp")
local from_addr = ""
if from and from[1] then
from_addr = from[1].addr or ""
end
local subject = task:get_subject() or ""
local text_parts = task:get_text_parts()
local body = ""
if text_parts then
for _, part in ipairs(text_parts) do
local content = part:get_content()
if content then
body = body .. tostring(content)
if #body > settings.max_body_length then
body = body:sub(1, settings.max_body_length)
break
end
end
end
end
-- JSON Request Body
local request_body = string.format(
'{"from_addr":"%s","subject":"%s","body":"%s"}',
from_addr:gsub('"', '\\"'),
subject:gsub('"', '\\"'),
body:gsub('"', '\\"'):gsub('\n', '\\n'):gsub('\r', '\\r')
)
local function callback(err, code, response_body)
if err then
rspamd_logger.errx(task, "SpamLLM request failed: %s", err)
return
end
if code ~= 200 then
rspamd_logger.errx(task, "SpamLLM returned HTTP %s", code)
return
end
local parser = ucl.parser()
local ok, parse_err = parser:parse_string(response_body)
if not ok then
rspamd_logger.errx(task, "SpamLLM JSON parse error: %s", parse_err)
return
end
local result = parser:get_object()
if result.is_spam then
task:insert_result(settings.symbol_spam, result.confidence, "SpamLLM")
rspamd_logger.infox(task, "SpamLLM: SPAM (confidence=%.2f, score=%.2f, lang=%s)",
result.confidence, result.score, result.language or "?")
else
task:insert_result(settings.symbol_ham, -result.confidence, "SpamLLM")
end
-- Fremdsprachen-Bonus als separates Symbol
if result.foreign_lang_bonus and result.foreign_lang_bonus > 0 then
task:insert_result(settings.symbol_foreign, result.foreign_lang_bonus / 4.0,
string.format("lang=%s", result.language or "unknown"))
rspamd_logger.infox(task, "SpamLLM: Foreign language detected: %s (bonus=%.1f)",
result.language, result.foreign_lang_bonus)
end
end
rspamd_http.request({
task = task,
url = settings.url,
body = request_body,
callback = callback,
headers = {
["Content-Type"] = "application/json",
},
timeout = settings.timeout,
})
end
-- Symbol registrieren
rspamd_config:register_symbol({
name = settings.symbol_spam,
weight = 5.0,
callback = check_spamllm,
type = "normal",
description = "SpamLLM DistilBERT spam classifier",
})
rspamd_config:register_symbol({
name = settings.symbol_ham,
weight = -2.0,
type = "virtual",
parent = rspamd_config:get_symbol_id(settings.symbol_spam),
description = "SpamLLM DistilBERT ham classification",
})
rspamd_config:register_symbol({
name = settings.symbol_foreign,
weight = 4.0,
type = "virtual",
parent = rspamd_config:get_symbol_id(settings.symbol_spam),
description = "Mail in unerwarteter Sprache (nicht DE/EN)",
})

120
server.py Normal file
View file

@ -0,0 +1,120 @@
"""
FastAPI Service für Spam-Klassifikation mit Spracherkennung.
Stellt einen HTTP-Endpunkt bereit, den RSpamd als external_service aufrufen kann.
Mails in nicht-erwarteten Sprachen (nicht DE/EN) bekommen einen Spam-Bonus.
"""
import logging
from contextlib import asynccontextmanager
from pathlib import Path
import torch
from fastapi import FastAPI
from langdetect import DetectorFactory, detect_langs
from pydantic import BaseModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# Deterministische Spracherkennung
DetectorFactory.seed = 0
# Erwartete Sprachen - alles andere bekommt einen Spam-Score-Bonus
EXPECTED_LANGUAGES = {"de", "en"}
# Score-Bonus für unerwartete Sprachen (0-5 Punkte extra)
FOREIGN_LANG_BONUS = 4.0
logger = logging.getLogger("spamllm")
logging.basicConfig(level=logging.INFO)
MODEL_PATH = Path("./model/final")
# Global model state
model = None
tokenizer = None
device = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global model, tokenizer, device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Loading model from {MODEL_PATH} on {device}")
tokenizer = AutoTokenizer.from_pretrained(str(MODEL_PATH))
model = AutoModelForSequenceClassification.from_pretrained(str(MODEL_PATH))
model.to(device)
model.eval()
logger.info("Model loaded successfully")
yield
app = FastAPI(title="SpamLLM Classifier", lifespan=lifespan)
class ClassifyRequest(BaseModel):
subject: str = ""
body: str = ""
from_addr: str = ""
class ClassifyResponse(BaseModel):
is_spam: bool
confidence: float
score: float # RSpamd-kompatibler Score (0-15)
language: str # Erkannte Sprache
foreign_lang_bonus: float # Zusätzlicher Score für Fremdsprache
def detect_language(text: str) -> tuple[str, bool]:
"""Erkennt die Sprache und ob sie erwartet ist."""
if not text or len(text.strip()) < 20:
return "unknown", False
try:
langs = detect_langs(text)
top_lang = langs[0]
lang_code = top_lang.lang
is_foreign = lang_code not in EXPECTED_LANGUAGES
return lang_code, is_foreign
except Exception:
return "unknown", False
@app.post("/classify", response_model=ClassifyResponse)
async def classify(request: ClassifyRequest):
# Kombiniere Mail-Felder zu einem Text
text = f"From: {request.from_addr}\nSubject: {request.subject}\n\n{request.body}"
# Spracherkennung auf dem Body (Subject ist oft zu kurz)
lang_text = request.body if request.body else request.subject
language, is_foreign = detect_language(lang_text)
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
spam_prob = probs[0][1].item()
# Konvertiere Wahrscheinlichkeit zu RSpamd-Score (0-15 Skala)
rspamd_score = spam_prob * 15.0
# Fremdsprachen-Bonus: Nicht DE/EN bekommt extra Punkte
lang_bonus = FOREIGN_LANG_BONUS if is_foreign else 0.0
rspamd_score = min(rspamd_score + lang_bonus, 15.0)
# Spam-Schwelle nach Bonus neu bewerten
effective_spam = spam_prob > 0.5 or (is_foreign and spam_prob > 0.3)
return ClassifyResponse(
is_spam=effective_spam,
confidence=spam_prob,
score=round(rspamd_score, 2),
language=language,
foreign_lang_bonus=lang_bonus,
)
@app.get("/health")
async def health():
return {"status": "ok", "model_loaded": model is not None}

93
test_classify.py Normal file
View file

@ -0,0 +1,93 @@
"""
Teste den Classifier lokal ohne Server.
Nützlich um das Modell nach dem Training schnell zu validieren.
"""
from pathlib import Path
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
MODEL_PATH = Path("./model/final")
# Testmails - DE, EN und Fremdsprachen
TEST_MESSAGES = [
# Deutscher Spam
{
"label": "spam",
"subject": "Herzlichen Glückwunsch! Sie haben gewonnen!",
"body": "Sie wurden als Gewinner ausgewählt! Klicken Sie hier um Ihren Preis von 50.000 Euro abzuholen. Senden Sie uns Ihre Bankdaten zur Überweisung.",
},
{
"label": "spam",
"subject": "Dringend: Ihr Konto wird gesperrt",
"body": "Sehr geehrter Kunde, Ihr Konto wird in 24 Stunden gesperrt. Bestätigen Sie Ihre Identität sofort unter folgendem Link um den Zugang zu behalten.",
},
# Deutscher Ham
{
"label": "ham",
"subject": "Re: Besprechung morgen um 15 Uhr",
"body": "Hallo zusammen, kurze Erinnerung an unser wöchentliches Meeting morgen um 15 Uhr. Bitte bereitet eure Status-Updates vor. Danke!",
},
{
"label": "ham",
"subject": "Rechnung Nr. 4521",
"body": "Vielen Dank für die Zusendung der Rechnung. Ich habe sie an die Buchhaltung weitergeleitet. Die Zahlung erfolgt innerhalb von 30 Tagen.",
},
# Englischer Spam
{
"label": "spam",
"subject": "You won $1,000,000!!!",
"body": "Congratulations! You have been selected as our winner. Click here to claim your prize now! Send us your bank details immediately.",
},
# Englischer Ham
{
"label": "ham",
"subject": "Re: Project timeline update",
"body": "Hi team, just a quick update on the project timeline. We're on track for the Q2 release. Let me know if you have any blockers.",
},
# Fremdsprache (Russisch) - sollte als verdächtig gelten
{
"label": "spam",
"subject": "Специальное предложение для вас",
"body": "Поздравляем! Вы выиграли специальный приз. Нажмите здесь чтобы получить вашу награду прямо сейчас.",
},
# Fremdsprache (Chinesisch)
{
"label": "spam",
"subject": "恭喜您中奖了",
"body": "尊敬的用户,恭喜您被选为幸运用户,请点击链接领取您的奖品。",
},
]
def main():
print(f"Loading model from {MODEL_PATH}...")
tokenizer = AutoTokenizer.from_pretrained(str(MODEL_PATH))
model = AutoModelForSequenceClassification.from_pretrained(str(MODEL_PATH))
print(f"\n{'Label':<6} {'Prediction':<12} {'Confidence':>10} Subject")
print("-" * 70)
correct = 0
for msg in TEST_MESSAGES:
text = f"Subject: {msg['subject']}\n\n{msg['body']}"
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
spam_prob = probs[0][1].item()
prediction = "spam" if spam_prob > 0.5 else "ham"
is_correct = prediction == msg["label"]
correct += int(is_correct)
marker = "+" if is_correct else "X"
print(f"{msg['label']:<6} {prediction:<12} {spam_prob:>9.1%} [{marker}] {msg['subject']}")
print(f"\nAccuracy: {correct}/{len(TEST_MESSAGES)} ({correct / len(TEST_MESSAGES):.0%})")
if __name__ == "__main__":
main()

158
train.py Normal file
View file

@ -0,0 +1,158 @@
"""
Fine-Tune DistilBERT (multilingual) als Spam-Classifier.
Kombiniert englische und deutsche Spam-Datensätze.
Für Produktion: Eigene gelabelte Spam/Ham-Mails unter data/ ablegen.
"""
import argparse
from pathlib import Path
import torch
from datasets import Dataset, concatenate_datasets, load_dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
Trainer,
TrainingArguments,
)
def tokenize(batch, tokenizer):
return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=512)
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")
acc = accuracy_score(labels, preds)
return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
def load_public_datasets():
"""Lädt und kombiniert öffentliche EN + DE Spam-Datensätze."""
# Englisch: SMS Spam Collection (~5.500 Nachrichten)
print("Loading English SMS spam dataset...")
en_sms = load_dataset("sms_spam", split="train")
en_data = en_sms.rename_column("sms", "text").rename_column("label", "labels")
# Deutsch: Es gibt wenige öffentliche DE-Spam-Datensätze.
# Wir nutzen einen multilingualen E-Mail-Datensatz falls verfügbar,
# ansonsten den englischen als Basis.
# Für Produktion: Eigene DE-Mails unter data/train_de.csv ablegen.
print("Checking for custom German dataset at data/train_de.csv...")
de_path = Path("data/train_de.csv")
if de_path.exists():
print(f" Found {de_path}, loading...")
de_data = load_dataset("csv", data_files=str(de_path), split="train")
# Erwartetes Format: Spalten "text" und "labels" (0=ham, 1=spam)
combined = concatenate_datasets([en_data, de_data])
else:
print(" Not found. Using English-only dataset.")
print(" Tipp: Exportiere deine RSpamd-Bayes-Daten als CSV nach data/train_de.csv")
print(" Format: text,labels (labels: 0=ham, 1=spam)")
combined = en_data
return combined
def load_custom_dataset():
"""Lädt benutzerdefinierte Datensätze aus data/."""
train_path = Path("data/train.csv")
test_path = Path("data/test.csv")
if not train_path.exists():
return None
print(f"Loading custom dataset from {train_path}...")
dataset = load_dataset("csv", data_files={"train": str(train_path)}, split="train")
if test_path.exists():
test_ds = load_dataset("csv", data_files={"test": str(test_path)}, split="train")
from datasets import DatasetDict
return DatasetDict({"train": dataset, "test": test_ds})
return dataset.train_test_split(test_size=0.2, seed=42, stratify_by_column="labels")
def main():
parser = argparse.ArgumentParser(description="Train multilingual DistilBERT spam classifier")
parser.add_argument(
"--model-name",
default="distilbert-base-multilingual-cased",
help="Base model (default: multilingual)",
)
parser.add_argument("--output-dir", default="./model", help="Where to save the trained model")
parser.add_argument("--epochs", type=int, default=3, help="Training epochs")
parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate")
parser.add_argument("--custom-data", action="store_true", help="Use only custom data from data/")
args = parser.parse_args()
print(f"Loading base model: {args.model_name}")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name,
num_labels=2,
id2label={0: "ham", 1: "spam"},
label2id={"ham": 0, "spam": 1},
)
# --- Datensatz laden ---
if args.custom_data:
dataset = load_custom_dataset()
if dataset is None:
print("ERROR: --custom-data gesetzt aber data/train.csv nicht gefunden!")
return
else:
combined = load_public_datasets()
dataset = combined.train_test_split(test_size=0.2, seed=42, stratify_by_column="labels")
tokenized = dataset.map(lambda batch: tokenize(batch, tokenizer), batched=True)
tokenized.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
print(f"Train: {len(tokenized['train'])} samples, Test: {len(tokenized['test'])} samples")
# --- Training ---
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
learning_rate=args.lr,
weight_decay=0.01,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1",
logging_steps=50,
fp16=torch.cuda.is_available(),
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["test"],
compute_metrics=compute_metrics,
)
print("Starting training...")
trainer.train()
# --- Evaluierung ---
results = trainer.evaluate()
print(f"\nResults: {results}")
# --- Speichern ---
output_path = Path(args.output_dir) / "final"
trainer.save_model(str(output_path))
tokenizer.save_pretrained(str(output_path))
print(f"\nModel saved to {output_path}")
if __name__ == "__main__":
main()