odtlearn.robust_oct¶
Classes¶
An optimal robust decision tree classifier, fit on a given integer-valued |
Module Contents¶
- class odtlearn.robust_oct.RobustOCT(solver, depth=1, time_limit=1800, num_threads=None, verbose=False)[source]¶
Bases:
odtlearn.opt_ct.OptimalClassificationTree
An optimal robust decision tree classifier, fit on a given integer-valued data set and a given cost-and-budget uncertainty set to produce a tree robust against distribution shifts.
- Parameters:
- solver: str
A string specifying the name of the solver to use to solve the MIP. Options are “Gurobi” and “CBC”. If the CBC binaries are not found, Gurobi will be used by default.
- depthint, default=1
A parameter specifying the depth of the tree.
- time_limitint, default=1800
The given time limit for solving the MIP in seconds.
- num_threads: int, default=None
The number of threads the solver should use. If not specified, solver uses all available threads.
- verbosebool, default = False
Flag for logging solver outputs.
- fit(X, y, costs=None, budget=-1)[source]¶
Fit a robust optimal classification tree to the given training data.
- Parameters:
- Xarray-like of shape (n_samples, n_features)
The training input samples. Features should be integer-valued.
- yarray-like of shape (n_samples,)
The target values (class labels) for the training samples.
- costsarray-like of shape (n_samples, n_features), optional
The costs of uncertainty for each feature and sample. If None, defaults to budget + 1.
- budgetfloat, optional
The budget of uncertainty. Default is -1.
- Returns:
- selfobject
Returns self.
- Raises:
- ValueError
If X contains non-integer values or if inputs have inconsistent numbers of samples.
Notes
This method fits the RobustOCT model using mixed-integer optimization while considering potential adversarial perturbations within the given budget. It sets up the optimization problem, solves it, and stores the results.
- predict(X)[source]¶
Predict class labels for samples in X using the fitted RobustOCT model.
- Parameters:
- Xarray-like of shape (n_samples, n_features)
The input samples for which to make predictions. Features should be integer-valued.
- Returns:
- y_predndarray of shape (n_samples,)
The predicted class labels for each sample in X.
- Raises:
- NotFittedError
If the model has not been fitted yet.
- ValueError
If X contains non-integer values or has a different number of features than the training data.
Notes
This method uses the robust decision tree learned during the fit process to classify new samples. It traverses the tree for each sample in X, following the branching decisions until reaching a leaf node, and returns the corresponding class prediction.
- print_tree()[source]¶
Print a text representation of the fitted tree.
This method prints the structure of the fitted tree, including the branching features, the threshold values for each branching node’s test, and the predictions asserted for each leaf node.
- Raises:
- NotFittedError
If the model has not been fitted yet.
Notes
The tree is printed in a depth-first manner, with each node represented by its index, branching feature and threshold (for internal nodes), or prediction (for leaf nodes).
- plot_tree(label='all', filled=True, rounded=False, precision=3, ax=None, fontsize=None, color_dict={'node': None, 'leaves': []}, edge_annotation=True, arrow_annotation_font_scale=0.8, debug=False, feature_names=None)[source]¶
Plot the fitted robust classification tree using matplotlib.
- Parameters:
- label{‘all’, ‘root’, ‘none’}, default=’all’
Whether to show informative labels for impurity, etc. Options include ‘all’ to show at every node, ‘root’ to show only at the top root node, or ‘none’ to not show at any node.
- filledbool, default=True
When set to True, paint nodes to indicate majority class for classification, extremity of values for regression, or purity of node for multi-output.
- roundedbool, default=False
When set to True, draw node boxes with rounded corners and use Helvetica fonts instead of Times-Roman.
- precisionint, default=3
Number of digits of precision for floating point in the values of impurity, threshold and value attributes of each node.
- axmatplotlib axis, default=None
Axes to plot to. If None, use current axis. Any previous content is cleared.
- fontsizeint, default=None
Size of text font. If None, determined automatically to fit figure.
- color_dictdict, default={“node”: None, “leaves”: []}
A dictionary specifying the colors for nodes and leaves in the plot in #RRGGBB format. If None, the colors are chosen using the sklearn plot_tree color palette.
- edge_annotationbool, default=True
Whether to display annotations on the edges.
- arrow_annotation_font_scalefloat, default=0.8
The font scale for the arrow annotations.
- debugbool, default=False
Whether to print debug information.
- feature_nameslist of str, default=None
A list of feature names to use for the plot. If None, the feature names from the fitted tree will be used. The feature names should be in the same order as the columns of the data used to fit the tree.
- Returns:
- matplotlib.axes.Axes
The matplotlib Axes containing the plotted tree.
Notes
This method uses the MPLPlotter class to visualize the robust classification tree.