feat: initial commit - Phase 1 & 2 core features

This commit is contained in:
hiderfong
2026-04-22 17:07:33 +08:00
commit 1773bda06b
25005 changed files with 6252106 additions and 0 deletions
@@ -0,0 +1,24 @@
"""
The :mod:`sklearn.tree` module includes decision tree-based models for
classification and regression.
"""
from ._classes import (
BaseDecisionTree,
DecisionTreeClassifier,
DecisionTreeRegressor,
ExtraTreeClassifier,
ExtraTreeRegressor,
)
from ._export import export_graphviz, export_text, plot_tree
__all__ = [
"BaseDecisionTree",
"DecisionTreeClassifier",
"DecisionTreeRegressor",
"ExtraTreeClassifier",
"ExtraTreeRegressor",
"export_graphviz",
"plot_tree",
"export_text",
]
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,115 @@
# Authors: Gilles Louppe <g.louppe@gmail.com>
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
# Brian Holt <bdholt1@gmail.com>
# Joel Nothman <joel.nothman@gmail.com>
# Arnaud Joly <arnaud.v.joly@gmail.com>
# Jacob Schreiber <jmschreiber91@gmail.com>
#
# License: BSD 3 clause
# See _criterion.pyx for implementation details.
from ..utils._typedefs cimport float64_t, int8_t, intp_t
cdef class Criterion:
# The criterion computes the impurity of a node and the reduction of
# impurity of a split on that node. It also computes the output statistics
# such as the mean in regression and class probabilities in classification.
# Internal structures
cdef const float64_t[:, ::1] y # Values of y
cdef const float64_t[:] sample_weight # Sample weights
cdef const intp_t[:] sample_indices # Sample indices in X, y
cdef intp_t start # samples[start:pos] are the samples in the left node
cdef intp_t pos # samples[pos:end] are the samples in the right node
cdef intp_t end
cdef intp_t n_missing # Number of missing values for the feature being evaluated
cdef bint missing_go_to_left # Whether missing values go to the left node
cdef intp_t n_outputs # Number of outputs
cdef intp_t n_samples # Number of samples
cdef intp_t n_node_samples # Number of samples in the node (end-start)
cdef float64_t weighted_n_samples # Weighted number of samples (in total)
cdef float64_t weighted_n_node_samples # Weighted number of samples in the node
cdef float64_t weighted_n_left # Weighted number of samples in the left node
cdef float64_t weighted_n_right # Weighted number of samples in the right node
cdef float64_t weighted_n_missing # Weighted number of samples that are missing
# The criterion object is maintained such that left and right collected
# statistics correspond to samples[start:pos] and samples[pos:end].
# Methods
cdef int init(
self,
const float64_t[:, ::1] y,
const float64_t[:] sample_weight,
float64_t weighted_n_samples,
const intp_t[:] sample_indices,
intp_t start,
intp_t end
) except -1 nogil
cdef void init_sum_missing(self)
cdef void init_missing(self, intp_t n_missing) noexcept nogil
cdef int reset(self) except -1 nogil
cdef int reverse_reset(self) except -1 nogil
cdef int update(self, intp_t new_pos) except -1 nogil
cdef float64_t node_impurity(self) noexcept nogil
cdef void children_impurity(
self,
float64_t* impurity_left,
float64_t* impurity_right
) noexcept nogil
cdef void node_value(
self,
float64_t* dest
) noexcept nogil
cdef void clip_node_value(
self,
float64_t* dest,
float64_t lower_bound,
float64_t upper_bound
) noexcept nogil
cdef float64_t middle_value(self) noexcept nogil
cdef float64_t impurity_improvement(
self,
float64_t impurity_parent,
float64_t impurity_left,
float64_t impurity_right
) noexcept nogil
cdef float64_t proxy_impurity_improvement(self) noexcept nogil
cdef bint check_monotonicity(
self,
int8_t monotonic_cst,
float64_t lower_bound,
float64_t upper_bound,
) noexcept nogil
cdef inline bint _check_monotonicity(
self,
int8_t monotonic_cst,
float64_t lower_bound,
float64_t upper_bound,
float64_t sum_left,
float64_t sum_right,
) noexcept nogil
cdef class ClassificationCriterion(Criterion):
"""Abstract criterion for classification."""
cdef intp_t[::1] n_classes
cdef intp_t max_n_classes
cdef float64_t[:, ::1] sum_total # The sum of the weighted count of each label.
cdef float64_t[:, ::1] sum_left # Same as above, but for the left side of the split
cdef float64_t[:, ::1] sum_right # Same as above, but for the right side of the split
cdef float64_t[:, ::1] sum_missing # Same as above, but for missing values in X
cdef class RegressionCriterion(Criterion):
"""Abstract regression criterion."""
cdef float64_t sq_sum_total
cdef float64_t[::1] sum_total # The sum of w*y.
cdef float64_t[::1] sum_left # Same as above, but for the left side of the split
cdef float64_t[::1] sum_right # Same as above, but for the right side of the split
cdef float64_t[::1] sum_missing # Same as above, but for missing values in X
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,188 @@
# Authors: William Mill (bill@billmill.org)
# License: BSD 3 clause
import numpy as np
class DrawTree:
def __init__(self, tree, parent=None, depth=0, number=1):
self.x = -1.0
self.y = depth
self.tree = tree
self.children = [
DrawTree(c, self, depth + 1, i + 1) for i, c in enumerate(tree.children)
]
self.parent = parent
self.thread = None
self.mod = 0
self.ancestor = self
self.change = self.shift = 0
self._lmost_sibling = None
# this is the number of the node in its group of siblings 1..n
self.number = number
def left(self):
return self.thread or len(self.children) and self.children[0]
def right(self):
return self.thread or len(self.children) and self.children[-1]
def lbrother(self):
n = None
if self.parent:
for node in self.parent.children:
if node == self:
return n
else:
n = node
return n
def get_lmost_sibling(self):
if not self._lmost_sibling and self.parent and self != self.parent.children[0]:
self._lmost_sibling = self.parent.children[0]
return self._lmost_sibling
lmost_sibling = property(get_lmost_sibling)
def __str__(self):
return "%s: x=%s mod=%s" % (self.tree, self.x, self.mod)
def __repr__(self):
return self.__str__()
def max_extents(self):
extents = [c.max_extents() for c in self.children]
extents.append((self.x, self.y))
return np.max(extents, axis=0)
def buchheim(tree):
dt = first_walk(DrawTree(tree))
min = second_walk(dt)
if min < 0:
third_walk(dt, -min)
return dt
def third_walk(tree, n):
tree.x += n
for c in tree.children:
third_walk(c, n)
def first_walk(v, distance=1.0):
if len(v.children) == 0:
if v.lmost_sibling:
v.x = v.lbrother().x + distance
else:
v.x = 0.0
else:
default_ancestor = v.children[0]
for w in v.children:
first_walk(w)
default_ancestor = apportion(w, default_ancestor, distance)
# print("finished v =", v.tree, "children")
execute_shifts(v)
midpoint = (v.children[0].x + v.children[-1].x) / 2
w = v.lbrother()
if w:
v.x = w.x + distance
v.mod = v.x - midpoint
else:
v.x = midpoint
return v
def apportion(v, default_ancestor, distance):
w = v.lbrother()
if w is not None:
# in buchheim notation:
# i == inner; o == outer; r == right; l == left; r = +; l = -
vir = vor = v
vil = w
vol = v.lmost_sibling
sir = sor = v.mod
sil = vil.mod
sol = vol.mod
while vil.right() and vir.left():
vil = vil.right()
vir = vir.left()
vol = vol.left()
vor = vor.right()
vor.ancestor = v
shift = (vil.x + sil) - (vir.x + sir) + distance
if shift > 0:
move_subtree(ancestor(vil, v, default_ancestor), v, shift)
sir = sir + shift
sor = sor + shift
sil += vil.mod
sir += vir.mod
sol += vol.mod
sor += vor.mod
if vil.right() and not vor.right():
vor.thread = vil.right()
vor.mod += sil - sor
else:
if vir.left() and not vol.left():
vol.thread = vir.left()
vol.mod += sir - sol
default_ancestor = v
return default_ancestor
def move_subtree(wl, wr, shift):
subtrees = wr.number - wl.number
# print(wl.tree, "is conflicted with", wr.tree, 'moving', subtrees,
# 'shift', shift)
# print wl, wr, wr.number, wl.number, shift, subtrees, shift/subtrees
wr.change -= shift / subtrees
wr.shift += shift
wl.change += shift / subtrees
wr.x += shift
wr.mod += shift
def execute_shifts(v):
shift = change = 0
for w in v.children[::-1]:
# print("shift:", w, shift, w.change)
w.x += shift
w.mod += shift
change += w.change
shift += w.shift + change
def ancestor(vil, v, default_ancestor):
# the relevant text is at the bottom of page 7 of
# "Improving Walker's Algorithm to Run in Linear Time" by Buchheim et al,
# (2002)
# https://citeseerx.ist.psu.edu/doc_view/pid/1f41c3c2a4880dc49238e46d555f16d28da2940d
if vil.ancestor in v.parent.children:
return vil.ancestor
else:
return default_ancestor
def second_walk(v, m=0, depth=0, min=None):
v.x += m
v.y = depth
if min is None or v.x < min:
min = v.x
for w in v.children:
min = second_walk(w, m + v.mod, depth + 1, min)
return min
class Tree:
def __init__(self, label="", node_id=-1, *children):
self.label = label
self.node_id = node_id
if children:
self.children = children
else:
self.children = []
@@ -0,0 +1,110 @@
# Authors: Gilles Louppe <g.louppe@gmail.com>
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
# Brian Holt <bdholt1@gmail.com>
# Joel Nothman <joel.nothman@gmail.com>
# Arnaud Joly <arnaud.v.joly@gmail.com>
# Jacob Schreiber <jmschreiber91@gmail.com>
#
# License: BSD 3 clause
# See _splitter.pyx for details.
from ._criterion cimport Criterion
from ._tree cimport ParentInfo
from ..utils._typedefs cimport float32_t, float64_t, intp_t, int8_t, int32_t, uint32_t
cdef struct SplitRecord:
# Data to track sample split
intp_t feature # Which feature to split on.
intp_t pos # Split samples array at the given position,
# # i.e. count of samples below threshold for feature.
# # pos is >= end if the node is a leaf.
float64_t threshold # Threshold to split at.
float64_t improvement # Impurity improvement given parent node.
float64_t impurity_left # Impurity of the left split.
float64_t impurity_right # Impurity of the right split.
float64_t lower_bound # Lower bound on value of both children for monotonicity
float64_t upper_bound # Upper bound on value of both children for monotonicity
unsigned char missing_go_to_left # Controls if missing values go to the left node.
intp_t n_missing # Number of missing values for the feature being split on
cdef class Splitter:
# The splitter searches in the input space for a feature and a threshold
# to split the samples samples[start:end].
#
# The impurity computations are delegated to a criterion object.
# Internal structures
cdef public Criterion criterion # Impurity criterion
cdef public intp_t max_features # Number of features to test
cdef public intp_t min_samples_leaf # Min samples in a leaf
cdef public float64_t min_weight_leaf # Minimum weight in a leaf
cdef object random_state # Random state
cdef uint32_t rand_r_state # sklearn_rand_r random number state
cdef intp_t[::1] samples # Sample indices in X, y
cdef intp_t n_samples # X.shape[0]
cdef float64_t weighted_n_samples # Weighted number of samples
cdef intp_t[::1] features # Feature indices in X
cdef intp_t[::1] constant_features # Constant features indices
cdef intp_t n_features # X.shape[1]
cdef float32_t[::1] feature_values # temp. array holding feature values
cdef intp_t start # Start position for the current node
cdef intp_t end # End position for the current node
cdef const float64_t[:, ::1] y
# Monotonicity constraints for each feature.
# The encoding is as follows:
# -1: monotonic decrease
# 0: no constraint
# +1: monotonic increase
cdef const int8_t[:] monotonic_cst
cdef bint with_monotonic_cst
cdef const float64_t[:] sample_weight
# The samples vector `samples` is maintained by the Splitter object such
# that the samples contained in a node are contiguous. With this setting,
# `node_split` reorganizes the node samples `samples[start:end]` in two
# subsets `samples[start:pos]` and `samples[pos:end]`.
# The 1-d `features` array of size n_features contains the features
# indices and allows fast sampling without replacement of features.
# The 1-d `constant_features` array of size n_features holds in
# `constant_features[:n_constant_features]` the feature ids with
# constant values for all the samples that reached a specific node.
# The value `n_constant_features` is given by the parent node to its
# child nodes. The content of the range `[n_constant_features:]` is left
# undefined, but preallocated for performance reasons
# This allows optimization with depth-based tree building.
# Methods
cdef int init(
self,
object X,
const float64_t[:, ::1] y,
const float64_t[:] sample_weight,
const unsigned char[::1] missing_values_in_feature_mask,
) except -1
cdef int node_reset(
self,
intp_t start,
intp_t end,
float64_t* weighted_n_node_samples
) except -1 nogil
cdef int node_split(
self,
ParentInfo* parent,
SplitRecord* split,
) except -1 nogil
cdef void node_value(self, float64_t* dest) noexcept nogil
cdef void clip_node_value(self, float64_t* dest, float64_t lower_bound, float64_t upper_bound) noexcept nogil
cdef float64_t node_impurity(self) noexcept nogil
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,123 @@
# Authors: Gilles Louppe <g.louppe@gmail.com>
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
# Brian Holt <bdholt1@gmail.com>
# Joel Nothman <joel.nothman@gmail.com>
# Arnaud Joly <arnaud.v.joly@gmail.com>
# Jacob Schreiber <jmschreiber91@gmail.com>
# Nelson Liu <nelson@nelsonliu.me>
#
# License: BSD 3 clause
# See _tree.pyx for details.
import numpy as np
cimport numpy as cnp
from ..utils._typedefs cimport float32_t, float64_t, intp_t, int32_t, uint32_t
from ._splitter cimport Splitter
from ._splitter cimport SplitRecord
cdef struct Node:
# Base storage structure for the nodes in a Tree object
intp_t left_child # id of the left child of the node
intp_t right_child # id of the right child of the node
intp_t feature # Feature used for splitting the node
float64_t threshold # Threshold value at the node
float64_t impurity # Impurity of the node (i.e., the value of the criterion)
intp_t n_node_samples # Number of samples at the node
float64_t weighted_n_node_samples # Weighted number of samples at the node
unsigned char missing_go_to_left # Whether features have missing values
cdef struct ParentInfo:
# Structure to store information about the parent of a node
# This is passed to the splitter, to provide information about the previous split
float64_t lower_bound # the lower bound of the parent's impurity
float64_t upper_bound # the upper bound of the parent's impurity
float64_t impurity # the impurity of the parent
intp_t n_constant_features # the number of constant features found in parent
cdef class Tree:
# The Tree object is a binary tree structure constructed by the
# TreeBuilder. The tree structure is used for predictions and
# feature importances.
# Input/Output layout
cdef public intp_t n_features # Number of features in X
cdef intp_t* n_classes # Number of classes in y[:, k]
cdef public intp_t n_outputs # Number of outputs in y
cdef public intp_t max_n_classes # max(n_classes)
# Inner structures: values are stored separately from node structure,
# since size is determined at runtime.
cdef public intp_t max_depth # Max depth of the tree
cdef public intp_t node_count # Counter for node IDs
cdef public intp_t capacity # Capacity of tree, in terms of nodes
cdef Node* nodes # Array of nodes
cdef float64_t* value # (capacity, n_outputs, max_n_classes) array of values
cdef intp_t value_stride # = n_outputs * max_n_classes
# Methods
cdef intp_t _add_node(self, intp_t parent, bint is_left, bint is_leaf,
intp_t feature, float64_t threshold, float64_t impurity,
intp_t n_node_samples,
float64_t weighted_n_node_samples,
unsigned char missing_go_to_left) except -1 nogil
cdef int _resize(self, intp_t capacity) except -1 nogil
cdef int _resize_c(self, intp_t capacity=*) except -1 nogil
cdef cnp.ndarray _get_value_ndarray(self)
cdef cnp.ndarray _get_node_ndarray(self)
cpdef cnp.ndarray predict(self, object X)
cpdef cnp.ndarray apply(self, object X)
cdef cnp.ndarray _apply_dense(self, object X)
cdef cnp.ndarray _apply_sparse_csr(self, object X)
cpdef object decision_path(self, object X)
cdef object _decision_path_dense(self, object X)
cdef object _decision_path_sparse_csr(self, object X)
cpdef compute_node_depths(self)
cpdef compute_feature_importances(self, normalize=*)
# =============================================================================
# Tree builder
# =============================================================================
cdef class TreeBuilder:
# The TreeBuilder recursively builds a Tree object from training samples,
# using a Splitter object for splitting internal nodes and assigning
# values to leaves.
#
# This class controls the various stopping criteria and the node splitting
# evaluation order, e.g. depth-first or best-first.
cdef Splitter splitter # Splitting algorithm
cdef intp_t min_samples_split # Minimum number of samples in an internal node
cdef intp_t min_samples_leaf # Minimum number of samples in a leaf
cdef float64_t min_weight_leaf # Minimum weight in a leaf
cdef intp_t max_depth # Maximal tree depth
cdef float64_t min_impurity_decrease # Impurity threshold for early stopping
cpdef build(
self,
Tree tree,
object X,
const float64_t[:, ::1] y,
const float64_t[:] sample_weight=*,
const unsigned char[::1] missing_values_in_feature_mask=*,
)
cdef _check_input(
self,
object X,
const float64_t[:, ::1] y,
const float64_t[:] sample_weight,
)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,104 @@
# Authors: Gilles Louppe <g.louppe@gmail.com>
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
# Arnaud Joly <arnaud.v.joly@gmail.com>
# Jacob Schreiber <jmschreiber91@gmail.com>
# Nelson Liu <nelson@nelsonliu.me>
#
# License: BSD 3 clause
# See _utils.pyx for details.
cimport numpy as cnp
from ._tree cimport Node
from ..neighbors._quad_tree cimport Cell
from ..utils._typedefs cimport float32_t, float64_t, intp_t, int32_t, uint32_t
cdef enum:
# Max value for our rand_r replacement (near the bottom).
# We don't use RAND_MAX because it's different across platforms and
# particularly tiny on Windows/MSVC.
# It corresponds to the maximum representable value for
# 32-bit signed integers (i.e. 2^31 - 1).
RAND_R_MAX = 2147483647
# safe_realloc(&p, n) resizes the allocation of p to n * sizeof(*p) bytes or
# raises a MemoryError. It never calls free, since that's __dealloc__'s job.
# cdef float32_t *p = NULL
# safe_realloc(&p, n)
# is equivalent to p = malloc(n * sizeof(*p)) with error checking.
ctypedef fused realloc_ptr:
# Add pointer types here as needed.
(float32_t*)
(intp_t*)
(unsigned char*)
(WeightedPQueueRecord*)
(float64_t*)
(float64_t**)
(Node*)
(Cell*)
(Node**)
cdef int safe_realloc(realloc_ptr* p, size_t nelems) except -1 nogil
cdef cnp.ndarray sizet_ptr_to_ndarray(intp_t* data, intp_t size)
cdef intp_t rand_int(intp_t low, intp_t high,
uint32_t* random_state) noexcept nogil
cdef float64_t rand_uniform(float64_t low, float64_t high,
uint32_t* random_state) noexcept nogil
cdef float64_t log(float64_t x) noexcept nogil
# =============================================================================
# WeightedPQueue data structure
# =============================================================================
# A record stored in the WeightedPQueue
cdef struct WeightedPQueueRecord:
float64_t data
float64_t weight
cdef class WeightedPQueue:
cdef intp_t capacity
cdef intp_t array_ptr
cdef WeightedPQueueRecord* array_
cdef bint is_empty(self) noexcept nogil
cdef int reset(self) except -1 nogil
cdef intp_t size(self) noexcept nogil
cdef int push(self, float64_t data, float64_t weight) except -1 nogil
cdef int remove(self, float64_t data, float64_t weight) noexcept nogil
cdef int pop(self, float64_t* data, float64_t* weight) noexcept nogil
cdef int peek(self, float64_t* data, float64_t* weight) noexcept nogil
cdef float64_t get_weight_from_index(self, intp_t index) noexcept nogil
cdef float64_t get_value_from_index(self, intp_t index) noexcept nogil
# =============================================================================
# WeightedMedianCalculator data structure
# =============================================================================
cdef class WeightedMedianCalculator:
cdef intp_t initial_capacity
cdef WeightedPQueue samples
cdef float64_t total_weight
cdef intp_t k
cdef float64_t sum_w_0_k # represents sum(weights[0:k]) = w[0] + w[1] + ... + w[k-1]
cdef intp_t size(self) noexcept nogil
cdef int push(self, float64_t data, float64_t weight) except -1 nogil
cdef int reset(self) except -1 nogil
cdef int update_median_parameters_post_push(
self, float64_t data, float64_t weight,
float64_t original_median) noexcept nogil
cdef int remove(self, float64_t data, float64_t weight) noexcept nogil
cdef int pop(self, float64_t* data, float64_t* weight) noexcept nogil
cdef int update_median_parameters_post_remove(
self, float64_t data, float64_t weight,
float64_t original_median) noexcept nogil
cdef float64_t get_median(self) noexcept nogil
@@ -0,0 +1,466 @@
# Authors: Gilles Louppe <g.louppe@gmail.com>
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
# Arnaud Joly <arnaud.v.joly@gmail.com>
# Jacob Schreiber <jmschreiber91@gmail.com>
# Nelson Liu <nelson@nelsonliu.me>
#
#
# License: BSD 3 clause
from libc.stdlib cimport free
from libc.stdlib cimport realloc
from libc.math cimport log as ln
from libc.math cimport isnan
import numpy as np
cimport numpy as cnp
cnp.import_array()
from ..utils._random cimport our_rand_r
# =============================================================================
# Helper functions
# =============================================================================
cdef int safe_realloc(realloc_ptr* p, size_t nelems) except -1 nogil:
# sizeof(realloc_ptr[0]) would be more like idiomatic C, but causes Cython
# 0.20.1 to crash.
cdef size_t nbytes = nelems * sizeof(p[0][0])
if nbytes / sizeof(p[0][0]) != nelems:
# Overflow in the multiplication
raise MemoryError(f"could not allocate ({nelems} * {sizeof(p[0][0])}) bytes")
cdef realloc_ptr tmp = <realloc_ptr>realloc(p[0], nbytes)
if tmp == NULL:
raise MemoryError(f"could not allocate {nbytes} bytes")
p[0] = tmp
return 0
def _realloc_test():
# Helper for tests. Tries to allocate <size_t>(-1) / 2 * sizeof(size_t)
# bytes, which will always overflow.
cdef intp_t* p = NULL
safe_realloc(&p, <size_t>(-1) / 2)
if p != NULL:
free(p)
assert False
cdef inline cnp.ndarray sizet_ptr_to_ndarray(intp_t* data, intp_t size):
"""Return copied data as 1D numpy array of intp's."""
cdef cnp.npy_intp shape[1]
shape[0] = <cnp.npy_intp> size
return cnp.PyArray_SimpleNewFromData(1, shape, cnp.NPY_INTP, data).copy()
cdef inline intp_t rand_int(intp_t low, intp_t high,
uint32_t* random_state) noexcept nogil:
"""Generate a random integer in [low; end)."""
return low + our_rand_r(random_state) % (high - low)
cdef inline float64_t rand_uniform(float64_t low, float64_t high,
uint32_t* random_state) noexcept nogil:
"""Generate a random float64_t in [low; high)."""
return ((high - low) * <float64_t> our_rand_r(random_state) /
<float64_t> RAND_R_MAX) + low
cdef inline float64_t log(float64_t x) noexcept nogil:
return ln(x) / ln(2.0)
# =============================================================================
# WeightedPQueue data structure
# =============================================================================
cdef class WeightedPQueue:
"""A priority queue class, always sorted in increasing order.
Attributes
----------
capacity : intp_t
The capacity of the priority queue.
array_ptr : intp_t
The water mark of the priority queue; the priority queue grows from
left to right in the array ``array_``. ``array_ptr`` is always
less than ``capacity``.
array_ : WeightedPQueueRecord*
The array of priority queue records. The minimum element is on the
left at index 0, and the maximum element is on the right at index
``array_ptr-1``.
"""
def __cinit__(self, intp_t capacity):
self.capacity = capacity
self.array_ptr = 0
safe_realloc(&self.array_, capacity)
def __dealloc__(self):
free(self.array_)
cdef int reset(self) except -1 nogil:
"""Reset the WeightedPQueue to its state at construction
Return -1 in case of failure to allocate memory (and raise MemoryError)
or 0 otherwise.
"""
self.array_ptr = 0
# Since safe_realloc can raise MemoryError, use `except -1`
safe_realloc(&self.array_, self.capacity)
return 0
cdef bint is_empty(self) noexcept nogil:
return self.array_ptr <= 0
cdef intp_t size(self) noexcept nogil:
return self.array_ptr
cdef int push(self, float64_t data, float64_t weight) except -1 nogil:
"""Push record on the array.
Return -1 in case of failure to allocate memory (and raise MemoryError)
or 0 otherwise.
"""
cdef intp_t array_ptr = self.array_ptr
cdef WeightedPQueueRecord* array = NULL
cdef intp_t i
# Resize if capacity not sufficient
if array_ptr >= self.capacity:
self.capacity *= 2
# Since safe_realloc can raise MemoryError, use `except -1`
safe_realloc(&self.array_, self.capacity)
# Put element as last element of array
array = self.array_
array[array_ptr].data = data
array[array_ptr].weight = weight
# bubble last element up according until it is sorted
# in ascending order
i = array_ptr
while(i != 0 and array[i].data < array[i-1].data):
array[i], array[i-1] = array[i-1], array[i]
i -= 1
# Increase element count
self.array_ptr = array_ptr + 1
return 0
cdef int remove(self, float64_t data, float64_t weight) noexcept nogil:
"""Remove a specific value/weight record from the array.
Returns 0 if successful, -1 if record not found."""
cdef intp_t array_ptr = self.array_ptr
cdef WeightedPQueueRecord* array = self.array_
cdef intp_t idx_to_remove = -1
cdef intp_t i
if array_ptr <= 0:
return -1
# find element to remove
for i in range(array_ptr):
if array[i].data == data and array[i].weight == weight:
idx_to_remove = i
break
if idx_to_remove == -1:
return -1
# shift the elements after the removed element
# to the left.
for i in range(idx_to_remove, array_ptr-1):
array[i] = array[i+1]
self.array_ptr = array_ptr - 1
return 0
cdef int pop(self, float64_t* data, float64_t* weight) noexcept nogil:
"""Remove the top (minimum) element from array.
Returns 0 if successful, -1 if nothing to remove."""
cdef intp_t array_ptr = self.array_ptr
cdef WeightedPQueueRecord* array = self.array_
cdef intp_t i
if array_ptr <= 0:
return -1
data[0] = array[0].data
weight[0] = array[0].weight
# shift the elements after the removed element
# to the left.
for i in range(0, array_ptr-1):
array[i] = array[i+1]
self.array_ptr = array_ptr - 1
return 0
cdef int peek(self, float64_t* data, float64_t* weight) noexcept nogil:
"""Write the top element from array to a pointer.
Returns 0 if successful, -1 if nothing to write."""
cdef WeightedPQueueRecord* array = self.array_
if self.array_ptr <= 0:
return -1
# Take first value
data[0] = array[0].data
weight[0] = array[0].weight
return 0
cdef float64_t get_weight_from_index(self, intp_t index) noexcept nogil:
"""Given an index between [0,self.current_capacity], access
the appropriate heap and return the requested weight"""
cdef WeightedPQueueRecord* array = self.array_
# get weight at index
return array[index].weight
cdef float64_t get_value_from_index(self, intp_t index) noexcept nogil:
"""Given an index between [0,self.current_capacity], access
the appropriate heap and return the requested value"""
cdef WeightedPQueueRecord* array = self.array_
# get value at index
return array[index].data
# =============================================================================
# WeightedMedianCalculator data structure
# =============================================================================
cdef class WeightedMedianCalculator:
"""A class to handle calculation of the weighted median from streams of
data. To do so, it maintains a parameter ``k`` such that the sum of the
weights in the range [0,k) is greater than or equal to half of the total
weight. By minimizing the value of ``k`` that fulfills this constraint,
calculating the median is done by either taking the value of the sample
at index ``k-1`` of ``samples`` (samples[k-1].data) or the average of
the samples at index ``k-1`` and ``k`` of ``samples``
((samples[k-1] + samples[k]) / 2).
Attributes
----------
initial_capacity : intp_t
The initial capacity of the WeightedMedianCalculator.
samples : WeightedPQueue
Holds the samples (consisting of values and their weights) used in the
weighted median calculation.
total_weight : float64_t
The sum of the weights of items in ``samples``. Represents the total
weight of all samples used in the median calculation.
k : intp_t
Index used to calculate the median.
sum_w_0_k : float64_t
The sum of the weights from samples[0:k]. Used in the weighted
median calculation; minimizing the value of ``k`` such that
``sum_w_0_k`` >= ``total_weight / 2`` provides a mechanism for
calculating the median in constant time.
"""
def __cinit__(self, intp_t initial_capacity):
self.initial_capacity = initial_capacity
self.samples = WeightedPQueue(initial_capacity)
self.total_weight = 0
self.k = 0
self.sum_w_0_k = 0
cdef intp_t size(self) noexcept nogil:
"""Return the number of samples in the
WeightedMedianCalculator"""
return self.samples.size()
cdef int reset(self) except -1 nogil:
"""Reset the WeightedMedianCalculator to its state at construction
Return -1 in case of failure to allocate memory (and raise MemoryError)
or 0 otherwise.
"""
# samples.reset (WeightedPQueue.reset) uses safe_realloc, hence
# except -1
self.samples.reset()
self.total_weight = 0
self.k = 0
self.sum_w_0_k = 0
return 0
cdef int push(self, float64_t data, float64_t weight) except -1 nogil:
"""Push a value and its associated weight to the WeightedMedianCalculator
Return -1 in case of failure to allocate memory (and raise MemoryError)
or 0 otherwise.
"""
cdef int return_value
cdef float64_t original_median = 0.0
if self.size() != 0:
original_median = self.get_median()
# samples.push (WeightedPQueue.push) uses safe_realloc, hence except -1
return_value = self.samples.push(data, weight)
self.update_median_parameters_post_push(data, weight,
original_median)
return return_value
cdef int update_median_parameters_post_push(
self, float64_t data, float64_t weight,
float64_t original_median) noexcept nogil:
"""Update the parameters used in the median calculation,
namely `k` and `sum_w_0_k` after an insertion"""
# trivial case of one element.
if self.size() == 1:
self.k = 1
self.total_weight = weight
self.sum_w_0_k = self.total_weight
return 0
# get the original weighted median
self.total_weight += weight
if data < original_median:
# inserting below the median, so increment k and
# then update self.sum_w_0_k accordingly by adding
# the weight that was added.
self.k += 1
# update sum_w_0_k by adding the weight added
self.sum_w_0_k += weight
# minimize k such that sum(W[0:k]) >= total_weight / 2
# minimum value of k is 1
while(self.k > 1 and ((self.sum_w_0_k -
self.samples.get_weight_from_index(self.k-1))
>= self.total_weight / 2.0)):
self.k -= 1
self.sum_w_0_k -= self.samples.get_weight_from_index(self.k)
return 0
if data >= original_median:
# inserting above or at the median
# minimize k such that sum(W[0:k]) >= total_weight / 2
while(self.k < self.samples.size() and
(self.sum_w_0_k < self.total_weight / 2.0)):
self.k += 1
self.sum_w_0_k += self.samples.get_weight_from_index(self.k-1)
return 0
cdef int remove(self, float64_t data, float64_t weight) noexcept nogil:
"""Remove a value from the MedianHeap, removing it
from consideration in the median calculation
"""
cdef int return_value
cdef float64_t original_median = 0.0
if self.size() != 0:
original_median = self.get_median()
return_value = self.samples.remove(data, weight)
self.update_median_parameters_post_remove(data, weight,
original_median)
return return_value
cdef int pop(self, float64_t* data, float64_t* weight) noexcept nogil:
"""Pop a value from the MedianHeap, starting from the
left and moving to the right.
"""
cdef int return_value
cdef float64_t original_median = 0.0
if self.size() != 0:
original_median = self.get_median()
# no elements to pop
if self.samples.size() == 0:
return -1
return_value = self.samples.pop(data, weight)
self.update_median_parameters_post_remove(data[0],
weight[0],
original_median)
return return_value
cdef int update_median_parameters_post_remove(
self, float64_t data, float64_t weight,
float64_t original_median) noexcept nogil:
"""Update the parameters used in the median calculation,
namely `k` and `sum_w_0_k` after a removal"""
# reset parameters because it there are no elements
if self.samples.size() == 0:
self.k = 0
self.total_weight = 0
self.sum_w_0_k = 0
return 0
# trivial case of one element.
if self.samples.size() == 1:
self.k = 1
self.total_weight -= weight
self.sum_w_0_k = self.total_weight
return 0
# get the current weighted median
self.total_weight -= weight
if data < original_median:
# removing below the median, so decrement k and
# then update self.sum_w_0_k accordingly by subtracting
# the removed weight
self.k -= 1
# update sum_w_0_k by removing the weight at index k
self.sum_w_0_k -= weight
# minimize k such that sum(W[0:k]) >= total_weight / 2
# by incrementing k and updating sum_w_0_k accordingly
# until the condition is met.
while(self.k < self.samples.size() and
(self.sum_w_0_k < self.total_weight / 2.0)):
self.k += 1
self.sum_w_0_k += self.samples.get_weight_from_index(self.k-1)
return 0
if data >= original_median:
# removing above the median
# minimize k such that sum(W[0:k]) >= total_weight / 2
while(self.k > 1 and ((self.sum_w_0_k -
self.samples.get_weight_from_index(self.k-1))
>= self.total_weight / 2.0)):
self.k -= 1
self.sum_w_0_k -= self.samples.get_weight_from_index(self.k)
return 0
cdef float64_t get_median(self) noexcept nogil:
"""Write the median to a pointer, taking into account
sample weights."""
if self.sum_w_0_k == (self.total_weight / 2.0):
# split median
return (self.samples.get_value_from_index(self.k) +
self.samples.get_value_from_index(self.k-1)) / 2.0
if self.sum_w_0_k > (self.total_weight / 2.0):
# whole median
return self.samples.get_value_from_index(self.k-1)
def _any_isnan_axis0(const float32_t[:, :] X):
"""Same as np.any(np.isnan(X), axis=0)"""
cdef:
intp_t i, j
intp_t n_samples = X.shape[0]
intp_t n_features = X.shape[1]
unsigned char[::1] isnan_out = np.zeros(X.shape[1], dtype=np.bool_)
with nogil:
for i in range(n_samples):
for j in range(n_features):
if isnan_out[j]:
continue
if isnan(X[i, j]):
isnan_out[j] = True
break
return np.asarray(isnan_out)
@@ -0,0 +1,26 @@
tree_extension_metadata = {
'_tree':
{'sources': ['_tree.pyx'],
'override_options': ['cython_language=cpp', 'optimization=3']},
'_splitter':
{'sources': ['_splitter.pyx'],
'override_options': ['optimization=3']},
'_criterion':
{'sources': ['_criterion.pyx'],
'override_options': ['optimization=3']},
'_utils':
{'sources': ['_utils.pyx'],
'override_options': ['optimization=3']},
}
foreach ext_name, ext_dict : tree_extension_metadata
py.extension_module(
ext_name,
[ext_dict.get('sources'), utils_cython_tree],
dependencies: [np_dep],
override_options : ext_dict.get('override_options', []),
cython_args: cython_args,
subdir: 'sklearn/tree',
install: true
)
endforeach
@@ -0,0 +1,546 @@
"""
Testing for export functions of decision trees (sklearn.tree.export).
"""
from io import StringIO
from re import finditer, search
from textwrap import dedent
import numpy as np
import pytest
from numpy.random import RandomState
from sklearn.base import is_classifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.exceptions import NotFittedError
from sklearn.tree import (
DecisionTreeClassifier,
DecisionTreeRegressor,
export_graphviz,
export_text,
plot_tree,
)
# toy sample
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
y = [-1, -1, -1, 1, 1, 1]
y2 = [[-1, 1], [-1, 1], [-1, 1], [1, 2], [1, 2], [1, 3]]
w = [1, 1, 1, 0.5, 0.5, 0.5]
y_degraded = [1, 1, 1, 1, 1, 1]
def test_graphviz_toy():
# Check correctness of export_graphviz
clf = DecisionTreeClassifier(
max_depth=3, min_samples_split=2, criterion="gini", random_state=2
)
clf.fit(X, y)
# Test export code
contents1 = export_graphviz(clf, out_file=None)
contents2 = (
"digraph Tree {\n"
'node [shape=box, fontname="helvetica"] ;\n'
'edge [fontname="helvetica"] ;\n'
'0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
'value = [3, 3]"] ;\n'
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n'
"0 -> 1 [labeldistance=2.5, labelangle=45, "
'headlabel="True"] ;\n'
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n'
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
'headlabel="False"] ;\n'
"}"
)
assert contents1 == contents2
# Test plot_options
contents1 = export_graphviz(
clf,
filled=True,
impurity=False,
proportion=True,
special_characters=True,
rounded=True,
out_file=None,
fontname="sans",
)
contents2 = (
"digraph Tree {\n"
'node [shape=box, style="filled, rounded", color="black", '
'fontname="sans"] ;\n'
'edge [fontname="sans"] ;\n'
"0 [label=<x<SUB>0</SUB> &le; 0.0<br/>samples = 100.0%<br/>"
'value = [0.5, 0.5]>, fillcolor="#ffffff"] ;\n'
"1 [label=<samples = 50.0%<br/>value = [1.0, 0.0]>, "
'fillcolor="#e58139"] ;\n'
"0 -> 1 [labeldistance=2.5, labelangle=45, "
'headlabel="True"] ;\n'
"2 [label=<samples = 50.0%<br/>value = [0.0, 1.0]>, "
'fillcolor="#399de5"] ;\n'
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
'headlabel="False"] ;\n'
"}"
)
assert contents1 == contents2
# Test max_depth
contents1 = export_graphviz(clf, max_depth=0, class_names=True, out_file=None)
contents2 = (
"digraph Tree {\n"
'node [shape=box, fontname="helvetica"] ;\n'
'edge [fontname="helvetica"] ;\n'
'0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
'value = [3, 3]\\nclass = y[0]"] ;\n'
'1 [label="(...)"] ;\n'
"0 -> 1 ;\n"
'2 [label="(...)"] ;\n'
"0 -> 2 ;\n"
"}"
)
assert contents1 == contents2
# Test max_depth with plot_options
contents1 = export_graphviz(
clf, max_depth=0, filled=True, out_file=None, node_ids=True
)
contents2 = (
"digraph Tree {\n"
'node [shape=box, style="filled", color="black", '
'fontname="helvetica"] ;\n'
'edge [fontname="helvetica"] ;\n'
'0 [label="node #0\\nx[0] <= 0.0\\ngini = 0.5\\n'
'samples = 6\\nvalue = [3, 3]", fillcolor="#ffffff"] ;\n'
'1 [label="(...)", fillcolor="#C0C0C0"] ;\n'
"0 -> 1 ;\n"
'2 [label="(...)", fillcolor="#C0C0C0"] ;\n'
"0 -> 2 ;\n"
"}"
)
assert contents1 == contents2
# Test multi-output with weighted samples
clf = DecisionTreeClassifier(
max_depth=2, min_samples_split=2, criterion="gini", random_state=2
)
clf = clf.fit(X, y2, sample_weight=w)
contents1 = export_graphviz(clf, filled=True, impurity=False, out_file=None)
contents2 = (
"digraph Tree {\n"
'node [shape=box, style="filled", color="black", '
'fontname="helvetica"] ;\n'
'edge [fontname="helvetica"] ;\n'
'0 [label="x[0] <= 0.0\\nsamples = 6\\n'
"value = [[3.0, 1.5, 0.0]\\n"
'[3.0, 1.0, 0.5]]", fillcolor="#ffffff"] ;\n'
'1 [label="samples = 3\\nvalue = [[3, 0, 0]\\n'
'[3, 0, 0]]", fillcolor="#e58139"] ;\n'
"0 -> 1 [labeldistance=2.5, labelangle=45, "
'headlabel="True"] ;\n'
'2 [label="x[0] <= 1.5\\nsamples = 3\\n'
"value = [[0.0, 1.5, 0.0]\\n"
'[0.0, 1.0, 0.5]]", fillcolor="#f1bd97"] ;\n'
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
'headlabel="False"] ;\n'
'3 [label="samples = 2\\nvalue = [[0, 1, 0]\\n'
'[0, 1, 0]]", fillcolor="#e58139"] ;\n'
"2 -> 3 ;\n"
'4 [label="samples = 1\\nvalue = [[0.0, 0.5, 0.0]\\n'
'[0.0, 0.0, 0.5]]", fillcolor="#e58139"] ;\n'
"2 -> 4 ;\n"
"}"
)
assert contents1 == contents2
# Test regression output with plot_options
clf = DecisionTreeRegressor(
max_depth=3, min_samples_split=2, criterion="squared_error", random_state=2
)
clf.fit(X, y)
contents1 = export_graphviz(
clf,
filled=True,
leaves_parallel=True,
out_file=None,
rotate=True,
rounded=True,
fontname="sans",
)
contents2 = (
"digraph Tree {\n"
'node [shape=box, style="filled, rounded", color="black", '
'fontname="sans"] ;\n'
"graph [ranksep=equally, splines=polyline] ;\n"
'edge [fontname="sans"] ;\n'
"rankdir=LR ;\n"
'0 [label="x[0] <= 0.0\\nsquared_error = 1.0\\nsamples = 6\\n'
'value = 0.0", fillcolor="#f2c09c"] ;\n'
'1 [label="squared_error = 0.0\\nsamples = 3\\'
'nvalue = -1.0", '
'fillcolor="#ffffff"] ;\n'
"0 -> 1 [labeldistance=2.5, labelangle=-45, "
'headlabel="True"] ;\n'
'2 [label="squared_error = 0.0\\nsamples = 3\\nvalue = 1.0", '
'fillcolor="#e58139"] ;\n'
"0 -> 2 [labeldistance=2.5, labelangle=45, "
'headlabel="False"] ;\n'
"{rank=same ; 0} ;\n"
"{rank=same ; 1; 2} ;\n"
"}"
)
assert contents1 == contents2
# Test classifier with degraded learning set
clf = DecisionTreeClassifier(max_depth=3)
clf.fit(X, y_degraded)
contents1 = export_graphviz(clf, filled=True, out_file=None)
contents2 = (
"digraph Tree {\n"
'node [shape=box, style="filled", color="black", '
'fontname="helvetica"] ;\n'
'edge [fontname="helvetica"] ;\n'
'0 [label="gini = 0.0\\nsamples = 6\\nvalue = 6.0", '
'fillcolor="#ffffff"] ;\n'
"}"
)
@pytest.mark.parametrize("constructor", [list, np.array])
def test_graphviz_feature_class_names_array_support(constructor):
# Check that export_graphviz treats feature names
# and class names correctly and supports arrays
clf = DecisionTreeClassifier(
max_depth=3, min_samples_split=2, criterion="gini", random_state=2
)
clf.fit(X, y)
# Test with feature_names
contents1 = export_graphviz(
clf, feature_names=constructor(["feature0", "feature1"]), out_file=None
)
contents2 = (
"digraph Tree {\n"
'node [shape=box, fontname="helvetica"] ;\n'
'edge [fontname="helvetica"] ;\n'
'0 [label="feature0 <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
'value = [3, 3]"] ;\n'
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]"] ;\n'
"0 -> 1 [labeldistance=2.5, labelangle=45, "
'headlabel="True"] ;\n'
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]"] ;\n'
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
'headlabel="False"] ;\n'
"}"
)
assert contents1 == contents2
# Test with class_names
contents1 = export_graphviz(
clf, class_names=constructor(["yes", "no"]), out_file=None
)
contents2 = (
"digraph Tree {\n"
'node [shape=box, fontname="helvetica"] ;\n'
'edge [fontname="helvetica"] ;\n'
'0 [label="x[0] <= 0.0\\ngini = 0.5\\nsamples = 6\\n'
'value = [3, 3]\\nclass = yes"] ;\n'
'1 [label="gini = 0.0\\nsamples = 3\\nvalue = [3, 0]\\n'
'class = yes"] ;\n'
"0 -> 1 [labeldistance=2.5, labelangle=45, "
'headlabel="True"] ;\n'
'2 [label="gini = 0.0\\nsamples = 3\\nvalue = [0, 3]\\n'
'class = no"] ;\n'
"0 -> 2 [labeldistance=2.5, labelangle=-45, "
'headlabel="False"] ;\n'
"}"
)
assert contents1 == contents2
def test_graphviz_errors():
# Check for errors of export_graphviz
clf = DecisionTreeClassifier(max_depth=3, min_samples_split=2)
# Check not-fitted decision tree error
out = StringIO()
with pytest.raises(NotFittedError):
export_graphviz(clf, out)
clf.fit(X, y)
# Check if it errors when length of feature_names
# mismatches with number of features
message = "Length of feature_names, 1 does not match number of features, 2"
with pytest.raises(ValueError, match=message):
export_graphviz(clf, None, feature_names=["a"])
message = "Length of feature_names, 3 does not match number of features, 2"
with pytest.raises(ValueError, match=message):
export_graphviz(clf, None, feature_names=["a", "b", "c"])
# Check error when argument is not an estimator
message = "is not an estimator instance"
with pytest.raises(TypeError, match=message):
export_graphviz(clf.fit(X, y).tree_)
# Check class_names error
out = StringIO()
with pytest.raises(IndexError):
export_graphviz(clf, out, class_names=[])
def test_friedman_mse_in_graphviz():
clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0)
clf.fit(X, y)
dot_data = StringIO()
export_graphviz(clf, out_file=dot_data)
clf = GradientBoostingClassifier(n_estimators=2, random_state=0)
clf.fit(X, y)
for estimator in clf.estimators_:
export_graphviz(estimator[0], out_file=dot_data)
for finding in finditer(r"\[.*?samples.*?\]", dot_data.getvalue()):
assert "friedman_mse" in finding.group()
def test_precision():
rng_reg = RandomState(2)
rng_clf = RandomState(8)
for X, y, clf in zip(
(rng_reg.random_sample((5, 2)), rng_clf.random_sample((1000, 4))),
(rng_reg.random_sample((5,)), rng_clf.randint(2, size=(1000,))),
(
DecisionTreeRegressor(
criterion="friedman_mse", random_state=0, max_depth=1
),
DecisionTreeClassifier(max_depth=1, random_state=0),
),
):
clf.fit(X, y)
for precision in (4, 3):
dot_data = export_graphviz(
clf, out_file=None, precision=precision, proportion=True
)
# With the current random state, the impurity and the threshold
# will have the number of precision set in the export_graphviz
# function. We will check the number of precision with a strict
# equality. The value reported will have only 2 precision and
# therefore, only a less equal comparison will be done.
# check value
for finding in finditer(r"value = \d+\.\d+", dot_data):
assert len(search(r"\.\d+", finding.group()).group()) <= precision + 1
# check impurity
if is_classifier(clf):
pattern = r"gini = \d+\.\d+"
else:
pattern = r"friedman_mse = \d+\.\d+"
# check impurity
for finding in finditer(pattern, dot_data):
assert len(search(r"\.\d+", finding.group()).group()) == precision + 1
# check threshold
for finding in finditer(r"<= \d+\.\d+", dot_data):
assert len(search(r"\.\d+", finding.group()).group()) == precision + 1
def test_export_text_errors():
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
clf.fit(X, y)
err_msg = "feature_names must contain 2 elements, got 1"
with pytest.raises(ValueError, match=err_msg):
export_text(clf, feature_names=["a"])
err_msg = (
"When `class_names` is an array, it should contain as"
" many items as `decision_tree.classes_`. Got 1 while"
" the tree was fitted with 2 classes."
)
with pytest.raises(ValueError, match=err_msg):
export_text(clf, class_names=["a"])
def test_export_text():
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
clf.fit(X, y)
expected_report = dedent(
"""
|--- feature_1 <= 0.00
| |--- class: -1
|--- feature_1 > 0.00
| |--- class: 1
"""
).lstrip()
assert export_text(clf) == expected_report
# testing that leaves at level 1 are not truncated
assert export_text(clf, max_depth=0) == expected_report
# testing that the rest of the tree is truncated
assert export_text(clf, max_depth=10) == expected_report
expected_report = dedent(
"""
|--- feature_1 <= 0.00
| |--- weights: [3.00, 0.00] class: -1
|--- feature_1 > 0.00
| |--- weights: [0.00, 3.00] class: 1
"""
).lstrip()
assert export_text(clf, show_weights=True) == expected_report
expected_report = dedent(
"""
|- feature_1 <= 0.00
| |- class: -1
|- feature_1 > 0.00
| |- class: 1
"""
).lstrip()
assert export_text(clf, spacing=1) == expected_report
X_l = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [-1, 1]]
y_l = [-1, -1, -1, 1, 1, 1, 2]
clf = DecisionTreeClassifier(max_depth=4, random_state=0)
clf.fit(X_l, y_l)
expected_report = dedent(
"""
|--- feature_1 <= 0.00
| |--- class: -1
|--- feature_1 > 0.00
| |--- truncated branch of depth 2
"""
).lstrip()
assert export_text(clf, max_depth=0) == expected_report
X_mo = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
y_mo = [[-1, -1], [-1, -1], [-1, -1], [1, 1], [1, 1], [1, 1]]
reg = DecisionTreeRegressor(max_depth=2, random_state=0)
reg.fit(X_mo, y_mo)
expected_report = dedent(
"""
|--- feature_1 <= 0.0
| |--- value: [-1.0, -1.0]
|--- feature_1 > 0.0
| |--- value: [1.0, 1.0]
"""
).lstrip()
assert export_text(reg, decimals=1) == expected_report
assert export_text(reg, decimals=1, show_weights=True) == expected_report
X_single = [[-2], [-1], [-1], [1], [1], [2]]
reg = DecisionTreeRegressor(max_depth=2, random_state=0)
reg.fit(X_single, y_mo)
expected_report = dedent(
"""
|--- first <= 0.0
| |--- value: [-1.0, -1.0]
|--- first > 0.0
| |--- value: [1.0, 1.0]
"""
).lstrip()
assert export_text(reg, decimals=1, feature_names=["first"]) == expected_report
assert (
export_text(reg, decimals=1, show_weights=True, feature_names=["first"])
== expected_report
)
@pytest.mark.parametrize("constructor", [list, np.array])
def test_export_text_feature_class_names_array_support(constructor):
# Check that export_graphviz treats feature names
# and class names correctly and supports arrays
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
clf.fit(X, y)
expected_report = dedent(
"""
|--- b <= 0.00
| |--- class: -1
|--- b > 0.00
| |--- class: 1
"""
).lstrip()
assert export_text(clf, feature_names=constructor(["a", "b"])) == expected_report
expected_report = dedent(
"""
|--- feature_1 <= 0.00
| |--- class: cat
|--- feature_1 > 0.00
| |--- class: dog
"""
).lstrip()
assert export_text(clf, class_names=constructor(["cat", "dog"])) == expected_report
def test_plot_tree_entropy(pyplot):
# mostly smoke tests
# Check correctness of export_graphviz for criterion = entropy
clf = DecisionTreeClassifier(
max_depth=3, min_samples_split=2, criterion="entropy", random_state=2
)
clf.fit(X, y)
# Test export code
feature_names = ["first feat", "sepal_width"]
nodes = plot_tree(clf, feature_names=feature_names)
assert len(nodes) == 5
assert (
nodes[0].get_text()
== "first feat <= 0.0\nentropy = 1.0\nsamples = 6\nvalue = [3, 3]"
)
assert nodes[1].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [3, 0]"
assert nodes[2].get_text() == "True "
assert nodes[3].get_text() == "entropy = 0.0\nsamples = 3\nvalue = [0, 3]"
assert nodes[4].get_text() == " False"
@pytest.mark.parametrize("fontsize", [None, 10, 20])
def test_plot_tree_gini(pyplot, fontsize):
# mostly smoke tests
# Check correctness of export_graphviz for criterion = gini
clf = DecisionTreeClassifier(
max_depth=3,
min_samples_split=2,
criterion="gini",
random_state=2,
)
clf.fit(X, y)
# Test export code
feature_names = ["first feat", "sepal_width"]
nodes = plot_tree(clf, feature_names=feature_names, fontsize=fontsize)
assert len(nodes) == 5
if fontsize is not None:
assert all(node.get_fontsize() == fontsize for node in nodes)
assert (
nodes[0].get_text()
== "first feat <= 0.0\ngini = 0.5\nsamples = 6\nvalue = [3, 3]"
)
assert nodes[1].get_text() == "gini = 0.0\nsamples = 3\nvalue = [3, 0]"
assert nodes[2].get_text() == "True "
assert nodes[3].get_text() == "gini = 0.0\nsamples = 3\nvalue = [0, 3]"
assert nodes[4].get_text() == " False"
def test_not_fitted_tree(pyplot):
# Testing if not fitted tree throws the correct error
clf = DecisionTreeRegressor()
with pytest.raises(NotFittedError):
plot_tree(clf)
@@ -0,0 +1,508 @@
import numpy as np
import pytest
from sklearn.datasets import make_classification, make_regression
from sklearn.ensemble import (
ExtraTreesClassifier,
ExtraTreesRegressor,
RandomForestClassifier,
RandomForestRegressor,
)
from sklearn.tree import (
DecisionTreeClassifier,
DecisionTreeRegressor,
ExtraTreeClassifier,
ExtraTreeRegressor,
)
from sklearn.utils._testing import assert_allclose
from sklearn.utils.fixes import CSC_CONTAINERS
TREE_CLASSIFIER_CLASSES = [DecisionTreeClassifier, ExtraTreeClassifier]
TREE_REGRESSOR_CLASSES = [DecisionTreeRegressor, ExtraTreeRegressor]
TREE_BASED_CLASSIFIER_CLASSES = TREE_CLASSIFIER_CLASSES + [
RandomForestClassifier,
ExtraTreesClassifier,
]
TREE_BASED_REGRESSOR_CLASSES = TREE_REGRESSOR_CLASSES + [
RandomForestRegressor,
ExtraTreesRegressor,
]
@pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES)
@pytest.mark.parametrize("depth_first_builder", (True, False))
@pytest.mark.parametrize("sparse_splitter", (True, False))
@pytest.mark.parametrize("csc_container", CSC_CONTAINERS)
def test_monotonic_constraints_classifications(
TreeClassifier,
depth_first_builder,
sparse_splitter,
global_random_seed,
csc_container,
):
n_samples = 1000
n_samples_train = 900
X, y = make_classification(
n_samples=n_samples,
n_classes=2,
n_features=5,
n_informative=5,
n_redundant=0,
random_state=global_random_seed,
)
X_train, y_train = X[:n_samples_train], y[:n_samples_train]
X_test, _ = X[n_samples_train:], y[n_samples_train:]
X_test_0incr, X_test_0decr = np.copy(X_test), np.copy(X_test)
X_test_1incr, X_test_1decr = np.copy(X_test), np.copy(X_test)
X_test_0incr[:, 0] += 10
X_test_0decr[:, 0] -= 10
X_test_1incr[:, 1] += 10
X_test_1decr[:, 1] -= 10
monotonic_cst = np.zeros(X.shape[1])
monotonic_cst[0] = 1
monotonic_cst[1] = -1
if depth_first_builder:
est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst)
else:
est = TreeClassifier(
max_depth=None,
monotonic_cst=monotonic_cst,
max_leaf_nodes=n_samples_train,
)
if hasattr(est, "random_state"):
est.set_params(**{"random_state": global_random_seed})
if hasattr(est, "n_estimators"):
est.set_params(**{"n_estimators": 5})
if sparse_splitter:
X_train = csc_container(X_train)
est.fit(X_train, y_train)
proba_test = est.predict_proba(X_test)
assert np.logical_and(
proba_test >= 0.0, proba_test <= 1.0
).all(), "Probability should always be in [0, 1] range."
assert_allclose(proba_test.sum(axis=1), 1.0)
# Monotonic increase constraint, it applies to the positive class
assert np.all(est.predict_proba(X_test_0incr)[:, 1] >= proba_test[:, 1])
assert np.all(est.predict_proba(X_test_0decr)[:, 1] <= proba_test[:, 1])
# Monotonic decrease constraint, it applies to the positive class
assert np.all(est.predict_proba(X_test_1incr)[:, 1] <= proba_test[:, 1])
assert np.all(est.predict_proba(X_test_1decr)[:, 1] >= proba_test[:, 1])
@pytest.mark.parametrize("TreeRegressor", TREE_BASED_REGRESSOR_CLASSES)
@pytest.mark.parametrize("depth_first_builder", (True, False))
@pytest.mark.parametrize("sparse_splitter", (True, False))
@pytest.mark.parametrize("criterion", ("absolute_error", "squared_error"))
@pytest.mark.parametrize("csc_container", CSC_CONTAINERS)
def test_monotonic_constraints_regressions(
TreeRegressor,
depth_first_builder,
sparse_splitter,
criterion,
global_random_seed,
csc_container,
):
n_samples = 1000
n_samples_train = 900
# Build a regression task using 5 informative features
X, y = make_regression(
n_samples=n_samples,
n_features=5,
n_informative=5,
random_state=global_random_seed,
)
train = np.arange(n_samples_train)
test = np.arange(n_samples_train, n_samples)
X_train = X[train]
y_train = y[train]
X_test = np.copy(X[test])
X_test_incr = np.copy(X_test)
X_test_decr = np.copy(X_test)
X_test_incr[:, 0] += 10
X_test_decr[:, 1] += 10
monotonic_cst = np.zeros(X.shape[1])
monotonic_cst[0] = 1
monotonic_cst[1] = -1
if depth_first_builder:
est = TreeRegressor(
max_depth=None,
monotonic_cst=monotonic_cst,
criterion=criterion,
)
else:
est = TreeRegressor(
max_depth=8,
monotonic_cst=monotonic_cst,
criterion=criterion,
max_leaf_nodes=n_samples_train,
)
if hasattr(est, "random_state"):
est.set_params(random_state=global_random_seed)
if hasattr(est, "n_estimators"):
est.set_params(**{"n_estimators": 5})
if sparse_splitter:
X_train = csc_container(X_train)
est.fit(X_train, y_train)
y = est.predict(X_test)
# Monotonic increase constraint
y_incr = est.predict(X_test_incr)
# y_incr should always be greater than y
assert np.all(y_incr >= y)
# Monotonic decrease constraint
y_decr = est.predict(X_test_decr)
# y_decr should always be lower than y
assert np.all(y_decr <= y)
@pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES)
def test_multiclass_raises(TreeClassifier):
X, y = make_classification(
n_samples=100, n_features=5, n_classes=3, n_informative=3, random_state=0
)
y[0] = 0
monotonic_cst = np.zeros(X.shape[1])
monotonic_cst[0] = -1
monotonic_cst[1] = 1
est = TreeClassifier(max_depth=None, monotonic_cst=monotonic_cst, random_state=0)
msg = "Monotonicity constraints are not supported with multiclass classification"
with pytest.raises(ValueError, match=msg):
est.fit(X, y)
@pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES)
def test_multiple_output_raises(TreeClassifier):
X = [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]
y = [[1, 0, 1, 0, 1], [1, 0, 1, 0, 1]]
est = TreeClassifier(
max_depth=None, monotonic_cst=np.array([-1, 1]), random_state=0
)
msg = "Monotonicity constraints are not supported with multiple output"
with pytest.raises(ValueError, match=msg):
est.fit(X, y)
@pytest.mark.parametrize(
"DecisionTreeEstimator", [DecisionTreeClassifier, DecisionTreeRegressor]
)
def test_missing_values_raises(DecisionTreeEstimator):
X, y = make_classification(
n_samples=100, n_features=5, n_classes=2, n_informative=3, random_state=0
)
X[0, 0] = np.nan
monotonic_cst = np.zeros(X.shape[1])
monotonic_cst[0] = 1
est = DecisionTreeEstimator(
max_depth=None, monotonic_cst=monotonic_cst, random_state=0
)
msg = "Input X contains NaN"
with pytest.raises(ValueError, match=msg):
est.fit(X, y)
@pytest.mark.parametrize("TreeClassifier", TREE_BASED_CLASSIFIER_CLASSES)
def test_bad_monotonic_cst_raises(TreeClassifier):
X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
y = [1, 0, 1, 0, 1]
msg = "monotonic_cst has shape 3 but the input data X has 2 features."
est = TreeClassifier(
max_depth=None, monotonic_cst=np.array([-1, 1, 0]), random_state=0
)
with pytest.raises(ValueError, match=msg):
est.fit(X, y)
msg = "monotonic_cst must be None or an array-like of -1, 0 or 1."
est = TreeClassifier(
max_depth=None, monotonic_cst=np.array([-2, 2]), random_state=0
)
with pytest.raises(ValueError, match=msg):
est.fit(X, y)
est = TreeClassifier(
max_depth=None, monotonic_cst=np.array([-1, 0.8]), random_state=0
)
with pytest.raises(ValueError, match=msg + "(.*)0.8]"):
est.fit(X, y)
def assert_1d_reg_tree_children_monotonic_bounded(tree_, monotonic_sign):
values = tree_.value
for i in range(tree_.node_count):
if tree_.children_left[i] > i and tree_.children_right[i] > i:
# Check monotonicity on children
i_left = tree_.children_left[i]
i_right = tree_.children_right[i]
if monotonic_sign == 1:
assert values[i_left] <= values[i_right]
elif monotonic_sign == -1:
assert values[i_left] >= values[i_right]
val_middle = (values[i_left] + values[i_right]) / 2
# Check bounds on grand-children, filtering out leaf nodes
if tree_.feature[i_left] >= 0:
i_left_right = tree_.children_right[i_left]
if monotonic_sign == 1:
assert values[i_left_right] <= val_middle
elif monotonic_sign == -1:
assert values[i_left_right] >= val_middle
if tree_.feature[i_right] >= 0:
i_right_left = tree_.children_left[i_right]
if monotonic_sign == 1:
assert val_middle <= values[i_right_left]
elif monotonic_sign == -1:
assert val_middle >= values[i_right_left]
def test_assert_1d_reg_tree_children_monotonic_bounded():
X = np.linspace(-1, 1, 7).reshape(-1, 1)
y = np.sin(2 * np.pi * X.ravel())
reg = DecisionTreeRegressor(max_depth=None, random_state=0).fit(X, y)
with pytest.raises(AssertionError):
assert_1d_reg_tree_children_monotonic_bounded(reg.tree_, 1)
with pytest.raises(AssertionError):
assert_1d_reg_tree_children_monotonic_bounded(reg.tree_, -1)
def assert_1d_reg_monotonic(clf, monotonic_sign, min_x, max_x, n_steps):
X_grid = np.linspace(min_x, max_x, n_steps).reshape(-1, 1)
y_pred_grid = clf.predict(X_grid)
if monotonic_sign == 1:
assert (np.diff(y_pred_grid) >= 0.0).all()
elif monotonic_sign == -1:
assert (np.diff(y_pred_grid) <= 0.0).all()
@pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES)
def test_1d_opposite_monotonicity_cst_data(TreeRegressor):
# Check that positive monotonic data with negative monotonic constraint
# yield constant predictions, equal to the average of target values
X = np.linspace(-2, 2, 10).reshape(-1, 1)
y = X.ravel()
clf = TreeRegressor(monotonic_cst=[-1])
clf.fit(X, y)
assert clf.tree_.node_count == 1
assert clf.tree_.value[0] == 0.0
# Swap monotonicity
clf = TreeRegressor(monotonic_cst=[1])
clf.fit(X, -y)
assert clf.tree_.node_count == 1
assert clf.tree_.value[0] == 0.0
@pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES)
@pytest.mark.parametrize("monotonic_sign", (-1, 1))
@pytest.mark.parametrize("depth_first_builder", (True, False))
@pytest.mark.parametrize("criterion", ("absolute_error", "squared_error"))
def test_1d_tree_nodes_values(
TreeRegressor, monotonic_sign, depth_first_builder, criterion, global_random_seed
):
# Adaptation from test_nodes_values in test_monotonic_constraints.py
# in sklearn.ensemble._hist_gradient_boosting
# Build a single tree with only one feature, and make sure the node
# values respect the monotonicity constraints.
# Considering the following tree with a monotonic +1 constraint, we
# should have:
#
# root
# / \
# a b
# / \ / \
# c d e f
#
# a <= root <= b
# c <= d <= (a + b) / 2 <= e <= f
rng = np.random.RandomState(global_random_seed)
n_samples = 1000
n_features = 1
X = rng.rand(n_samples, n_features)
y = rng.rand(n_samples)
if depth_first_builder:
# No max_leaf_nodes, default depth first tree builder
clf = TreeRegressor(
monotonic_cst=[monotonic_sign],
criterion=criterion,
random_state=global_random_seed,
)
else:
# max_leaf_nodes triggers best first tree builder
clf = TreeRegressor(
monotonic_cst=[monotonic_sign],
max_leaf_nodes=n_samples,
criterion=criterion,
random_state=global_random_seed,
)
clf.fit(X, y)
assert_1d_reg_tree_children_monotonic_bounded(clf.tree_, monotonic_sign)
assert_1d_reg_monotonic(clf, monotonic_sign, np.min(X), np.max(X), 100)
def assert_nd_reg_tree_children_monotonic_bounded(tree_, monotonic_cst):
upper_bound = np.full(tree_.node_count, np.inf)
lower_bound = np.full(tree_.node_count, -np.inf)
for i in range(tree_.node_count):
feature = tree_.feature[i]
node_value = tree_.value[i][0][0] # unpack value from nx1x1 array
# While building the tree, the computed middle value is slightly
# different from the average of the siblings values, because
# sum_right / weighted_n_right
# is slightly different from the value of the right sibling.
# This can cause a discrepancy up to numerical noise when clipping,
# which is resolved by comparing with some loss of precision.
assert np.float32(node_value) <= np.float32(upper_bound[i])
assert np.float32(node_value) >= np.float32(lower_bound[i])
if feature < 0:
# Leaf: nothing to do
continue
# Split node: check and update bounds for the children.
i_left = tree_.children_left[i]
i_right = tree_.children_right[i]
# unpack value from nx1x1 array
middle_value = (tree_.value[i_left][0][0] + tree_.value[i_right][0][0]) / 2
if monotonic_cst[feature] == 0:
# Feature without monotonicity constraint: propagate bounds
# down the tree to both children.
# Otherwise, with 2 features and a monotonic increase constraint
# (encoded by +1) on feature 0, the following tree can be accepted,
# although it does not respect the monotonic increase constraint:
#
# X[0] <= 0
# value = 100
# / \
# X[0] <= -1 X[1] <= 0
# value = 50 value = 150
# / \ / \
# leaf leaf leaf leaf
# value = 25 value = 75 value = 50 value = 250
lower_bound[i_left] = lower_bound[i]
upper_bound[i_left] = upper_bound[i]
lower_bound[i_right] = lower_bound[i]
upper_bound[i_right] = upper_bound[i]
elif monotonic_cst[feature] == 1:
# Feature with constraint: check monotonicity
assert tree_.value[i_left] <= tree_.value[i_right]
# Propagate bounds down the tree to both children.
lower_bound[i_left] = lower_bound[i]
upper_bound[i_left] = middle_value
lower_bound[i_right] = middle_value
upper_bound[i_right] = upper_bound[i]
elif monotonic_cst[feature] == -1:
# Feature with constraint: check monotonicity
assert tree_.value[i_left] >= tree_.value[i_right]
# Update and propagate bounds down the tree to both children.
lower_bound[i_left] = middle_value
upper_bound[i_left] = upper_bound[i]
lower_bound[i_right] = lower_bound[i]
upper_bound[i_right] = middle_value
else: # pragma: no cover
raise ValueError(f"monotonic_cst[{feature}]={monotonic_cst[feature]}")
def test_assert_nd_reg_tree_children_monotonic_bounded():
# Check that assert_nd_reg_tree_children_monotonic_bounded can detect
# non-monotonic tree predictions.
X = np.linspace(0, 2 * np.pi, 30).reshape(-1, 1)
y = np.sin(X).ravel()
reg = DecisionTreeRegressor(max_depth=None, random_state=0).fit(X, y)
with pytest.raises(AssertionError):
assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [1])
with pytest.raises(AssertionError):
assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [-1])
assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [0])
# Check that assert_nd_reg_tree_children_monotonic_bounded raises
# when the data (and therefore the model) is naturally monotonic in the
# opposite direction.
X = np.linspace(-5, 5, 5).reshape(-1, 1)
y = X.ravel() ** 3
reg = DecisionTreeRegressor(max_depth=None, random_state=0).fit(X, y)
with pytest.raises(AssertionError):
assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [-1])
# For completeness, check that the converse holds when swapping the sign.
reg = DecisionTreeRegressor(max_depth=None, random_state=0).fit(X, -y)
with pytest.raises(AssertionError):
assert_nd_reg_tree_children_monotonic_bounded(reg.tree_, [1])
@pytest.mark.parametrize("TreeRegressor", TREE_REGRESSOR_CLASSES)
@pytest.mark.parametrize("monotonic_sign", (-1, 1))
@pytest.mark.parametrize("depth_first_builder", (True, False))
@pytest.mark.parametrize("criterion", ("absolute_error", "squared_error"))
def test_nd_tree_nodes_values(
TreeRegressor, monotonic_sign, depth_first_builder, criterion, global_random_seed
):
# Build tree with several features, and make sure the nodes
# values respect the monotonicity constraints.
# Considering the following tree with a monotonic increase constraint on X[0],
# we should have:
#
# root
# X[0]<=t
# / \
# a b
# X[0]<=u X[1]<=v
# / \ / \
# c d e f
#
# i) a <= root <= b
# ii) c <= a <= d <= (a+b)/2
# iii) (a+b)/2 <= min(e,f)
# For iii) we check that each node value is within the proper lower and
# upper bounds.
rng = np.random.RandomState(global_random_seed)
n_samples = 1000
n_features = 2
monotonic_cst = [monotonic_sign, 0]
X = rng.rand(n_samples, n_features)
y = rng.rand(n_samples)
if depth_first_builder:
# No max_leaf_nodes, default depth first tree builder
clf = TreeRegressor(
monotonic_cst=monotonic_cst,
criterion=criterion,
random_state=global_random_seed,
)
else:
# max_leaf_nodes triggers best first tree builder
clf = TreeRegressor(
monotonic_cst=monotonic_cst,
max_leaf_nodes=n_samples,
criterion=criterion,
random_state=global_random_seed,
)
clf.fit(X, y)
assert_nd_reg_tree_children_monotonic_bounded(clf.tree_, monotonic_cst)
@@ -0,0 +1,49 @@
import numpy as np
import pytest
from sklearn.tree._reingold_tilford import Tree, buchheim
simple_tree = Tree("", 0, Tree("", 1), Tree("", 2))
bigger_tree = Tree(
"",
0,
Tree(
"",
1,
Tree("", 3),
Tree("", 4, Tree("", 7), Tree("", 8)),
),
Tree("", 2, Tree("", 5), Tree("", 6)),
)
@pytest.mark.parametrize("tree, n_nodes", [(simple_tree, 3), (bigger_tree, 9)])
def test_buchheim(tree, n_nodes):
def walk_tree(draw_tree):
res = [(draw_tree.x, draw_tree.y)]
for child in draw_tree.children:
# parents higher than children:
assert child.y == draw_tree.y + 1
res.extend(walk_tree(child))
if len(draw_tree.children):
# these trees are always binary
# parents are centered above children
assert (
draw_tree.x == (draw_tree.children[0].x + draw_tree.children[1].x) / 2
)
return res
layout = buchheim(tree)
coordinates = walk_tree(layout)
assert len(coordinates) == n_nodes
# test that x values are unique per depth / level
# we could also do it quicker using defaultdicts..
depth = 0
while True:
x_at_this_depth = [node[0] for node in coordinates if node[1] == depth]
if not x_at_this_depth:
# reached all leafs
break
assert len(np.unique(x_at_this_depth)) == len(x_at_this_depth)
depth += 1
File diff suppressed because it is too large Load Diff