spamBERT/train.py
Carsten Abele 38efd20b4d Initial commit: SpamLLM - DistilBERT spam classifier for RSpamd
Multilingual spam classifier (DE/EN) with language detection.
Non-DE/EN mails receive an additional spam score bonus.

- train.py: Fine-tune distilbert-base-multilingual-cased on spam/ham data
- server.py: FastAPI service with langdetect integration
- rspamd/: Lua plugin and config for RSpamd integration
- export_rspamd_data.py: Export Maildir folders to CSV training data
- test_classify.py: Local model validation with DE/EN/foreign test cases

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-19 22:27:05 +01:00

158 lines
5.6 KiB
Python

"""
Fine-Tune DistilBERT (multilingual) als Spam-Classifier.
Kombiniert englische und deutsche Spam-Datensätze.
Für Produktion: Eigene gelabelte Spam/Ham-Mails unter data/ ablegen.
"""
import argparse
from pathlib import Path
import torch
from datasets import Dataset, concatenate_datasets, load_dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
Trainer,
TrainingArguments,
)
def tokenize(batch, tokenizer):
return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=512)
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")
acc = accuracy_score(labels, preds)
return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
def load_public_datasets():
"""Lädt und kombiniert öffentliche EN + DE Spam-Datensätze."""
# Englisch: SMS Spam Collection (~5.500 Nachrichten)
print("Loading English SMS spam dataset...")
en_sms = load_dataset("sms_spam", split="train")
en_data = en_sms.rename_column("sms", "text").rename_column("label", "labels")
# Deutsch: Es gibt wenige öffentliche DE-Spam-Datensätze.
# Wir nutzen einen multilingualen E-Mail-Datensatz falls verfügbar,
# ansonsten den englischen als Basis.
# Für Produktion: Eigene DE-Mails unter data/train_de.csv ablegen.
print("Checking for custom German dataset at data/train_de.csv...")
de_path = Path("data/train_de.csv")
if de_path.exists():
print(f" Found {de_path}, loading...")
de_data = load_dataset("csv", data_files=str(de_path), split="train")
# Erwartetes Format: Spalten "text" und "labels" (0=ham, 1=spam)
combined = concatenate_datasets([en_data, de_data])
else:
print(" Not found. Using English-only dataset.")
print(" Tipp: Exportiere deine RSpamd-Bayes-Daten als CSV nach data/train_de.csv")
print(" Format: text,labels (labels: 0=ham, 1=spam)")
combined = en_data
return combined
def load_custom_dataset():
"""Lädt benutzerdefinierte Datensätze aus data/."""
train_path = Path("data/train.csv")
test_path = Path("data/test.csv")
if not train_path.exists():
return None
print(f"Loading custom dataset from {train_path}...")
dataset = load_dataset("csv", data_files={"train": str(train_path)}, split="train")
if test_path.exists():
test_ds = load_dataset("csv", data_files={"test": str(test_path)}, split="train")
from datasets import DatasetDict
return DatasetDict({"train": dataset, "test": test_ds})
return dataset.train_test_split(test_size=0.2, seed=42, stratify_by_column="labels")
def main():
parser = argparse.ArgumentParser(description="Train multilingual DistilBERT spam classifier")
parser.add_argument(
"--model-name",
default="distilbert-base-multilingual-cased",
help="Base model (default: multilingual)",
)
parser.add_argument("--output-dir", default="./model", help="Where to save the trained model")
parser.add_argument("--epochs", type=int, default=3, help="Training epochs")
parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate")
parser.add_argument("--custom-data", action="store_true", help="Use only custom data from data/")
args = parser.parse_args()
print(f"Loading base model: {args.model_name}")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name,
num_labels=2,
id2label={0: "ham", 1: "spam"},
label2id={"ham": 0, "spam": 1},
)
# --- Datensatz laden ---
if args.custom_data:
dataset = load_custom_dataset()
if dataset is None:
print("ERROR: --custom-data gesetzt aber data/train.csv nicht gefunden!")
return
else:
combined = load_public_datasets()
dataset = combined.train_test_split(test_size=0.2, seed=42, stratify_by_column="labels")
tokenized = dataset.map(lambda batch: tokenize(batch, tokenizer), batched=True)
tokenized.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
print(f"Train: {len(tokenized['train'])} samples, Test: {len(tokenized['test'])} samples")
# --- Training ---
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.epochs,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
learning_rate=args.lr,
weight_decay=0.01,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1",
logging_steps=50,
fp16=torch.cuda.is_available(),
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["test"],
compute_metrics=compute_metrics,
)
print("Starting training...")
trainer.train()
# --- Evaluierung ---
results = trainer.evaluate()
print(f"\nResults: {results}")
# --- Speichern ---
output_path = Path(args.output_dir) / "final"
trainer.save_model(str(output_path))
tokenizer.save_pretrained(str(output_path))
print(f"\nModel saved to {output_path}")
if __name__ == "__main__":
main()