odtlearn.robust_oct
#
Module Contents#
Classes#
An optimal robust decision tree classifier, fitted on a given integer-valued |
- class odtlearn.robust_oct.RobustOCT(solver, depth=1, time_limit=1800, num_threads=None, verbose=False)[source]#
Bases:
odtlearn.opt_ct.OptimalClassificationTree
An optimal robust decision tree classifier, fitted on a given integer-valued data set and a given cost-and-budget uncertainty set to produce a tree robust against distribution shifts.
- Parameters:
- solver: str
A string specifying the name of the solver to use to solve the MIP. Options are “Gurobi” and “CBC”. If the CBC binaries are not found, Gurobi will be used by default.
- depthint, default=1
A parameter specifying the depth of the tree.
- time_limitint, default=1800
The given time limit for solving the MIP in seconds.
- num_threads: int, default=None
The number of threads the solver should use. If not specified, solver uses all available threads.
- verbosebool, default = False
Flag for logging solver outputs.
- fit(X, y, costs=None, budget=-1)[source]#
Fit an optimal robust classification tree given data, labels, costs of uncertainty, and budget of uncertainty
- Parameters:
- Xarray-like, shape (n_samples, n_features)
The training input samples.
- yarray-like, shape (n_samples,)
The target values. An array of int.
- costsarray-like, shape (n_samples, n_features), default = budget + 1
The costs of uncertainty
- budgetfloat, default = -1
The budget of uncertainty
- Returns:
- selfobject
Returns self.
- predict(X)[source]#
Given the input covariates, predict the class labels of each sample based on the fitted optimal robust classification tree
- Parameters:
- Xarray-like, shape (n_samples, n_features)
The input samples.
- Returns:
- yndarray, shape (n_samples,)
The label for each sample is the label of the closest sample seen during fit.
- print_tree()[source]#
Print the fitted tree with the branching features, the threshold values for each branching node’s test, and the predictions asserted for each assignment node
The method uses the Gurobi model’s name for determining how to generate the tree
- plot_tree(label='all', filled=True, rounded=False, precision=3, ax=None, fontsize=None, color_dict={'node': None, 'leaves': []}, edge_annotation=True, arrow_annotation_font_scale=0.8, debug=False)[source]#
Plot the fitted tree with the branching features, the threshold values for each branching node’s test, and the predictions asserted for each assignment node using matplotlib. The method uses the Gurobi model’s name for determining how to generate the tree. It does some preprocessing before passing the tree to the _MPLTreeExporter class from the sklearn package. The arguments for the plot_tree method are based on the arguments of the sklearn plot_tree function.
- Parameters:
- label{‘all’, ‘root’, ‘none’}, default=’all’
- Whether to show informative labels for impurity, etc.
- Options include ‘all’ to show at every node, ‘root’ to show only at
- the top root node, or ‘none’ to not show at any node.
- filledbool, default=False
When set to
True
, paint nodes to indicate majority class for classification, extremity of values for regression, or purity of node for multi-output.- roundedbool, default=False
When set to
True
, draw node boxes with rounded corners and use Helvetica fonts instead of Times-Roman.- precision: int, default=3
Number of digits of precision for floating point in the values of impurity, threshold and value attributes of each node.
- axmatplotlib axis, default=None
Axes to plot to. If None, use current axis. Any previous content
- is cleared.
- fontsizeint, default=None
Size of text font. If None, determined automatically to fit figure.
- color_dict: dict, default={“node”: None, “leaves”: []}
A dictionary specifying the colors for nodes and leaves in the plot in #RRGGBB format. If None, the colors are chosen using the sklearn plot_tree color palette