In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import ot


import gc # memory cleaning


import os
import psutil
process = psutil.Process(os.getpid())
print("Memory usage:", process.memory_info().rss/1024/1024,"MB") 
Memory usage: 108.40625 MB
In [2]:
# calculate W2 difference for A
def calculate_W2_difference(A, Mtry, theta_i,X,Y, p = 2):
    Y_left = []
    Y_right = []
    split_direction, split_value = theta_i
    for i in A:
        if X[i,split_direction] < split_value:
            Y_left += [Y[i]]
        else:
            Y_right += [Y[i]]
    return ot.wasserstein_1d(Y_left,Y_right, p = p)

def getALAR(A,theta_star,X):
    AL = []
    AR = []
    split_direction, split_value = theta_star
    for i in A:
        if X[i,split_direction] < split_value:
            AL += [i]
        else:
            AR += [i]
    return AL, AR 
In [3]:
# W2_split
def W2_split(A,Mtry,X,Y,p = 2):
    #Mtry = np.array([0]) 
    min_sample_each_node = 1  
    theta_list = []
    for split_direction in Mtry:
        Xtry = np.sort(X[A,:][:,split_direction])
        candidate_current_direction = (Xtry[min_sample_each_node:] + Xtry[:-min_sample_each_node])*.5
        theta_list += [[split_direction,candidate_current_direction[i]] for i in range(len(A) -min_sample_each_node)]
    
    W2_dist = np.zeros(len(theta_list)) 
    for i in range(len(theta_list)):
        W2_dist[i] = calculate_W2_difference(A, Mtry, theta_list[i],X,Y, p = p)
        
    
    # we consider a bootstrap to accelerate the splitting
    # mtry_bootstrap = 1000
    # W2_dist = np.zeros(mtry_bootstrap) 
    # _i = 0
    # for i in np.random.choice(range(len(theta)),mtry_bootstrap):
    #     W2_dist[_i] = calculate_W2_difference(A, Mtry, theta[i], p = 2)
    #     _i += 1
    return theta_list[np.argmax(W2_dist)]


    
In [4]:
class node:
    def __init__(self, left = None, right = None, parent = None, split = None, neighbours = None, nodes = None):
        self.left = left 
        self.right = right 
        self.parent = parent 
        self.split = split  
        self.neighbours = neighbours 
        
class DecisionTree:
    def __init__(self,
                 mtry = 1,
                 nodesize = 5,
                 subsample = 0.8,
                 bootstrap = True,
                 p = 2,
                 nodes = None):
        self.nodes = nodes # P 
        # parameters
        self.mtry = mtry
        self.nodesize = nodesize
        self.subsample = subsample 
        self.bootstrap = bootstrap 
        self.p = p 
    def fit(self,X,Y):
        N,d = X.shape 
        S_b = np.random.choice(range(N), int(N*self.subsample), replace  = self.bootstrap)
        self.nodes = node(neighbours = S_b)
        P =[self.nodes] 
        while P:
            # A is current node
            A = P[0] 
            if len(A.neighbours) < self.nodesize:
                del P[0]
            else:
                Mtry = np.random.choice(range(d),self.mtry,replace = False)
                theta_star = W2_split(A.neighbours,Mtry,X,Y,self.p)
                A.split = theta_star
                #theta_star_list += [theta_star] 
                AL,AR = getALAR(A.neighbours,theta_star,X)
                del P[0]
                A.left = node(neighbours = AL, parent  = A)
                A.right = node(neighbours = AR, parent  = A)
                P += [A.left,A.right]
        
    def _predict(self,x,Y):
        current_node = self.nodes
        while current_node.split:
            direction,value = current_node.split
            if x[direction]<value:
                current_node = current_node.left
            else:
                current_node = current_node.right
        return np.mean(Y[current_node.neighbours])
    def predict(self,x,Y):
        return np.apply_along_axis(lambda x : self._predict(x,Y),1,x)
        
class WassersteinRandomForest:
    def __init__(self,
                 mtry = 1,
                 nodesize = 5,
                 subsample = 0.8,
                 bootstrap = True,
                 n_estimators = 10,
                 p = 2):
        # parameters
        self.mtry = mtry
        self.nodesize = nodesize
        self.subsample = subsample 
        self.bootstrap = bootstrap 
        self.n_estimators = n_estimators
        self.p = p
        self.ListLearners = []
        self.Y = None
    def fit(self,X,Y):
        self.Y = Y
        for i in range(self.n_estimators):
            BaseLearner = DecisionTree(mtry = self.mtry,
                                       nodesize = self.nodesize,
                                       subsample = self.subsample,
                                       bootstrap = self.bootstrap,
                                       p = self.p)
            BaseLearner.fit(X,Y)
            self.ListLearners += [BaseLearner]
    def predict(self,x):
        prediction = np.zeros(x.shape[0])
        for i in range(self.n_estimators):
            prediction += self.ListLearners[i].predict(x,self.Y)
        prediction /= float(self.n_estimators)
        return prediction
        

Synthetic data

In [5]:
N_total = 2000
N_train = 1000 
X = np.random.uniform(0,1,(N_total,3))
def obj_func(x):
    """
    conditional expectation
    """
    #return x[1]+x[2] +2. +np.sin(x[0])
    return x[1] + x[2]
def obj_func2(x):
    """
    conditional variance
    """
    return np.abs(x[0]+x[1])*.2 
#Y = np.random.normal(0,1.,N_total) + np.apply_along_axis(obj_func,1,X)
Y = np.zeros(N_total)
for i in range(N_total):
    Y[i] = np.random.normal(obj_func(X[i]),np.sqrt(obj_func2(X[i])),1) 
In [6]:
reg = WassersteinRandomForest(nodesize = 5,
                             bootstrap = False,
                             subsample = 0.05,
                             n_estimators = 100,
                             mtry = 2,
                             p = 1)
#reg = DecisionTree()
reg.fit(X[:N_train],Y[:N_train])
In [7]:
# test
#print(reg.predict(X[-10].reshape(1,3),Y))
# print(reg.predict(X[-10:]))
# print("Memory usage:", process.memory_info().rss/1024/1024,"MB") 
In [8]:
from sklearn.metrics import mean_squared_error, r2_score
print("Wasserstein RF")
print("R2:",r2_score(np.apply_along_axis(obj_func,1,X[N_train:]),reg.predict(X[N_train:])))
print("MSE:",
        np.sqrt(mean_squared_error(np.apply_along_axis(obj_func,1,X[N_train:]),reg.predict(X[N_train:])))
     )
Wasserstein RF
R2: 0.8783131104113783
MSE: 0.14264850858126085

Comparison with classical RF

In [9]:
from sklearn.ensemble import RandomForestRegressor

reg2 = RandomForestRegressor(n_estimators = 500,
                            min_samples_split = 5,
                            bootstrap = 0.8)
reg2.fit(X[:N_train],Y[:N_train])
Out[9]:
RandomForestRegressor(bootstrap=0.8, criterion='mse', max_depth=None,
                      max_features='auto', max_leaf_nodes=None,
                      min_impurity_decrease=0.0, min_impurity_split=None,
                      min_samples_leaf=1, min_samples_split=5,
                      min_weight_fraction_leaf=0.0, n_estimators=500,
                      n_jobs=None, oob_score=False, random_state=None,
                      verbose=0, warm_start=False)
In [10]:
print("Classical RF")
print("R2:",r2_score(np.apply_along_axis(obj_func,1,X[N_train:]),reg2.predict(X[N_train:])))
print("MSE:",
        np.sqrt(mean_squared_error(np.apply_along_axis(obj_func,1,X[N_train:]),reg2.predict(X[N_train:])))
     )
Classical RF
R2: 0.8032508432460227
MSE: 0.18138517153516898

Visulization

In [11]:
plt.figure(figsize=(15,10))
plt.subplot(211)
plt.plot(np.apply_along_axis(obj_func,1,X[N_train:N_train+100]), color = "darkred", label="ref")
plt.plot(reg.predict(X[N_train:N_train+100]), color = "grey", label="pred")
plt.title("Estimation by Wasserstein RF")
plt.grid()
plt.legend()

plt.subplot(212)
plt.plot(np.apply_along_axis(obj_func,1,X[N_train:N_train+100]), color = "darkred", label="ref")
plt.plot(reg2.predict(X[N_train:N_train+100]), color = "grey", label="pred")
plt.title("Estimation by Classical RF")
plt.grid()
plt.legend()
Out[11]:
<matplotlib.legend.Legend at 0x7fda9f1086a0>
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]: