Skip to content

Commit a3bc7f8

Browse files
Copilotxadupre
andcommitted
Fix XGBRFClassifier multi-class ONNX conversion with num_parallel_tree > 1
Co-authored-by: xadupre <22452781+xadupre@users.noreply.github.com>
1 parent 930fae5 commit a3bc7f8

File tree

4 files changed

+40
-1
lines changed

4 files changed

+40
-1
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,6 @@ docs/examples/*.onnx
5757
docs/examples/*.dot
5858
docs/examples/*.png
5959
docs/examples/model
60+
dump.raw.txt
61+
model.json
62+
model2.json

onnxmltools/convert/xgboost/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def get_xgb_params(xgb_node):
5353
gbp = config["learner"]["gradient_booster"]["gbtree_model_param"]
5454
if "num_trees" in gbp:
5555
params["best_ntree_limit"] = int(gbp["num_trees"])
56+
if "num_parallel_tree" in gbp:
57+
params["num_parallel_tree"] = int(gbp["num_parallel_tree"])
5658
return params
5759

5860

onnxmltools/convert/xgboost/operator_converters/XGBoost.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,11 @@ def convert(scope, operator, container):
508508
attr_pairs["base_values"] = base_score * ncl
509509
else:
510510
attr_pairs["base_values"] = base_score
511-
attr_pairs["class_ids"] = [v % ncl for v in attr_pairs["class_treeids"]]
511+
num_parallel_tree = params.get("num_parallel_tree", 1)
512+
attr_pairs["class_ids"] = [
513+
(v % (ncl * num_parallel_tree)) // num_parallel_tree
514+
for v in attr_pairs["class_treeids"]
515+
]
512516

513517
classes = xgb_node.classes_
514518
if (

tests/xgboost/test_xgboost_converters_rf.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,36 @@ def test_xgbrf_classifier(self):
8484
)
8585
dump_data_and_model(x_test, xgb, conv_model, basename="SklearnXGBRFClassifier")
8686

87+
@unittest.skipIf(XGBRFRegressor is None, "xgboost is not available")
88+
def test_xgbrf_classifier_multiclass(self):
89+
"""Test XGBRFClassifier with multiple classes and multiple estimators."""
90+
import onnxruntime as rt
91+
92+
for n_estimators in [1, 3, 5]:
93+
xgb, x_test = _fit_classification_model(
94+
XGBRFClassifier(n_estimators=n_estimators), 3
95+
)
96+
conv_model = convert_xgboost(
97+
xgb,
98+
initial_types=[("input", FloatTensorType(shape=[None, None]))],
99+
target_opset=TARGET_OPSET,
100+
)
101+
sess = rt.InferenceSession(conv_model.SerializeToString())
102+
onnx_labels, onnx_probs = sess.run(None, {"input": x_test})
103+
xgb_labels = xgb.predict(x_test)
104+
xgb_probs = xgb.predict_proba(x_test)
105+
np.testing.assert_array_equal(
106+
onnx_labels,
107+
xgb_labels,
108+
err_msg=f"Label mismatch for n_estimators={n_estimators}",
109+
)
110+
np.testing.assert_allclose(
111+
onnx_probs,
112+
xgb_probs,
113+
atol=1e-5,
114+
err_msg=f"Probability mismatch for n_estimators={n_estimators}",
115+
)
116+
87117

88118
if __name__ == "__main__":
89119
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)