This example trains models on the sklean digits dataset which is a copy of the test set of the UCI ML hand-written digits datasets. This demonstrates using a table with a single array feature column for classification. You could do something similar with a vector column.
content_copy
-- load the sklearn digits dataset
SELECT pgml.load_dataset('digits');
-- view the dataset
SELECT left(image::text, 40) || ',...}', target FROM pgml.digits LIMIT 10;
-- train a simple model to classify the data
SELECT * FROM pgml.train('Handwritten Digits', 'classification', 'pgml.digits', 'target');
-- check out the predictions
SELECT target, pgml.predict('Handwritten Digits', image) AS prediction
FROM pgml.digits
LIMIT 10;
-- view raw class probabilities
SELECT target, pgml.predict_proba('Handwritten Digits', image) AS prediction
FROM pgml.digits
LIMIT 10;
We currently support classification algorithms from scikit-learn, XGBoost, LightGBM and Catboost.
content_copy
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost', hyperparams => '{"n_estimators": 10}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost_random_forest', hyperparams => '{"n_estimators": 10}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'lightgbm', hyperparams => '{"n_estimators": 1}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'catboost', hyperparams => '{"n_estimators": 1}');
content_copy
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'ada_boost');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'bagging');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'extra_trees', hyperparams => '{"n_estimators": 10}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'gradient_boosting_trees', hyperparams => '{"n_estimators": 10}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'random_forest', hyperparams => '{"n_estimators": 10}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'hist_gradient_boosting', hyperparams => '{"max_iter": 2}');
content_copy
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'svm');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'nu_svm');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'linear_svm');
content_copy
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'ridge');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'stochastic_gradient_descent');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'perceptron');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'passive_aggressive');
content_copy
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'gaussian_process', hyperparams => '{"max_iter_predict": 100, "warm_start": true}');