Initial commit: SpamLLM - DistilBERT spam classifier for RSpamd
Multilingual spam classifier (DE/EN) with language detection. Non-DE/EN mails receive an additional spam score bonus. - train.py: Fine-tune distilbert-base-multilingual-cased on spam/ham data - server.py: FastAPI service with langdetect integration - rspamd/: Lua plugin and config for RSpamd integration - export_rspamd_data.py: Export Maildir folders to CSV training data - test_classify.py: Local model validation with DE/EN/foreign test cases Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
commit
38efd20b4d
7 changed files with 671 additions and 0 deletions
158
train.py
Normal file
158
train.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
"""
|
||||
Fine-Tune DistilBERT (multilingual) als Spam-Classifier.
|
||||
|
||||
Kombiniert englische und deutsche Spam-Datensätze.
|
||||
Für Produktion: Eigene gelabelte Spam/Ham-Mails unter data/ ablegen.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, concatenate_datasets, load_dataset
|
||||
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
|
||||
def tokenize(batch, tokenizer):
|
||||
return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=512)
|
||||
|
||||
|
||||
def compute_metrics(pred):
|
||||
labels = pred.label_ids
|
||||
preds = pred.predictions.argmax(-1)
|
||||
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")
|
||||
acc = accuracy_score(labels, preds)
|
||||
return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}
|
||||
|
||||
|
||||
def load_public_datasets():
|
||||
"""Lädt und kombiniert öffentliche EN + DE Spam-Datensätze."""
|
||||
|
||||
# Englisch: SMS Spam Collection (~5.500 Nachrichten)
|
||||
print("Loading English SMS spam dataset...")
|
||||
en_sms = load_dataset("sms_spam", split="train")
|
||||
en_data = en_sms.rename_column("sms", "text").rename_column("label", "labels")
|
||||
|
||||
# Deutsch: Es gibt wenige öffentliche DE-Spam-Datensätze.
|
||||
# Wir nutzen einen multilingualen E-Mail-Datensatz falls verfügbar,
|
||||
# ansonsten den englischen als Basis.
|
||||
# Für Produktion: Eigene DE-Mails unter data/train_de.csv ablegen.
|
||||
print("Checking for custom German dataset at data/train_de.csv...")
|
||||
de_path = Path("data/train_de.csv")
|
||||
if de_path.exists():
|
||||
print(f" Found {de_path}, loading...")
|
||||
de_data = load_dataset("csv", data_files=str(de_path), split="train")
|
||||
# Erwartetes Format: Spalten "text" und "labels" (0=ham, 1=spam)
|
||||
combined = concatenate_datasets([en_data, de_data])
|
||||
else:
|
||||
print(" Not found. Using English-only dataset.")
|
||||
print(" Tipp: Exportiere deine RSpamd-Bayes-Daten als CSV nach data/train_de.csv")
|
||||
print(" Format: text,labels (labels: 0=ham, 1=spam)")
|
||||
combined = en_data
|
||||
|
||||
return combined
|
||||
|
||||
|
||||
def load_custom_dataset():
|
||||
"""Lädt benutzerdefinierte Datensätze aus data/."""
|
||||
train_path = Path("data/train.csv")
|
||||
test_path = Path("data/test.csv")
|
||||
|
||||
if not train_path.exists():
|
||||
return None
|
||||
|
||||
print(f"Loading custom dataset from {train_path}...")
|
||||
dataset = load_dataset("csv", data_files={"train": str(train_path)}, split="train")
|
||||
|
||||
if test_path.exists():
|
||||
test_ds = load_dataset("csv", data_files={"test": str(test_path)}, split="train")
|
||||
from datasets import DatasetDict
|
||||
return DatasetDict({"train": dataset, "test": test_ds})
|
||||
|
||||
return dataset.train_test_split(test_size=0.2, seed=42, stratify_by_column="labels")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train multilingual DistilBERT spam classifier")
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
default="distilbert-base-multilingual-cased",
|
||||
help="Base model (default: multilingual)",
|
||||
)
|
||||
parser.add_argument("--output-dir", default="./model", help="Where to save the trained model")
|
||||
parser.add_argument("--epochs", type=int, default=3, help="Training epochs")
|
||||
parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
|
||||
parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate")
|
||||
parser.add_argument("--custom-data", action="store_true", help="Use only custom data from data/")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Loading base model: {args.model_name}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
args.model_name,
|
||||
num_labels=2,
|
||||
id2label={0: "ham", 1: "spam"},
|
||||
label2id={"ham": 0, "spam": 1},
|
||||
)
|
||||
|
||||
# --- Datensatz laden ---
|
||||
if args.custom_data:
|
||||
dataset = load_custom_dataset()
|
||||
if dataset is None:
|
||||
print("ERROR: --custom-data gesetzt aber data/train.csv nicht gefunden!")
|
||||
return
|
||||
else:
|
||||
combined = load_public_datasets()
|
||||
dataset = combined.train_test_split(test_size=0.2, seed=42, stratify_by_column="labels")
|
||||
|
||||
tokenized = dataset.map(lambda batch: tokenize(batch, tokenizer), batched=True)
|
||||
tokenized.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
|
||||
|
||||
print(f"Train: {len(tokenized['train'])} samples, Test: {len(tokenized['test'])} samples")
|
||||
|
||||
# --- Training ---
|
||||
training_args = TrainingArguments(
|
||||
output_dir=args.output_dir,
|
||||
num_train_epochs=args.epochs,
|
||||
per_device_train_batch_size=args.batch_size,
|
||||
per_device_eval_batch_size=args.batch_size,
|
||||
learning_rate=args.lr,
|
||||
weight_decay=0.01,
|
||||
eval_strategy="epoch",
|
||||
save_strategy="epoch",
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="f1",
|
||||
logging_steps=50,
|
||||
fp16=torch.cuda.is_available(),
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized["train"],
|
||||
eval_dataset=tokenized["test"],
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
print("Starting training...")
|
||||
trainer.train()
|
||||
|
||||
# --- Evaluierung ---
|
||||
results = trainer.evaluate()
|
||||
print(f"\nResults: {results}")
|
||||
|
||||
# --- Speichern ---
|
||||
output_path = Path(args.output_dir) / "final"
|
||||
trainer.save_model(str(output_path))
|
||||
tokenizer.save_pretrained(str(output_path))
|
||||
print(f"\nModel saved to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue