odtlearn.robust_oct

Classes

RobustOCT

An optimal robust decision tree classifier, fit on a given integer-valued

Module Contents

class odtlearn.robust_oct.RobustOCT(solver, depth=1, time_limit=1800, num_threads=None, verbose=False)[source]

Bases: odtlearn.opt_ct.OptimalClassificationTree

An optimal robust decision tree classifier, fit on a given integer-valued data set and a given cost-and-budget uncertainty set to produce a tree robust against distribution shifts.

Parameters:
solver: str

A string specifying the name of the solver to use to solve the MIP. Options are “Gurobi” and “CBC”. If the CBC binaries are not found, Gurobi will be used by default.

depthint, default=1

A parameter specifying the depth of the tree.

time_limitint, default=1800

The given time limit for solving the MIP in seconds.

num_threads: int, default=None

The number of threads the solver should use. If not specified, solver uses all available threads.

verbosebool, default = False

Flag for logging solver outputs.

fit(X, y, costs=None, budget=-1)[source]

Fit a robust optimal classification tree to the given training data.

Parameters:
Xarray-like of shape (n_samples, n_features)

The training input samples. Features should be integer-valued.

yarray-like of shape (n_samples,)

The target values (class labels) for the training samples.

costsarray-like of shape (n_samples, n_features), optional

The costs of uncertainty for each feature and sample. If None, defaults to budget + 1.

budgetfloat, optional

The budget of uncertainty. Default is -1.

Returns:
selfobject

Returns self.

Raises:
ValueError

If X contains non-integer values or if inputs have inconsistent numbers of samples.

Notes

This method fits the RobustOCT model using mixed-integer optimization while considering potential adversarial perturbations within the given budget. It sets up the optimization problem, solves it, and stores the results.

predict(X)[source]

Predict class labels for samples in X using the fitted RobustOCT model.

Parameters:
Xarray-like of shape (n_samples, n_features)

The input samples for which to make predictions. Features should be integer-valued.

Returns:
y_predndarray of shape (n_samples,)

The predicted class labels for each sample in X.

Raises:
NotFittedError

If the model has not been fitted yet.

ValueError

If X contains non-integer values or has a different number of features than the training data.

Notes

This method uses the robust decision tree learned during the fit process to classify new samples. It traverses the tree for each sample in X, following the branching decisions until reaching a leaf node, and returns the corresponding class prediction.

print_tree()[source]

Print a text representation of the fitted tree.

This method prints the structure of the fitted tree, including the branching features, the threshold values for each branching node’s test, and the predictions asserted for each leaf node.

Raises:
NotFittedError

If the model has not been fitted yet.

Notes

The tree is printed in a depth-first manner, with each node represented by its index, branching feature and threshold (for internal nodes), or prediction (for leaf nodes).

plot_tree(label='all', filled=True, rounded=False, precision=3, ax=None, fontsize=None, color_dict={'node': None, 'leaves': []}, edge_annotation=True, arrow_annotation_font_scale=0.8, debug=False, feature_names=None)[source]

Plot the fitted robust classification tree using matplotlib.

Parameters:
label{‘all’, ‘root’, ‘none’}, default=’all’

Whether to show informative labels for impurity, etc. Options include ‘all’ to show at every node, ‘root’ to show only at the top root node, or ‘none’ to not show at any node.

filledbool, default=True

When set to True, paint nodes to indicate majority class for classification, extremity of values for regression, or purity of node for multi-output.

roundedbool, default=False

When set to True, draw node boxes with rounded corners and use Helvetica fonts instead of Times-Roman.

precisionint, default=3

Number of digits of precision for floating point in the values of impurity, threshold and value attributes of each node.

axmatplotlib axis, default=None

Axes to plot to. If None, use current axis. Any previous content is cleared.

fontsizeint, default=None

Size of text font. If None, determined automatically to fit figure.

color_dictdict, default={“node”: None, “leaves”: []}

A dictionary specifying the colors for nodes and leaves in the plot in #RRGGBB format. If None, the colors are chosen using the sklearn plot_tree color palette.

edge_annotationbool, default=True

Whether to display annotations on the edges.

arrow_annotation_font_scalefloat, default=0.8

The font scale for the arrow annotations.

debugbool, default=False

Whether to print debug information.

feature_nameslist of str, default=None

A list of feature names to use for the plot. If None, the feature names from the fitted tree will be used. The feature names should be in the same order as the columns of the data used to fit the tree.

Returns:
matplotlib.axes.Axes

The matplotlib Axes containing the plotted tree.

Notes

This method uses the MPLPlotter class to visualize the robust classification tree.