Source code for bartpy.tree

from typing import List

import numpy as np

from bartpy.mutation import TreeMutation
from bartpy.node import TreeNode, LeafNode, DecisionNode


[docs]class Tree: """ An encapsulation of the structure of a single decision tree Contains no logic, but keeps track of 4 different kinds of nodes within the tree: - leaf nodes - decision nodes - splittable leaf nodes - prunable decision nodes Parameters ---------- nodes: List[Node] All nodes contained in the tree, i.e. decision and leaf nodes """ def __init__(self, nodes: List[TreeNode]): self._nodes = nodes self.cache_up_to_date = False self._prediction = np.zeros_like(self._nodes[0]._split._data._y) @property def nodes(self) -> List[TreeNode]: """ List of all nodes contained in the tree """ return self._nodes @property def leaf_nodes(self) -> List[LeafNode]: """ List of all of the leaf nodes in the tree """ return [x for x in self._nodes if x.is_leaf_node()] @property def splittable_leaf_nodes(self) -> List[LeafNode]: """ List of all leaf nodes in the tree which can be split in a non-degenerate way i.e. not all rows of the covariate matrix are duplicates """ return [x for x in self.leaf_nodes if x.is_splittable()] @property def decision_nodes(self) -> List[DecisionNode]: """ List of decision nodes in the tree. Decision nodes are internal split nodes, i.e. not leaf nodes """ return [x for x in self._nodes if x.is_decision_node()] @property def prunable_decision_nodes(self) -> List[DecisionNode]: """ List of decision nodes in the tree that are suitable for pruning In particular, decision nodes that have two leaf node children """ return [x for x in self.decision_nodes if x.is_prunable()]
[docs] def update_y(self, y: np.ndarray) -> None: """ Update the cached value of the target array in all nodes Used to pass in the residuals from the sum of all of the other trees """ self.cache_up_to_date = False for node in self.nodes: node.split.update_y(y)
[docs] def predict(self) -> np.ndarray: """ Generate a set of predictions with the same dimensionality as the target array Note that the prediction is from one tree, so represents only (1 / number_of_trees) of the target """ if self.cache_up_to_date: return self._prediction for leaf in self.leaf_nodes: self._prediction[leaf.split.condition()] = leaf.predict() self.cache_up_to_date = True return self._prediction
[docs] def out_of_sample_predict(self, X) -> np.ndarray: """ Prediction for a covariate matrix not used for training Note that this is quite slow Parameters ---------- X: pd.DataFrame Covariates to predict for Returns ------- np.ndarray """ prediction = np.array([0.] * len(X)) for leaf in self.leaf_nodes: prediction[leaf.split.out_of_sample_condition(X)] = leaf.predict() return prediction
[docs] def remove_node(self, node: TreeNode) -> None: """ Remove a single node from the tree Note that this is non-recursive, only drops the node and not any children """ self._nodes.remove(node)
[docs] def add_node(self, node: TreeNode) -> None: """ Add a node to the tree Note that this is non-recursive, only adds the node and not any children """ self._nodes.append(node)
[docs]def mutate(tree: Tree, mutation: TreeMutation) -> None: """ Apply a change to the structure of the tree Modifies not only the tree, but also the links between the TreeNodes Parameters ---------- tree: Tree The tree to mutate mutation: TreeMutation The mutation to apply to the tree """ tree.cache_up_to_date = False if mutation.kind == "prune": tree.remove_node(mutation.existing_node) tree.remove_node(mutation.existing_node.left_child) tree.remove_node(mutation.existing_node.right_child) tree.add_node(mutation.updated_node) if mutation.kind == "grow": tree.remove_node(mutation.existing_node) tree.add_node(mutation.updated_node.left_child) tree.add_node(mutation.updated_node.right_child) tree.add_node(mutation.updated_node) for node in tree.nodes: if node.right_child == mutation.existing_node: node._right_child = mutation.updated_node if node.left_child == mutation.existing_node: node._left_child = mutation.updated_node