The pgml.predict_batch()
function is a performance optimization which allows to return predictions for multiple rows in one function call. It works the same way as pgml.predict()
in all other respects.
Many machine learning algorithms can benefit from calculating predictions in one operation instead of many, and batch predictions can be 3-6 times faster, for large datasets, than pgml.predict()
.
The API for batch predictions is very similar to individual predictions, and only requires two arguments: the project name and the aggregated features used for predictions.
content_copy
pgml.predict_batch(
project_name TEXT,
features REAL[]
)
Parameter |
Example |
Description |
project_name |
My first PostgresML project |
The project name used to train models in pgml.train() . |
features |
array_agg(image) |
An aggregate of feature vectors used to predict novel data points. |
content_copy
SELECT pgml.predict_batch(
'My First PostgresML Project',
array_agg(ARRAY[0.1, 2.0, 5.0])
) AS prediction
FROM pgml.digits
Note that we are passing the result of array_agg()
to our function because we want Postgres to accumulate all the features first, and only give it to PostgresML in one function call.
Batch predictions have to be fetched in a subquery or a CTE because they are using the array_agg()
aggregate. To get the results back in an easily usable form, pgml.predict_batch()
returns a setof
result instead of a normal array, and that can be then built into a table:
content_copy
WITH predictions AS (
SELECT pgml.predict_batch(
'My Classification Project',
array_agg(image)
) AS prediction,
unnest(
array_agg(target)
) AS target
FROM pgml.digits
WHERE target = 0
)
SELECT prediction, target FROM predictions
LIMIT 10;
content_copy
prediction | target
------------+--------
0 | 0
0 | 0
0 | 0
0 | 0
0 | 0
0 | 0
0 | 0
0 | 0
0 | 0
0 | 0
(10 rows)
Since we're using aggregates, one must take care to place limiting predicates into the WHERE
clause of the CTE. For example, we used WHERE target = 0
to batch predict images which are only classified into the 0
class.
To perform a join on batch predictions, it's necessary to have a uniquely identifiable join column for each row. As you saw in the example above, one can pass any column through the aggregation by using a combination of unnest()
and array_agg()
.
Example
content_copy
WITH predictions AS (
SELECT
--
-- Prediction
--
pgml.predict_batch(
'My Bot Detector',
array_agg(ARRAY[account_age, city, last_login])
) AS prediction,
--
-- The pass-through unique identifier for each row
--
unnest(
array_agg(user_id)
) AS target
FROM users
--
-- Filter which rows to pass to pgml.predict_batch()
--
WHERE last_login > NOW() - INTERVAL '1 minute'
)
SELECT prediction, email, ip_address
FROM users
INNER JOIN predictions
ON users.user_id = predictions.user_id