odtlearn.robust_oct#

Module Contents#

Classes#

RobustOCT

An optimal robust decision tree classifier, fitted on a given integer-valued

class odtlearn.robust_oct.RobustOCT(solver, depth=1, time_limit=1800, num_threads=None, verbose=False)[source]#

Bases: odtlearn.opt_ct.OptimalClassificationTree

An optimal robust decision tree classifier, fitted on a given integer-valued data set and a given cost-and-budget uncertainty set to produce a tree robust against distribution shifts.

Parameters:
solver: str

A string specifying the name of the solver to use to solve the MIP. Options are “Gurobi” and “CBC”. If the CBC binaries are not found, Gurobi will be used by default.

depthint, default=1

A parameter specifying the depth of the tree.

time_limitint, default=1800

The given time limit for solving the MIP in seconds.

num_threads: int, default=None

The number of threads the solver should use. If not specified, solver uses all available threads.

verbosebool, default = False

Flag for logging solver outputs.

fit(X, y, costs=None, budget=-1)[source]#

Fit an optimal robust classification tree given data, labels, costs of uncertainty, and budget of uncertainty

Parameters:
Xarray-like, shape (n_samples, n_features)

The training input samples.

yarray-like, shape (n_samples,)

The target values. An array of int.

costsarray-like, shape (n_samples, n_features), default = budget + 1

The costs of uncertainty

budgetfloat, default = -1

The budget of uncertainty

Returns:
selfobject

Returns self.

predict(X)[source]#

Given the input covariates, predict the class labels of each sample based on the fitted optimal robust classification tree

Parameters:
Xarray-like, shape (n_samples, n_features)

The input samples.

Returns:
yndarray, shape (n_samples,)

The label for each sample is the label of the closest sample seen during fit.

print_tree()[source]#

Print the fitted tree with the branching features, the threshold values for each branching node’s test, and the predictions asserted for each assignment node

The method uses the Gurobi model’s name for determining how to generate the tree

plot_tree(label='all', filled=True, rounded=False, precision=3, ax=None, fontsize=None, color_dict={'node': None, 'leaves': []}, edge_annotation=True, arrow_annotation_font_scale=0.8, debug=False)[source]#

Plot the fitted tree with the branching features, the threshold values for each branching node’s test, and the predictions asserted for each assignment node using matplotlib. The method uses the Gurobi model’s name for determining how to generate the tree. It does some preprocessing before passing the tree to the _MPLTreeExporter class from the sklearn package. The arguments for the plot_tree method are based on the arguments of the sklearn plot_tree function.

Parameters:
label{‘all’, ‘root’, ‘none’}, default=’all’
Whether to show informative labels for impurity, etc.
Options include ‘all’ to show at every node, ‘root’ to show only at
the top root node, or ‘none’ to not show at any node.
filledbool, default=False

When set to True, paint nodes to indicate majority class for classification, extremity of values for regression, or purity of node for multi-output.

roundedbool, default=False

When set to True, draw node boxes with rounded corners and use Helvetica fonts instead of Times-Roman.

precision: int, default=3

Number of digits of precision for floating point in the values of impurity, threshold and value attributes of each node.

axmatplotlib axis, default=None

Axes to plot to. If None, use current axis. Any previous content

is cleared.
fontsizeint, default=None

Size of text font. If None, determined automatically to fit figure.

color_dict: dict, default={“node”: None, “leaves”: []}

A dictionary specifying the colors for nodes and leaves in the plot in #RRGGBB format. If None, the colors are chosen using the sklearn plot_tree color palette