commit 38efd20b4d8a50e59c8871ffdaa4e0fa4c617344 Author: Carsten Abele Date: Thu Mar 19 22:27:05 2026 +0100 Initial commit: SpamLLM - DistilBERT spam classifier for RSpamd Multilingual spam classifier (DE/EN) with language detection. Non-DE/EN mails receive an additional spam score bonus. - train.py: Fine-tune distilbert-base-multilingual-cased on spam/ham data - server.py: FastAPI service with langdetect integration - rspamd/: Lua plugin and config for RSpamd integration - export_rspamd_data.py: Export Maildir folders to CSV training data - test_classify.py: Local model validation with DE/EN/foreign test cases Co-Authored-By: Claude Opus 4.6 (1M context) diff --git a/export_rspamd_data.py b/export_rspamd_data.py new file mode 100644 index 0000000..afe960f --- /dev/null +++ b/export_rspamd_data.py @@ -0,0 +1,136 @@ +""" +Exportiert Mails aus Maildir-Ordnern als Trainingsdaten für SpamLLM. + +Erwartet eine typische Maildir-Struktur: + - Spam-Ordner: z.B. ~/.spam/ oder /var/vmail/user/.Junk/ + - Ham-Ordner: z.B. ~/Maildir/ oder /var/vmail/user/.INBOX/ + +Erzeugt: data/train.csv mit Spalten: text, labels (0=ham, 1=spam) +""" + +import argparse +import csv +import email +import email.policy +import os +import random +from pathlib import Path + + +def extract_text_from_email(filepath: str) -> dict | None: + """Extrahiert Subject und Body aus einer E-Mail-Datei.""" + try: + with open(filepath, "rb") as f: + msg = email.message_from_binary_file(f, policy=email.policy.default) + + subject = msg.get("Subject", "") + from_addr = msg.get("From", "") + + body = "" + if msg.is_multipart(): + for part in msg.walk(): + if part.get_content_type() == "text/plain": + content = part.get_content() + if isinstance(content, str): + body += content + else: + content = msg.get_content() + if isinstance(content, str): + body = content + + # Auf sinnvolle Länge begrenzen + body = body[:4096] + + if not subject and not body: + return None + + text = f"From: {from_addr}\nSubject: {subject}\n\n{body}" + return {"text": text} + + except Exception as e: + print(f" Skipping {filepath}: {e}") + return None + + +def collect_mails(directory: str, label: int, max_count: int = 0) -> list[dict]: + """Sammelt Mails aus einem Maildir-Verzeichnis.""" + results = [] + mail_dir = Path(directory) + + if not mail_dir.exists(): + print(f"WARNING: {directory} existiert nicht!") + return results + + # Maildir hat typischerweise cur/, new/, tmp/ Unterordner + search_dirs = [mail_dir] + for subdir in ["cur", "new"]: + sub = mail_dir / subdir + if sub.exists(): + search_dirs.append(sub) + + files = [] + for search_dir in search_dirs: + for f in search_dir.iterdir(): + if f.is_file() and not f.name.startswith("."): + files.append(f) + + if max_count > 0 and len(files) > max_count: + random.shuffle(files) + files = files[:max_count] + + label_name = "spam" if label == 1 else "ham" + print(f"Processing {len(files)} {label_name} mails from {directory}...") + + for filepath in files: + extracted = extract_text_from_email(str(filepath)) + if extracted: + extracted["labels"] = label + results.append(extracted) + + return results + + +def main(): + parser = argparse.ArgumentParser(description="Export Maildir to CSV training data") + parser.add_argument("--spam-dir", required=True, help="Path to spam Maildir") + parser.add_argument("--ham-dir", required=True, help="Path to ham Maildir") + parser.add_argument("--output", default="data/train.csv", help="Output CSV path") + parser.add_argument("--max-per-class", type=int, default=0, help="Max mails per class (0=all)") + parser.add_argument("--test-split", type=float, default=0.2, help="Test set ratio") + args = parser.parse_args() + + spam_mails = collect_mails(args.spam_dir, label=1, max_count=args.max_per_class) + ham_mails = collect_mails(args.ham_dir, label=0, max_count=args.max_per_class) + + all_mails = spam_mails + ham_mails + random.shuffle(all_mails) + + print(f"\nTotal: {len(all_mails)} mails ({len(spam_mails)} spam, {len(ham_mails)} ham)") + + # Train/Test Split + split_idx = int(len(all_mails) * (1 - args.test_split)) + train_data = all_mails[:split_idx] + test_data = all_mails[split_idx:] + + # Verzeichnis erstellen + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Train CSV + with open(output_path, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=["text", "labels"]) + writer.writeheader() + writer.writerows(train_data) + print(f"Train set: {len(train_data)} mails -> {output_path}") + + # Test CSV + test_path = output_path.parent / "test.csv" + with open(test_path, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=["text", "labels"]) + writer.writeheader() + writer.writerows(test_data) + print(f"Test set: {len(test_data)} mails -> {test_path}") + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8d457b8 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +torch>=2.0.0 +transformers>=4.36.0 +fastapi>=0.104.0 +uvicorn>=0.24.0 +pydantic>=2.0.0 +datasets>=2.16.0 +scikit-learn>=1.3.0 +accelerate>=0.25.0 +langdetect>=1.0.9 diff --git a/rspamd/local.d/external_services.conf b/rspamd/local.d/external_services.conf new file mode 100644 index 0000000..7b434ae --- /dev/null +++ b/rspamd/local.d/external_services.conf @@ -0,0 +1,26 @@ +# RSpamd External Service Konfiguration für SpamLLM +# Kopiere diese Datei nach /etc/rspamd/local.d/external_services.conf + +spamllm { + # Typ: HTTP-basierter externer Service + type = "http"; + + # URL des SpamLLM FastAPI Service + url = "http://127.0.0.1:8000/classify"; + + # Timeout in Sekunden + timeout = 5.0; + + # Maximale Nachrichtengröße die an den Service gesendet wird (in Bytes) + max_size = 50k; + + # Symbol das bei Spam-Erkennung gesetzt wird + symbol = "SPAMLLM_SPAM"; + + # Score der dem Symbol zugewiesen wird (wird durch den Service dynamisch gesetzt) + weight = 5.0; + + # Nur Mails im Graubereich prüfen (Score zwischen 3 und 12) + # Das spart Ressourcen: offensichtlicher Spam/Ham wird nicht an LLM geschickt + condition = "not rspamd_config.is_local(task:get_from_ip()) and task:get_metric_score('default') > 3 and task:get_metric_score('default') < 12"; +} diff --git a/rspamd/lua/spamllm.lua b/rspamd/lua/spamllm.lua new file mode 100644 index 0000000..c8313bd --- /dev/null +++ b/rspamd/lua/spamllm.lua @@ -0,0 +1,129 @@ +-- RSpamd Lua Plugin für SpamLLM +-- Kopiere nach /etc/rspamd/plugins.d/spamllm.lua +-- +-- Dieser Plugin sendet Mail-Daten an den SpamLLM HTTP Service +-- und setzt den Score basierend auf der Antwort. + +local rspamd_http = require "rspamd_http" +local rspamd_logger = require "rspamd_logger" +local ucl = require "ucl" + +local N = "spamllm" + +local settings = { + url = "http://127.0.0.1:8000/classify", + timeout = 5.0, + symbol_spam = "SPAMLLM_SPAM", + symbol_ham = "SPAMLLM_HAM", + symbol_foreign = "SPAMLLM_FOREIGN_LANG", + threshold = 0.5, + max_body_length = 4096, + enabled = true, +} + +local function check_spamllm(task) + -- Extrahiere Mail-Daten + local from = task:get_from("smtp") + local from_addr = "" + if from and from[1] then + from_addr = from[1].addr or "" + end + + local subject = task:get_subject() or "" + + local text_parts = task:get_text_parts() + local body = "" + if text_parts then + for _, part in ipairs(text_parts) do + local content = part:get_content() + if content then + body = body .. tostring(content) + if #body > settings.max_body_length then + body = body:sub(1, settings.max_body_length) + break + end + end + end + end + + -- JSON Request Body + local request_body = string.format( + '{"from_addr":"%s","subject":"%s","body":"%s"}', + from_addr:gsub('"', '\\"'), + subject:gsub('"', '\\"'), + body:gsub('"', '\\"'):gsub('\n', '\\n'):gsub('\r', '\\r') + ) + + local function callback(err, code, response_body) + if err then + rspamd_logger.errx(task, "SpamLLM request failed: %s", err) + return + end + + if code ~= 200 then + rspamd_logger.errx(task, "SpamLLM returned HTTP %s", code) + return + end + + local parser = ucl.parser() + local ok, parse_err = parser:parse_string(response_body) + if not ok then + rspamd_logger.errx(task, "SpamLLM JSON parse error: %s", parse_err) + return + end + + local result = parser:get_object() + + if result.is_spam then + task:insert_result(settings.symbol_spam, result.confidence, "SpamLLM") + rspamd_logger.infox(task, "SpamLLM: SPAM (confidence=%.2f, score=%.2f, lang=%s)", + result.confidence, result.score, result.language or "?") + else + task:insert_result(settings.symbol_ham, -result.confidence, "SpamLLM") + end + + -- Fremdsprachen-Bonus als separates Symbol + if result.foreign_lang_bonus and result.foreign_lang_bonus > 0 then + task:insert_result(settings.symbol_foreign, result.foreign_lang_bonus / 4.0, + string.format("lang=%s", result.language or "unknown")) + rspamd_logger.infox(task, "SpamLLM: Foreign language detected: %s (bonus=%.1f)", + result.language, result.foreign_lang_bonus) + end + end + + rspamd_http.request({ + task = task, + url = settings.url, + body = request_body, + callback = callback, + headers = { + ["Content-Type"] = "application/json", + }, + timeout = settings.timeout, + }) +end + +-- Symbol registrieren +rspamd_config:register_symbol({ + name = settings.symbol_spam, + weight = 5.0, + callback = check_spamllm, + type = "normal", + description = "SpamLLM DistilBERT spam classifier", +}) + +rspamd_config:register_symbol({ + name = settings.symbol_ham, + weight = -2.0, + type = "virtual", + parent = rspamd_config:get_symbol_id(settings.symbol_spam), + description = "SpamLLM DistilBERT ham classification", +}) + +rspamd_config:register_symbol({ + name = settings.symbol_foreign, + weight = 4.0, + type = "virtual", + parent = rspamd_config:get_symbol_id(settings.symbol_spam), + description = "Mail in unerwarteter Sprache (nicht DE/EN)", +}) diff --git a/server.py b/server.py new file mode 100644 index 0000000..4525a43 --- /dev/null +++ b/server.py @@ -0,0 +1,120 @@ +""" +FastAPI Service für Spam-Klassifikation mit Spracherkennung. + +Stellt einen HTTP-Endpunkt bereit, den RSpamd als external_service aufrufen kann. +Mails in nicht-erwarteten Sprachen (nicht DE/EN) bekommen einen Spam-Bonus. +""" + +import logging +from contextlib import asynccontextmanager +from pathlib import Path + +import torch +from fastapi import FastAPI +from langdetect import DetectorFactory, detect_langs +from pydantic import BaseModel +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +# Deterministische Spracherkennung +DetectorFactory.seed = 0 + +# Erwartete Sprachen - alles andere bekommt einen Spam-Score-Bonus +EXPECTED_LANGUAGES = {"de", "en"} +# Score-Bonus für unerwartete Sprachen (0-5 Punkte extra) +FOREIGN_LANG_BONUS = 4.0 + +logger = logging.getLogger("spamllm") +logging.basicConfig(level=logging.INFO) + +MODEL_PATH = Path("./model/final") + +# Global model state +model = None +tokenizer = None +device = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + global model, tokenizer, device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Loading model from {MODEL_PATH} on {device}") + tokenizer = AutoTokenizer.from_pretrained(str(MODEL_PATH)) + model = AutoModelForSequenceClassification.from_pretrained(str(MODEL_PATH)) + model.to(device) + model.eval() + logger.info("Model loaded successfully") + yield + + +app = FastAPI(title="SpamLLM Classifier", lifespan=lifespan) + + +class ClassifyRequest(BaseModel): + subject: str = "" + body: str = "" + from_addr: str = "" + + +class ClassifyResponse(BaseModel): + is_spam: bool + confidence: float + score: float # RSpamd-kompatibler Score (0-15) + language: str # Erkannte Sprache + foreign_lang_bonus: float # Zusätzlicher Score für Fremdsprache + + +def detect_language(text: str) -> tuple[str, bool]: + """Erkennt die Sprache und ob sie erwartet ist.""" + if not text or len(text.strip()) < 20: + return "unknown", False + + try: + langs = detect_langs(text) + top_lang = langs[0] + lang_code = top_lang.lang + is_foreign = lang_code not in EXPECTED_LANGUAGES + return lang_code, is_foreign + except Exception: + return "unknown", False + + +@app.post("/classify", response_model=ClassifyResponse) +async def classify(request: ClassifyRequest): + # Kombiniere Mail-Felder zu einem Text + text = f"From: {request.from_addr}\nSubject: {request.subject}\n\n{request.body}" + + # Spracherkennung auf dem Body (Subject ist oft zu kurz) + lang_text = request.body if request.body else request.subject + language, is_foreign = detect_language(lang_text) + + inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True) + inputs = {k: v.to(device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = model(**inputs) + probs = torch.softmax(outputs.logits, dim=-1) + spam_prob = probs[0][1].item() + + # Konvertiere Wahrscheinlichkeit zu RSpamd-Score (0-15 Skala) + rspamd_score = spam_prob * 15.0 + + # Fremdsprachen-Bonus: Nicht DE/EN bekommt extra Punkte + lang_bonus = FOREIGN_LANG_BONUS if is_foreign else 0.0 + rspamd_score = min(rspamd_score + lang_bonus, 15.0) + + # Spam-Schwelle nach Bonus neu bewerten + effective_spam = spam_prob > 0.5 or (is_foreign and spam_prob > 0.3) + + return ClassifyResponse( + is_spam=effective_spam, + confidence=spam_prob, + score=round(rspamd_score, 2), + language=language, + foreign_lang_bonus=lang_bonus, + ) + + +@app.get("/health") +async def health(): + return {"status": "ok", "model_loaded": model is not None} diff --git a/test_classify.py b/test_classify.py new file mode 100644 index 0000000..20fcc7a --- /dev/null +++ b/test_classify.py @@ -0,0 +1,93 @@ +""" +Teste den Classifier lokal ohne Server. +Nützlich um das Modell nach dem Training schnell zu validieren. +""" + +from pathlib import Path + +import torch +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +MODEL_PATH = Path("./model/final") + +# Testmails - DE, EN und Fremdsprachen +TEST_MESSAGES = [ + # Deutscher Spam + { + "label": "spam", + "subject": "Herzlichen Glückwunsch! Sie haben gewonnen!", + "body": "Sie wurden als Gewinner ausgewählt! Klicken Sie hier um Ihren Preis von 50.000 Euro abzuholen. Senden Sie uns Ihre Bankdaten zur Überweisung.", + }, + { + "label": "spam", + "subject": "Dringend: Ihr Konto wird gesperrt", + "body": "Sehr geehrter Kunde, Ihr Konto wird in 24 Stunden gesperrt. Bestätigen Sie Ihre Identität sofort unter folgendem Link um den Zugang zu behalten.", + }, + # Deutscher Ham + { + "label": "ham", + "subject": "Re: Besprechung morgen um 15 Uhr", + "body": "Hallo zusammen, kurze Erinnerung an unser wöchentliches Meeting morgen um 15 Uhr. Bitte bereitet eure Status-Updates vor. Danke!", + }, + { + "label": "ham", + "subject": "Rechnung Nr. 4521", + "body": "Vielen Dank für die Zusendung der Rechnung. Ich habe sie an die Buchhaltung weitergeleitet. Die Zahlung erfolgt innerhalb von 30 Tagen.", + }, + # Englischer Spam + { + "label": "spam", + "subject": "You won $1,000,000!!!", + "body": "Congratulations! You have been selected as our winner. Click here to claim your prize now! Send us your bank details immediately.", + }, + # Englischer Ham + { + "label": "ham", + "subject": "Re: Project timeline update", + "body": "Hi team, just a quick update on the project timeline. We're on track for the Q2 release. Let me know if you have any blockers.", + }, + # Fremdsprache (Russisch) - sollte als verdächtig gelten + { + "label": "spam", + "subject": "Специальное предложение для вас", + "body": "Поздравляем! Вы выиграли специальный приз. Нажмите здесь чтобы получить вашу награду прямо сейчас.", + }, + # Fremdsprache (Chinesisch) + { + "label": "spam", + "subject": "恭喜您中奖了", + "body": "尊敬的用户,恭喜您被选为幸运用户,请点击链接领取您的奖品。", + }, +] + + +def main(): + print(f"Loading model from {MODEL_PATH}...") + tokenizer = AutoTokenizer.from_pretrained(str(MODEL_PATH)) + model = AutoModelForSequenceClassification.from_pretrained(str(MODEL_PATH)) + + print(f"\n{'Label':<6} {'Prediction':<12} {'Confidence':>10} Subject") + print("-" * 70) + + correct = 0 + for msg in TEST_MESSAGES: + text = f"Subject: {msg['subject']}\n\n{msg['body']}" + inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding=True) + + with torch.no_grad(): + outputs = model(**inputs) + probs = torch.softmax(outputs.logits, dim=-1) + spam_prob = probs[0][1].item() + + prediction = "spam" if spam_prob > 0.5 else "ham" + is_correct = prediction == msg["label"] + correct += int(is_correct) + marker = "+" if is_correct else "X" + + print(f"{msg['label']:<6} {prediction:<12} {spam_prob:>9.1%} [{marker}] {msg['subject']}") + + print(f"\nAccuracy: {correct}/{len(TEST_MESSAGES)} ({correct / len(TEST_MESSAGES):.0%})") + + +if __name__ == "__main__": + main() diff --git a/train.py b/train.py new file mode 100644 index 0000000..1174f91 --- /dev/null +++ b/train.py @@ -0,0 +1,158 @@ +""" +Fine-Tune DistilBERT (multilingual) als Spam-Classifier. + +Kombiniert englische und deutsche Spam-Datensätze. +Für Produktion: Eigene gelabelte Spam/Ham-Mails unter data/ ablegen. +""" + +import argparse +from pathlib import Path + +import torch +from datasets import Dataset, concatenate_datasets, load_dataset +from sklearn.metrics import accuracy_score, precision_recall_fscore_support +from transformers import ( + AutoModelForSequenceClassification, + AutoTokenizer, + Trainer, + TrainingArguments, +) + + +def tokenize(batch, tokenizer): + return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=512) + + +def compute_metrics(pred): + labels = pred.label_ids + preds = pred.predictions.argmax(-1) + precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary") + acc = accuracy_score(labels, preds) + return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall} + + +def load_public_datasets(): + """Lädt und kombiniert öffentliche EN + DE Spam-Datensätze.""" + + # Englisch: SMS Spam Collection (~5.500 Nachrichten) + print("Loading English SMS spam dataset...") + en_sms = load_dataset("sms_spam", split="train") + en_data = en_sms.rename_column("sms", "text").rename_column("label", "labels") + + # Deutsch: Es gibt wenige öffentliche DE-Spam-Datensätze. + # Wir nutzen einen multilingualen E-Mail-Datensatz falls verfügbar, + # ansonsten den englischen als Basis. + # Für Produktion: Eigene DE-Mails unter data/train_de.csv ablegen. + print("Checking for custom German dataset at data/train_de.csv...") + de_path = Path("data/train_de.csv") + if de_path.exists(): + print(f" Found {de_path}, loading...") + de_data = load_dataset("csv", data_files=str(de_path), split="train") + # Erwartetes Format: Spalten "text" und "labels" (0=ham, 1=spam) + combined = concatenate_datasets([en_data, de_data]) + else: + print(" Not found. Using English-only dataset.") + print(" Tipp: Exportiere deine RSpamd-Bayes-Daten als CSV nach data/train_de.csv") + print(" Format: text,labels (labels: 0=ham, 1=spam)") + combined = en_data + + return combined + + +def load_custom_dataset(): + """Lädt benutzerdefinierte Datensätze aus data/.""" + train_path = Path("data/train.csv") + test_path = Path("data/test.csv") + + if not train_path.exists(): + return None + + print(f"Loading custom dataset from {train_path}...") + dataset = load_dataset("csv", data_files={"train": str(train_path)}, split="train") + + if test_path.exists(): + test_ds = load_dataset("csv", data_files={"test": str(test_path)}, split="train") + from datasets import DatasetDict + return DatasetDict({"train": dataset, "test": test_ds}) + + return dataset.train_test_split(test_size=0.2, seed=42, stratify_by_column="labels") + + +def main(): + parser = argparse.ArgumentParser(description="Train multilingual DistilBERT spam classifier") + parser.add_argument( + "--model-name", + default="distilbert-base-multilingual-cased", + help="Base model (default: multilingual)", + ) + parser.add_argument("--output-dir", default="./model", help="Where to save the trained model") + parser.add_argument("--epochs", type=int, default=3, help="Training epochs") + parser.add_argument("--batch-size", type=int, default=16, help="Batch size") + parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate") + parser.add_argument("--custom-data", action="store_true", help="Use only custom data from data/") + args = parser.parse_args() + + print(f"Loading base model: {args.model_name}") + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + model = AutoModelForSequenceClassification.from_pretrained( + args.model_name, + num_labels=2, + id2label={0: "ham", 1: "spam"}, + label2id={"ham": 0, "spam": 1}, + ) + + # --- Datensatz laden --- + if args.custom_data: + dataset = load_custom_dataset() + if dataset is None: + print("ERROR: --custom-data gesetzt aber data/train.csv nicht gefunden!") + return + else: + combined = load_public_datasets() + dataset = combined.train_test_split(test_size=0.2, seed=42, stratify_by_column="labels") + + tokenized = dataset.map(lambda batch: tokenize(batch, tokenizer), batched=True) + tokenized.set_format("torch", columns=["input_ids", "attention_mask", "labels"]) + + print(f"Train: {len(tokenized['train'])} samples, Test: {len(tokenized['test'])} samples") + + # --- Training --- + training_args = TrainingArguments( + output_dir=args.output_dir, + num_train_epochs=args.epochs, + per_device_train_batch_size=args.batch_size, + per_device_eval_batch_size=args.batch_size, + learning_rate=args.lr, + weight_decay=0.01, + eval_strategy="epoch", + save_strategy="epoch", + load_best_model_at_end=True, + metric_for_best_model="f1", + logging_steps=50, + fp16=torch.cuda.is_available(), + ) + + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized["train"], + eval_dataset=tokenized["test"], + compute_metrics=compute_metrics, + ) + + print("Starting training...") + trainer.train() + + # --- Evaluierung --- + results = trainer.evaluate() + print(f"\nResults: {results}") + + # --- Speichern --- + output_path = Path(args.output_dir) / "final" + trainer.save_model(str(output_path)) + tokenizer.save_pretrained(str(output_path)) + print(f"\nModel saved to {output_path}") + + +if __name__ == "__main__": + main()