multiclass-classification
Multiclass Classification with XGBoost (Done Right)
For tabular multiclass classification, default to XGBoost. Same
reasoning as binary: it dominates Kaggle and real-world tabular
benchmarks, handles class imbalance with sample_weight, and gives
you SHAP-based explanations as a side effect.
The differences from binary are real and worth understanding:
objective="multi:softprob"+num_class=Ninstead ofbinary:logistic- Per-class metrics, not just accuracy. Accuracy can hide a catastrophic failure on a minority class.
sample_weightper row for imbalance (noscale_pos_weightfor multiclass)- Confusion matrix is the most important diagnostic — which classes get confused matters more than how many
predict_probareturns shape(n_samples, n_classes)and you often want top-K, not just argmax
When to use this skill
- The target has > 2 discrete classes (sentiment 1-5, product category A-Z, intent buckets, fault types)
- The classes are mutually exclusive (each example belongs to exactly one class — multilabel is a different skill)
- The features are tabular
- You have at least a few hundred examples per class
When NOT to use this skill
- Binary target → see
binary-classification - Multiple labels per example → see
multilabel-classification - Classes have natural ordering (rating 1-5) → consider an ordinal regression model (not XGBoost) — XGBoost will work but ignores the order
- Very many classes (> 100) and you only care about top-K → consider hierarchical softmax or learning-to-rank approaches
Project layout
<project>/
├── data/ # input parquet/csv
├── src/
│ ├── train.py # Pipeline + XGBClassifier(multi:softprob) + MLflow
│ ├── predict.py # reload, return top-K predictions per row
│ └── plots.py # confusion matrix, per-class metrics, ROC OvR, SHAP
├── notebooks/
│ └── demo.py # marimo walkthrough
└── mlruns/
Data access — ibis at the source
Same pattern as the other tabular bundles. Use ibis to read; materialize
once with .execute() for sklearn:
import ibis
table = ibis.duckdb.connect().read_parquet("data/train.parquet")
feature_cols = [c for c in table.columns if c.startswith("feature_")]
data = (
table
.select(*feature_cols, "target")
.execute()
)
X = data[feature_cols]
y = data["target"].astype(int)
n_classes = int(y.max()) + 1
The pipeline
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from xgboost import XGBClassifier
def build_pipeline(feature_cols, n_classes, seed):
return Pipeline([
("preprocess", ColumnTransformer([("num", StandardScaler(), feature_cols)])),
("clf", XGBClassifier(
n_estimators=300,
max_depth=5,
learning_rate=0.05,
subsample=0.8,
colsample_bytree=0.8,
reg_lambda=1.0,
objective="multi:softprob",
num_class=n_classes,
eval_metric="mlogloss",
random_state=seed,
n_jobs=-1,
)),
])
The only changes from binary are objective, num_class, and
eval_metric.
The five things that separate this from a tutorial
1. Per-class metrics — never just accuracy
Accuracy on a 5-class problem can be 80% while the model completely fails on one class. You need per-class precision, recall, F1, and support:
from sklearn.metrics import precision_recall_fscore_support, classification_report
precision, recall, f1, support = precision_recall_fscore_support(
y_test, y_pred, labels=list(range(n_classes)), zero_division=0,
)
print(classification_report(y_test, y_pred, digits=3))
A failing minority class shows up as F1 ≈ 0 for that class even when overall accuracy looks fine. Always log per-class F1 to MLflow, not just the macro/micro averages.
2. Macro vs micro vs weighted F1 — three different decisions
These three averaging strategies sound similar but encode very different priorities:
| Average | What it computes | When to use |
|---|---|---|
| macro | Unweighted mean of per-class F1 | All classes matter equally — minority classes drag the average down |
| micro | F1 over the full pooled prediction set (= accuracy on single-label multiclass) | You only care about overall correctness |
| weighted | Mean of per-class F1 weighted by class support | Class proportions reflect real-world frequencies; minority misses don't matter much |
from sklearn.metrics import f1_score
f1_macro = f1_score(y_test, y_pred, average="macro")
f1_micro = f1_score(y_test, y_pred, average="micro")
f1_weighted = f1_score(y_test, y_pred, average="weighted")
Default to macro F1 when each class is independently important. A model with 90% accuracy and 0% F1 on a minority class is bad — macro F1 surfaces that, weighted F1 hides it.
3. sample_weight per row, not scale_pos_weight
XGBoost's scale_pos_weight only works for binary classification.
For multiclass, you pass an explicit sample_weight array — one
weight per row — to fit. The standard "balanced" choice gives each
class an equal total weight regardless of its frequency:
from sklearn.utils.class_weight import compute_sample_weight
sample_weight = compute_sample_weight(class_weight="balanced", y=y_train)
pipeline.fit(X_train, y_train, clf__sample_weight=sample_weight)
The clf__ prefix routes the parameter to the clf step of the
Pipeline. Without it, the Pipeline doesn't know which step you mean.
4. Confusion matrix is the primary diagnostic
For multiclass, the confusion matrix tells you which classes get confused for which — and that's the input to feature engineering. "Class 2 is confused with class 4 60% of the time" tells you to look for features that distinguish those two.
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test, y_pred, labels=list(range(n_classes)))
# Normalize by row to see "given true class i, what % goes where"
cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
Always plot two confusion matrices: the raw counts and the row-normalized version. Counts show the magnitude of errors; the normalized version shows the per-class structure.
5. Top-K accuracy + SHAP per class
For problems with many classes, top-K accuracy is often the actual product metric — "did the right answer appear in the top 3 suggestions?" rather than "was the top-1 prediction exactly right?":
from sklearn.metrics import top_k_accuracy_score
top_3 = top_k_accuracy_score(y_test, y_proba, k=3, labels=list(range(n_classes)))
For SHAP, multiclass returns a 3D array (n_samples, n_features, n_classes) instead of binary's 2D. Slice to one class to plot:
import shap
explainer = shap.TreeExplainer(clf)
shap_values = explainer(X_test_sample)
# shap_values.values shape: (n_samples, n_features, n_classes)
# Pick a class to explain
class_idx = 0
sliced = shap.Explanation(
values=shap_values.values[:, :, class_idx],
base_values=shap_values.base_values[:, class_idx],
data=shap_values.data,
feature_names=feature_cols,
)
shap.summary_plot(sliced, X_test_sample, feature_names=feature_cols)
The class-level SHAP explanation tells you "which features push predictions toward / away from class 0?" Repeat per class for a full picture, or focus on the class you care about most.
MLflow logging
Every run logs:
| Kind | What |
|---|---|
params |
data path, n_rows, n_features, n_classes, seed, hyperparameters, use_sample_weights |
metrics |
accuracy, F1 macro / micro / weighted (all three), log-loss, top-K accuracy, per-class F1 (one metric per class) |
tags |
data hash |
artifacts |
model, class balance bar, confusion matrix (raw + normalized), per-class metrics bar chart, ROC OvR, SHAP summary per class |
The per-class F1 logging is what catches the silent-minority-class failure mode. Don't skip it.
Common pitfalls
- Reporting only accuracy. Hides minority-class failure. Use F1 macro alongside accuracy.
- Skipping the confusion matrix. It's the most informative single plot for multiclass debugging.
- Using
class_weight="balanced"on the pipeline instead of passingsample_weighttofit. XGBoost's wrapper acceptssample_weight; the older sklearnclass_weightparameter is ignored. - Forgetting
num_class. If you setobjective="multi:softprob"but forgetnum_class=N, XGBoost will infer it from the labels — usually correctly, but explicit is safer. - Treating
predict_probaoutput as 2D when it's 3D for SHAP. Multiclass SHAP returns one value per (sample, feature, class) triple. Slice before plotting. - Using accuracy as the early-stopping metric on imbalanced data.
Use
mloglossor weighted F1 instead. - Ignoring class order in ordinal targets. If your classes have an order (1 < 2 < 3 < 4 < 5), XGBoost still treats them as nominal. It works but you're throwing away information. Consider ordinal regression for true ordinal data.
Worked example
See demo.py (marimo notebook). It generates a 5-class imbalanced
synthetic dataset (the most-frequent class is 5× more common than
the rarest), fits XGBoost two ways — without and with
sample_weight — and shows the difference in per-class F1. The
demo's punchline is that on imbalanced data, the minority-class
F1 is dramatically improved by sample_weight even though overall
accuracy barely changes — the exact thing macro F1 catches and
weighted F1 hides.
More from brojonat/llmsrules
ibis-data
Use Ibis for database-agnostic data access in Python. Use when writing data queries, connecting to databases (DuckDB, PostgreSQL, SQLite), or building portable data pipelines that should work across backends.
13go-service
Build Go microservices with stdlib HTTP handlers, sqlc, urfave/cli, and slog. Use when creating or modifying a Go HTTP server, adding routes, middleware, database queries, or CLI commands.
13temporal-go
Build Temporal workflow applications in Go. Use when creating or modifying Temporal workflows, activities, workers, clients, signals, queries, updates, retry policies, saga patterns, or writing Temporal tests.
13parquet-analysis
Analyze parquet files using Python and Ibis. Use when the user wants to explore, transform, or analyze parquet data files, perform aggregations, joins, or export results. Works with local parquet files and provides database-agnostic data operations.
12ducklake
Work with DuckLake, an open lakehouse format built on DuckDB. Use when creating or querying DuckLake tables, managing snapshots, time travel, schema evolution, partitioning, or lakehouse maintenance operations.
12temporal-python
Build Temporal applications in Python using the temporalio SDK. Use when creating workflows, activities, workers, clients, signals, queries, updates, child workflows, timers, retry policies, saga/compensation patterns, testing, or any durable execution pattern in Python.
12