# Example: Regression with XGBoost

Superconductivty Data Set: Predict the critical temperature based on 81 material features.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

In [None]:
filename = "train_critical_temp.csv"
#filename = "https://www.physi.uni-heidelberg.de/~marks/ml_einfuehrung/Beispiele/train_critical_temp.csv"
df = pd.read_csv(filename, engine='python')

In [None]:
df.head()

In [None]:
y = df['critical_temp'].values
X = df[[col for col in df.columns if col!="critical_temp"]]

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, shuffle=True)

In [None]:
import xgboost as xgb
import time
# XGBreg = xgb.sklearn.XGBRegressor(nthread=-1, seed=1, n_estimators=1000)
XGBreg = xgb.sklearn.XGBRegressor()

start_time = time.time()
XGBreg.fit(X_train, y_train)
run_time = time.time() - start_time

print(run_time)

In [None]:
y_pred = XGBreg.predict(X_test)

In [None]:
plt.scatter(y_test, y_pred, s=2)
plt.xlabel("true critical temperature (K)", fontsize=14)
plt.ylabel("predicted critical temperature (K)", fontsize=14)
plt.savefig("critical_temperature.pdf")

In [None]:
rms = np.sqrt(mean_squared_error(y_test, y_pred))
print(f"root mean square error {rms:.2f}")

In [None]:
# compare with other regressors

from sklearn.ensemble import RandomForestRegressor
rfr = RandomForestRegressor()

from sklearn.ensemble import GradientBoostingRegressor
gbr = GradientBoostingRegressor()

from sklearn.neural_network import MLPRegressor
mlpr = MLPRegressor(hidden_layer_sizes=(50,50), activation='relu', random_state=1, max_iter=5000)

In [None]:
regressors = [rfr, gbr, mlpr]

for reg in regressors:
    
    start_time = time.time()
    reg.fit(X_train, y_train)
    run_time = time.time() - start_time
    
    y_pred = reg.predict(X_test)
    rms = np.sqrt(mean_squared_error(y_test, y_pred))
    
    print(run_time)
    print(f"root mean square error {rms:.2f}\n")
