实现最小二乘回归树

实现gbdt或者xgboost 之前,先完成回归树,和先前的ID3算法基本流程一致,所以很简单:

cart 回归树

最小二乘回归树生成算法

思路——递归,参考《统计学习方法》李航

  1. 结束条件:

    • 没有特征了
    • 没有样本了(说明上一步归入了同一类——把父节点标记为叶节点)
    • 所有样本的y都是一样的(说明上一步归入了同一类——把父节点标记为叶节点)
  2. 遍历

    • 遍历特征集合 + 遍历所有可能的阈值
    • 计算切分后落在左右的离差平方和之和
    • 找到使得“方差之和”最小的特征、阈值、左右均值/预测值
  3. 对于左右两部分数据集递归调用算法

# 使用boston 房价数据
from keras.datasets import boston_housing
(train_x, train_y), (test_x, test_y) = boston_housing.load_data()
train_x.shape, test_x.shape

## 将代码块运行结果全部输出,而不是只输出最后的,适用于全文
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"    
((404, 13), (102, 13))
train_x.shape
(404, 13)

1.遍历选取最优分割方式

输出:分割特征、阈值

import pandas as pd
import numpy as np


def square_loss(x):
    xx  = x - np.mean(x)
    return np.dot(xx, xx.T)


def argmin_col_val(data_x, data_y):
    p = data_x.shape[1]
    n = data_x.shape[0]
    if n <= 5:
        return None, None, 0
        
    
    # 1. 遍历所有特征
    for i in range(p):
        col0 = [x[i] for x in data_x]
        values = np.unique(col0)
        feature_values = [v for v in values if (not v==np.min(values)) and \
                          (not v==np.max(values))]
        square_parent = square_loss(data_y) 
        square_new = square_parent
        
        # 2. 遍历该特征所有取值
        for v in feature_values:
            col1 = [data_x[x][i] >= v for x in range(n)]
            col2 = [data_x[x][i] < v for x in range(n)]
            tmp = square_loss(data_y[col1]) + square_loss(data_y[col2])
            # 找到平方和最小的分割方式
            if tmp <= square_new:
                square_new = tmp
                min_col = i
                min_value = v
                mean_right  = np.mean(data_y[col1])
                mean_left = np.mean(data_y[col2])
    # 输出最优分割特征、阈值、左右均值(预测)、分割前后平方和的变化量
    return min_col, min_value, square_parent - square_new
    
    
# 进行第一次划分:
col_1, value_1, square_change = argmin_col_val(train_x, train_y)
print(col_1, value_1)
12 9.64

2.按照指定方式分割后的数据集

# 按照指定方式分割,输出分割后的data
def split_data(data_x, data_y, min_col, min_value):
    col1 = [x[min_col] >= min_value for x in data_x]
    col2 = [x[min_col] < min_value for x in data_x]
    data_x = np.array([np.delete(x, min_col) for x in data_x])
    subset_left_x = data_x[col1]
    subset_right_x = data_x[col2]
    subset_left_y = data_y[col1]
    subset_right_y = data_y[col2]
    return subset_left_x, subset_right_x, subset_left_y,subset_right_y

3.定义节点

class Node:                       # 定义树节点  
    def __init__(self, name):
        self.name = name
        self.connections = {}    
            
    def connect(self, label, node):   # 节点对应多个子节点
        self.connections[label] = node 

4.递归生成树

这我直接设置最小square_change为300(以免过拟合)懒得加参数了...

def bulid_cart(train_x, train_y, parent_node, parent_connection_label):
    min_col, min_val, square_change = argmin_col_val(train_x, train_y)
    # 1.结束条件
    if (not min_col)  or  (square_change < 300): 
        node = Node(np.mean(train_y))
        parent_node.connect(parent_connection_label, node)   # 设置为叶节点
        return

    #print("最优分割列{} 阈值{} 损失降低{}".format(min_col, min_val, square_change))
    # 2.生成新节点
    node = Node(min_col)       # 生成新节点
    parent_node.connect(parent_connection_label, node)
    
    # 3.对左右递归
    subset_left_x, subset_right_x, subset_left_y, subset_right_y = split_data(train_x, train_y, min_col, min_val)
    bulid_cart(subset_left_x, subset_left_y, node, "列"+str(min_col) +"<" +str(min_val) )
    bulid_cart(subset_right_x, subset_right_y, node,  "列"+str(min_col) +">=" +str(min_val))
    

root = Node('root')
bulid_cart(train_x, train_y, root, "")  

5.递归打印树

def print_tree(node, tabs):
    print(str(tabs) + str(node.name))        
    for connection, child_node in node.connections.items():
        print(str(tabs) + "\t" + "(" + str(connection) + ")")
        print_tree(child_node, str(tabs) + "\t\t") 

print_tree(root,"")
root
	()
		12
			(列12<9.64)
				11
					(列11<288.99)
						10
							(列10<19.7)
								16.116190476190475
							(列10>=19.7)
								20.366666666666667
					(列11>=288.99)
						13.183333333333332
			(列12>=9.64)
				11
					(列11<395.58)
						10
							(列10<19.6)
								21.5
							(列10>=19.6)
								27.7452380952381
					(列11>=395.58)
						10
							(列10<14.8)
								9
									(列9<430.0)
										38.75555555555556
									(列9>=430.0)
										8
											(列8<7.0)
												7
													(列7<4.148)
														26.928571428571427
													(列7>=4.148)
														38.230000000000004
											(列8>=7.0)
												27.894915254237286
							(列10>=14.8)
								9
									(列9<264.0)
										38.306250000000006
									(列9>=264.0)
										49.625