odtlearn.opt_ct¶
Classes¶
A class for learning optimal classification trees using mixed-integer programming. |
Module Contents¶
- class odtlearn.opt_ct.OptimalClassificationTree(solver: str, depth: int, time_limit: int, num_threads: None | int, verbose: bool)[source]¶
Bases:
odtlearn.opt_dt.OptimalDecisionTree
A class for learning optimal classification trees using mixed-integer programming.
- Parameters:
- solverstr
The solver to use for the MIP formulation. Currently, only “gurobi” and “CBC” are supported.
- depthint
The maximum depth of the tree to be learned.
- time_limitint
The time limit (in seconds) for solving the MIP formulation.
- num_threadsint, optional
The number of threads the solver should use. If not specified, solver uses all available threads.
- verbosebool, default=False
Whether to print verbose output during the tree learning process.
Notes
This class extends the
OptimalDecisionTree
base class to learn optimal classification trees. It formulates the problem as a mixed-integer program and solves it using either the Gurobi or CBC solver.- Attributes:
- b_valuenumpy.ndarray
The values of the branching decision variables in the learned tree.
- w_valuenumpy.ndarray
The values of the prediction decision variables in the learned tree.
- p_valuenumpy.ndarray
The values of the pruning decision variables in the learned tree.
Methods
fit(X, y)
Fit the optimal classification tree to the given training data.
predict(X)
Make predictions using the fitted optimal classification tree.
print_tree()
Print the structure of the fitted optimal classification tree.
plot_tree(*kwargs)
Plot the fitted optimal classification tree using matplotlib.
- print_tree() None [source]¶
Print a text representation of the fitted tree.
This method prints the structure of the fitted tree, including the branching features, the threshold values for each branching node’s test, and the predictions asserted for each leaf node.
- Raises:
- NotFittedError
If the model has not been fitted yet.
Notes
The tree is printed in a depth-first manner, with each node represented by its index, branching feature and threshold (for internal nodes), or prediction (for leaf nodes).
- plot_tree(label='all', filled=True, rounded=False, precision=3, ax=None, fontsize=None, color_dict={'node': None, 'leaves': []}, edge_annotation=True, arrow_annotation_font_scale=0.8, debug=False, distance=1.0, feature_names=None)[source]¶
Plot the fitted tree using matplotlib.
- Parameters:
- label{‘all’, ‘root’, ‘none’}, default=’all’
Whether to show informative labels for impurity, etc. Options include ‘all’ to show at every node, ‘root’ to show only at the top root node, or ‘none’ to not show at any node.
- filledbool, default=True
When set to True, paint nodes to indicate majority treatment for prescriptive trees.
- roundedbool, default=False
When set to True, draw node boxes with rounded corners and use Helvetica fonts instead of Times-Roman.
- precisionint, default=3
Number of digits of precision for floating point in the values of impurity, threshold and value attributes of each node.
- axmatplotlib axis, default=None
Axes to plot to. If None, use current axis. Any previous content is cleared.
- fontsizeint, default=None
Size of text font. If None, determined automatically to fit figure.
- color_dictdict, optional
A dictionary specifying the colors for nodes and leaves. Default: {“node”: None, “leaves”: []}
- edge_annotationbool, optional (default=True)
Whether to display annotations on the edges.
- arrow_annotation_font_scalefloat, optional (default=0.5)
The font scale for the arrow annotations.
- debugbool, optional (default=False)
Whether to print debug information.
- distance: float, default=1.0
Adjust distance between levels in the tree.
- feature_nameslist of str, default=None
A list of feature names to use for the plot. If None, the feature names from the fitted tree will be used. The feature names should be in the same order as the columns of the data used to fit the tree.
- Returns:
- axmatplotlib.axes.Axes
The matplotlib Axes containing the plot.
- Raises:
- NotFittedError
If the model has not been fitted yet.
Notes
This method visualizes the fitted tree structure using matplotlib. Each node in the tree is represented by a box, with arrows indicating the branching structure.