文章目录
代码运行结果代码
# 决策树用于拟合import numpy as npimport matplotlib.pyplot as pltfrom sklearn.tree import DecisionTreeRegressorif __name__ == "__main__":# 构造数据N = 100x = np.random.rand(N) * 6 - 3# [-3,3)# print(x.shape) # (100,)x.sort()y = np.sin(x) + np.random.randn(N) * 0.05print(y)x = x.reshape(-1, 1) # 转置后,得到N个样本,每个样本都是1维的print(x)# 决策树分类器dt = DecisionTreeRegressor(criterion='mse', max_depth=9)dt.fit(x, y)x_test = np.linspace(-3, 3, 50).reshape(-1, 1)y_hat = dt.predict(x_test)plt.plot(x, y, 'r*', ms=10, label='Actual')plt.plot(x_test, y_hat, 'g-', linewidth=2, label='Predict')plt.legend(loc='upper left')plt.grid()plt.show()# 比较决策树的深度影响depth = [2, 4, 6, 8, 10]# 颜色clr = 'rgbmy'dtr = DecisionTreeRegressor(criterion='mse')plt.plot(x, y, 'ko', ms=6, label='Actual')x_test = np.linspace(-3, 3, 50).reshape(-1, 1)for d, c in zip(depth, clr):# 设置参数dtr.set_params(max_depth=d)dtr.fit(x, y)y_hat = dtr.predict(x_test)plt.plot(x_test, y_hat, '-', color=c, linewidth=2, label='Depth=%d' % d)plt.legend(loc='upper left')plt.grid(b=True)plt.show()
运行结果
(100,)[-0.13747212 -0.32696352 -0.34299033 -0.37734289 -0.30216829 -0.41908633-0.42649759 -0.55874875 -0.47470554 -0.50349372 -0.60084058 -0.72667652-0.88731673 -0.85007184 -0.80980603 -0.89046954 -0.92967645 -1.01708456-0.96413472 -1.00831613 -1.06009149 -0.98629175 -0.99021064 -0.88084281-0.90996548 -0.89476142 -0.80952269 -0.83540464 -0.76614234 -0.75365537-0.51213752 -0.53558931 -0.5158306 -0.51753766 -0.47760662 -0.49621367-0.35078086 -0.4007496 -0.37787176 -0.35708106 -0.33543894 0.05607983-0.04710956 0.02358386 0.13753866 0.22134074 0.36428241 0.381515420.42788242 0.47056583 0.47299773 0.57728474 0.69424008 0.688668460.74362813 0.85661517 0.79570145 0.72801613 0.83298817 0.913787560.92111679 1.01043268 0.96942097 0.989228 0.97144073 0.95990.90630972 0.94775525 1.00992384 1.00577511 1.0092611 1.066418451.01056367 0.92489214 0.99751525 0.9716967 0.90643779 0.934102050.90237971 0.93908154 0.88156985 0.84080906 0.81336031 0.811845130.77923751 0.71039144 0.65860142 0.68686109 0.66221666 0.467245690.49525938 0.33146802 0.26010888 0.33738618 0.2700388 0.251141230.25704015 0.16070012 0.10970704 0.24002726][[-2.98291781][-2.84917313][-2.83638979][-2.7860988 ][-2.77855502][-2.71508228][-2.69430568][-2.66073423][-2.61173817][-2.58799644][-2.44903871][-2.25958656][-2.23913379][-2.23306881][-2.22936331][-2.10650054][-1.8938904 ][-1.7209747 ][-1.66178009][-1.41848746][-1.30166897][-1.23206966][-1.21439096][-1.06802454][-1.0338541 ][-0.98878396][-0.97796016][-0.95109667][-0.87134449][-0.80900001][-0.56761021][-0.54240145][-0.53489081][-0.44861858][-0.43542907][-0.42543931][-0.41576859][-0.37648098][-0.36547507][-0.35646849][-0.33521726][ 0.03031728][ 0.03678361][ 0.08126219][ 0.11222954][ 0.18367956][ 0.36263569][ 0.37936774][ 0.39229896][ 0.40675994][ 0.57247375][ 0.59730802][ 0.75540724][ 0.80358303][ 0.8803318 ][ 0.92195617][ 0.92913085][ 0.98071443][ 1.08071656][ 1.1088603 ][ 1.14930987][ 1.32788011][ 1.34569531][ 1.36755187][ 1.40643578][ 1.43714342][ 1.44562189][ 1.51583702][ 1.53488103][ 1.58985047][ 1.6181127 ][ 1.641521 ][ 1.65421212][ 1.6710831 ][ 1.68895352][ 1.76660029][ 1.84145428][ 1.88972944][ 1.96540222][ 1.98953008][ 1.9968826 ][ 2.02694326][ 2.09358637][ 2.22188409][ 2.24214962][ 2.38443801][ 2.38563411][ 2.39967366][ 2.45955993][ 2.59606735][ 2.66426555][ 2.7828691 ][ 2.78397954][ 2.83419346][ 2.84315114][ 2.89945669][ 2.90104871][ 2.97192459][ 2.97521555][ 2.99002727]]