Classification

Example

This example trains models on the sklean digits dataset which is a copy of the test set of the UCI ML hand-written digits datasets. This demonstrates using a table with a single array feature column for classification. You could do something similar with a vector column.

content_copy
-- load the sklearn digits dataset
SELECT pgml.load_dataset('digits');
-- view the dataset
SELECT left(image::text, 40) || ',...}', target FROM pgml.digits LIMIT 10;
-- train a simple model to classify the data
SELECT * FROM pgml.train('Handwritten Digits', 'classification', 'pgml.digits', 'target');
-- check out the predictions
SELECT target, pgml.predict('Handwritten Digits', image) AS prediction
FROM pgml.digits
LIMIT 10;
-- view raw class probabilities
SELECT target, pgml.predict_proba('Handwritten Digits', image) AS prediction
FROM pgml.digits
LIMIT 10;

Algorithms

We currently support classification algorithms from scikit-learn, XGBoost, LightGBM and Catboost.

Gradient Boosting

Algorithm Reference
xgboost XGBClassifier
xgboost_random_forest XGBRFClassifier
lightgbm LGBMClassifier
catboost CatBoostClassifier

Examples

content_copy
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost', hyperparams => '{"n_estimators": 10}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost_random_forest', hyperparams => '{"n_estimators": 10}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'lightgbm', hyperparams => '{"n_estimators": 1}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'catboost', hyperparams => '{"n_estimators": 1}');

Scikit Ensembles

Algorithm Reference
ada_boost AdaBoostClassifier
bagging BaggingClassifier
extra_trees ExtraTreesClassifier
gradient_boosting_trees GradientBoostingClassifier
random_forest RandomForestClassifier
hist_gradient_boosting HistGradientBoostingClassifier

Examples

content_copy
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'ada_boost');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'bagging');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'extra_trees', hyperparams => '{"n_estimators": 10}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'gradient_boosting_trees', hyperparams => '{"n_estimators": 10}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'random_forest', hyperparams => '{"n_estimators": 10}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'hist_gradient_boosting', hyperparams => '{"max_iter": 2}');

Support Vector Machines

Algorithm Reference
svm SVC
nu_svm NuSVC
linear_svm LinearSVC

Examples

content_copy
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'svm');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'nu_svm');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'linear_svm');

Linear Models

Algorithm Reference
linear LogisticRegression
ridge RidgeClassifier
stochastic_gradient_descent SGDClassifier
perceptron Perceptron
passive_aggressive PassiveAggressiveClassifier

Examples

content_copy
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'ridge');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'stochastic_gradient_descent');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'perceptron');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'passive_aggressive');

Other

Algorithm Reference
gaussian_process GaussianProcessClassifier

Examples

content_copy
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'gaussian_process', hyperparams => '{"max_iter_predict": 100, "warm_start": true}');