Skip to content

Commit

Permalink
GPU Plugin: Add bosch demo, update build instructions (dmlc#1872)
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell authored and tqchen committed Dec 15, 2016
1 parent edc356f commit d943720
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 4 deletions.
42 changes: 42 additions & 0 deletions demo/gpu_acceleration/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# GPU Acceleration Demo

This demo shows how to perform a cross validation on the kaggle Bosch dataset with GPU acceleration. The Bosch numerical dataset has over 1 million rows and 968 features, making it time consuming to process.

This demo requires the [GPU plug-in](https://github.com/dmlc/xgboost/tree/master/plugin/updater_gpu) to be built and installed.

The dataset is available from:
https://www.kaggle.com/c/bosch-production-line-performance/data

Copy train_numeric.csv into xgboost/demo/data.

The subsample parameter can be changed so you can run the script first on a small portion of the data. Processing the entire dataset can take a long time and requires about 8GB of device memory. It is initially set to 0.4, using about 2650/3380MB on a GTX 970.

```python
subsample = 0.4
```

Parameters are set as usual except that we set silent to 0 to see how much memory is being allocated on the GPU and we change 'updater' to 'grow_gpu' to activate the GPU plugin.

```python
param['silent'] = 0
param['updater'] = 'grow_gpu'
```

We use the sklearn cross validation function instead of the xgboost cv function as the xgboost cv will try to fit all folds in GPU memory at the same time.

Using the sklearn cv we can run each fold separately to fit a very large dataset onto the GPU.

Also note the line:
```python
del bst
```

This hints to the python garbage collection that it should delete the booster for the current fold before beginning the next. Without this line python may keep 'bst' from the previous fold in memory, using up precious GPU memory.

You can change the updater parameter to run the equivalent algorithm for the CPU:
```python
param['updater'] = 'grow_colmaker'
```

Expect some minor variations in accuracy between the two versions.

42 changes: 42 additions & 0 deletions demo/gpu_acceleration/bosch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
import pandas as pd
import xgboost as xgb
import time
import random
from sklearn.cross_validation import StratifiedKFold

#For sub sampling rows from input file
random_seed = 9
subsample = 0.4

n_rows = 1183747;
train_rows = int(n_rows * subsample)
random.seed(random_seed)
skip = sorted(random.sample(xrange(1,n_rows + 1),n_rows-train_rows))
data = pd.read_csv("../data/train_numeric.csv", index_col=0, dtype=np.float32, skiprows=skip)
y = data['Response'].values
del data['Response']
X = data.values

param = {}
param['objective'] = 'binary:logistic'
param['eval_metric'] = 'auc'
param['max_depth'] = 5
param['eta'] = 0.3
param['silent'] = 0
param['updater'] = 'grow_gpu'
#param['updater'] = 'grow_colmaker'

num_round = 20

cv = StratifiedKFold(y, n_folds=5)

for i, (train, test) in enumerate(cv):
dtrain = xgb.DMatrix(X[train], label=y[train])
tmp = time.time()
bst = xgb.train(param, dtrain, num_round)
boost_time = time.time() - tmp
res = bst.eval(xgb.DMatrix(X[test], label=y[test]))
print("Fold {}: {}, Boost Time {}".format(i, res, str(boost_time)))
del bst

6 changes: 4 additions & 2 deletions plugin/updater_gpu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ A CUDA capable GPU with at least compute capability >= 3.5 (the algorithm depend

Building the plug-in requires CUDA Toolkit 7.5 or later.

The plugin also depends on CUB 1.5.2 - https://github.com/NVlabs/cub/tree/1.5.2
The plugin also depends on CUB 1.6.4 - https://nvlabs.github.io/cub/

CUB is a header only cuda library which provides sort/reduce/scan primitives.

Expand Down Expand Up @@ -70,6 +70,8 @@ The build process generates an xgboost library and executable as normal but cont
## Author
Rory Mitchell

Report any bugs to r.a.mitchell.nz at google mail.
Please report bugs to the xgboost/issues page. You can tag me with @RAMitchell.

Otherwise I can be contacted at r.a.mitchell.nz at gmail.


1 change: 0 additions & 1 deletion plugin/updater_gpu/src/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,6 @@ class bulk_allocator {
}

_size = get_size_bytes(args...);
std::cout << "trying to allocate: " << _size << "\n";

safe_cuda(cudaMalloc(&d_ptr, _size));

Expand Down
1 change: 0 additions & 1 deletion plugin/updater_gpu/src/gpu_builder.cu
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,6 @@ float GPUBuilder::GetSubsamplingRate(MetaInfo info) {
uint32_t max_nodes_level = 1 << param.max_depth;
size_t required = 10 * info.num_row + 40 * info.num_nonzero
+ 64 * max_nodes + 76 * max_nodes_level * info.num_col;
std::cout << "required: " << required << "\n";
size_t available = dh::available_memory();
while (available < required) {
subsample -= 0.05;
Expand Down

0 comments on commit d943720

Please sign in to comment.