""" Fine-Tune DistilBERT (multilingual) als Spam-Classifier. Kombiniert englische und deutsche Spam-Datensätze. Für Produktion: Eigene gelabelte Spam/Ham-Mails unter data/ ablegen. """ import argparse from pathlib import Path import torch from datasets import Dataset, concatenate_datasets, load_dataset from sklearn.metrics import accuracy_score, precision_recall_fscore_support from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, ) def tokenize(batch, tokenizer): return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=512) def compute_metrics(pred): labels = pred.label_ids preds = pred.predictions.argmax(-1) precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary") acc = accuracy_score(labels, preds) return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall} def load_public_datasets(): """Lädt und kombiniert öffentliche EN + DE Spam-Datensätze.""" # Englisch: SMS Spam Collection (~5.500 Nachrichten) print("Loading English SMS spam dataset...") en_sms = load_dataset("sms_spam", split="train") en_data = en_sms.rename_column("sms", "text").rename_column("label", "labels") # Deutsch: Es gibt wenige öffentliche DE-Spam-Datensätze. # Wir nutzen einen multilingualen E-Mail-Datensatz falls verfügbar, # ansonsten den englischen als Basis. # Für Produktion: Eigene DE-Mails unter data/train_de.csv ablegen. print("Checking for custom German dataset at data/train_de.csv...") de_path = Path("data/train_de.csv") if de_path.exists(): print(f" Found {de_path}, loading...") de_data = load_dataset("csv", data_files=str(de_path), split="train") # Erwartetes Format: Spalten "text" und "labels" (0=ham, 1=spam) combined = concatenate_datasets([en_data, de_data]) else: print(" Not found. Using English-only dataset.") print(" Tipp: Exportiere deine RSpamd-Bayes-Daten als CSV nach data/train_de.csv") print(" Format: text,labels (labels: 0=ham, 1=spam)") combined = en_data return combined def load_custom_dataset(): """Lädt benutzerdefinierte Datensätze aus data/.""" train_path = Path("data/train.csv") test_path = Path("data/test.csv") if not train_path.exists(): return None print(f"Loading custom dataset from {train_path}...") dataset = load_dataset("csv", data_files={"train": str(train_path)}, split="train") if test_path.exists(): test_ds = load_dataset("csv", data_files={"test": str(test_path)}, split="train") from datasets import DatasetDict return DatasetDict({"train": dataset, "test": test_ds}) return dataset.train_test_split(test_size=0.2, seed=42, stratify_by_column="labels") def main(): parser = argparse.ArgumentParser(description="Train multilingual DistilBERT spam classifier") parser.add_argument( "--model-name", default="distilbert-base-multilingual-cased", help="Base model (default: multilingual)", ) parser.add_argument("--output-dir", default="./model", help="Where to save the trained model") parser.add_argument("--epochs", type=int, default=3, help="Training epochs") parser.add_argument("--batch-size", type=int, default=16, help="Batch size") parser.add_argument("--lr", type=float, default=2e-5, help="Learning rate") parser.add_argument("--custom-data", action="store_true", help="Use only custom data from data/") args = parser.parse_args() print(f"Loading base model: {args.model_name}") tokenizer = AutoTokenizer.from_pretrained(args.model_name) model = AutoModelForSequenceClassification.from_pretrained( args.model_name, num_labels=2, id2label={0: "ham", 1: "spam"}, label2id={"ham": 0, "spam": 1}, ) # --- Datensatz laden --- if args.custom_data: dataset = load_custom_dataset() if dataset is None: print("ERROR: --custom-data gesetzt aber data/train.csv nicht gefunden!") return else: combined = load_public_datasets() dataset = combined.train_test_split(test_size=0.2, seed=42, stratify_by_column="labels") tokenized = dataset.map(lambda batch: tokenize(batch, tokenizer), batched=True) tokenized.set_format("torch", columns=["input_ids", "attention_mask", "labels"]) print(f"Train: {len(tokenized['train'])} samples, Test: {len(tokenized['test'])} samples") # --- Training --- training_args = TrainingArguments( output_dir=args.output_dir, num_train_epochs=args.epochs, per_device_train_batch_size=args.batch_size, per_device_eval_batch_size=args.batch_size, learning_rate=args.lr, weight_decay=0.01, eval_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, metric_for_best_model="f1", logging_steps=50, fp16=torch.cuda.is_available(), ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized["train"], eval_dataset=tokenized["test"], compute_metrics=compute_metrics, ) print("Starting training...") trainer.train() # --- Evaluierung --- results = trainer.evaluate() print(f"\nResults: {results}") # --- Speichern --- output_path = Path(args.output_dir) / "final" trainer.save_model(str(output_path)) tokenizer.save_pretrained(str(output_path)) print(f"\nModel saved to {output_path}") if __name__ == "__main__": main()