Multilingual spam classifier (DE/EN) with language detection. Non-DE/EN mails receive an additional spam score bonus. - train.py: Fine-tune distilbert-base-multilingual-cased on spam/ham data - server.py: FastAPI service with langdetect integration - rspamd/: Lua plugin and config for RSpamd integration - export_rspamd_data.py: Export Maildir folders to CSV training data - test_classify.py: Local model validation with DE/EN/foreign test cases Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
158 lines
5.6 KiB
Python
158 lines
5.6 KiB
Python
"""
|
|
Fine-Tune DistilBERT (multilingual) als Spam-Classifier.
|
|
|
|
Kombiniert englische und deutsche Spam-Datensätze.
|
|
Für Produktion: Eigene gelabelte Spam/Ham-Mails unter data/ ablegen.
|
|
"""
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from datasets import Dataset, concatenate_datasets, load_dataset
|
|
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
|
from transformers import (
|
|
AutoModelForSequenceClassification,
|
|
AutoTokenizer,
|
|
Trainer,
|
|
TrainingArguments,
|
|
)
|
|
|
|
|
|
def tokenize(batch, tokenizer):
|
|
return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=512)
|
|
|
|
|
|
def compute_metrics(pred):
|
|
labels = pred.label_ids
|
|
preds = pred.predictions.argmax(-1)
|
|
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")
|
|
acc = accuracy_score(labels, preds)
|
|
return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
|
|
|
|
|
|
def load_public_datasets():
|
|
"""Lädt und kombiniert öffentliche EN + DE Spam-Datensätze."""
|
|
|
|
# Englisch: SMS Spam Collection (~5.500 Nachrichten)
|
|
print("Loading English SMS spam dataset...")
|
|
en_sms = load_dataset("sms_spam", split="train")
|
|
en_data = en_sms.rename_column("sms", "text").rename_column("label", "labels")
|
|
|
|
# Deutsch: Es gibt wenige öffentliche DE-Spam-Datensätze.
|
|
# Wir nutzen einen multilingualen E-Mail-Datensatz falls verfügbar,
|
|
# ansonsten den englischen als Basis.
|
|
# Für Produktion: Eigene DE-Mails unter data/train_de.csv ablegen.
|
|
print("Checking for custom German dataset at data/train_de.csv...")
|
|
de_path = Path("data/train_de.csv")
|
|
if de_path.exists():
|
|
print(f" Found {de_path}, loading...")
|
|
de_data = load_dataset("csv", data_files=str(de_path), split="train")
|
|
# Erwartetes Format: Spalten "text" und "labels" (0=ham, 1=spam)
|
|
combined = concatenate_datasets([en_data, de_data])
|
|
else:
|
|
print(" Not found. Using English-only dataset.")
|
|
print(" Tipp: Exportiere deine RSpamd-Bayes-Daten als CSV nach data/train_de.csv")
|
|
print(" Format: text,labels (labels: 0=ham, 1=spam)")
|
|
combined = en_data
|
|
|
|
return combined
|
|
|
|
|
|
def load_custom_dataset():
|
|
"""Lädt benutzerdefinierte Datensätze aus data/."""
|
|
train_path = Path("data/train.csv")
|
|
test_path = Path("data/test.csv")
|
|
|
|
if not train_path.exists():
|
|
return None
|
|
|
|
print(f"Loading custom dataset from {train_path}...")
|
|
dataset = load_dataset("csv", data_files={"train": str(train_path)}, split="train")
|
|
|
|
if test_path.exists():
|
|
test_ds = load_dataset("csv", data_files={"test": str(test_path)}, split="train")
|
|
from datasets import DatasetDict
|
|
return DatasetDict({"train": dataset, "test": test_ds})
|
|
|
|
return dataset.train_test_split(test_size=0.2, seed=42, stratify_by_column="labels")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Train multilingual DistilBERT spam classifier")
|
|
parser.add_argument(
|
|
"--model-name",
|
|
default="distilbert-base-multilingual-cased",
|
|
help="Base model (default: multilingual)",
|
|
)
|
|
parser.add_argument("--output-dir", default="./model", help="Where to save the trained model")
|
|
parser.add_argument("--epochs", type=int, default=3, help="Training epochs")
|
|
parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
|
|
parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate")
|
|
parser.add_argument("--custom-data", action="store_true", help="Use only custom data from data/")
|
|
args = parser.parse_args()
|
|
|
|
print(f"Loading base model: {args.model_name}")
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
|
model = AutoModelForSequenceClassification.from_pretrained(
|
|
args.model_name,
|
|
num_labels=2,
|
|
id2label={0: "ham", 1: "spam"},
|
|
label2id={"ham": 0, "spam": 1},
|
|
)
|
|
|
|
# --- Datensatz laden ---
|
|
if args.custom_data:
|
|
dataset = load_custom_dataset()
|
|
if dataset is None:
|
|
print("ERROR: --custom-data gesetzt aber data/train.csv nicht gefunden!")
|
|
return
|
|
else:
|
|
combined = load_public_datasets()
|
|
dataset = combined.train_test_split(test_size=0.2, seed=42, stratify_by_column="labels")
|
|
|
|
tokenized = dataset.map(lambda batch: tokenize(batch, tokenizer), batched=True)
|
|
tokenized.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
|
|
|
|
print(f"Train: {len(tokenized['train'])} samples, Test: {len(tokenized['test'])} samples")
|
|
|
|
# --- Training ---
|
|
training_args = TrainingArguments(
|
|
output_dir=args.output_dir,
|
|
num_train_epochs=args.epochs,
|
|
per_device_train_batch_size=args.batch_size,
|
|
per_device_eval_batch_size=args.batch_size,
|
|
learning_rate=args.lr,
|
|
weight_decay=0.01,
|
|
eval_strategy="epoch",
|
|
save_strategy="epoch",
|
|
load_best_model_at_end=True,
|
|
metric_for_best_model="f1",
|
|
logging_steps=50,
|
|
fp16=torch.cuda.is_available(),
|
|
)
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=tokenized["train"],
|
|
eval_dataset=tokenized["test"],
|
|
compute_metrics=compute_metrics,
|
|
)
|
|
|
|
print("Starting training...")
|
|
trainer.train()
|
|
|
|
# --- Evaluierung ---
|
|
results = trainer.evaluate()
|
|
print(f"\nResults: {results}")
|
|
|
|
# --- Speichern ---
|
|
output_path = Path(args.output_dir) / "final"
|
|
trainer.save_model(str(output_path))
|
|
tokenizer.save_pretrained(str(output_path))
|
|
print(f"\nModel saved to {output_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|