Skip to content

Instantly share code, notes, and snippets.

@thomvolker
Last active February 10, 2025 13:10
Show Gist options
  • Save thomvolker/c739f262c0d8095ea0d8c9bddd3f4968 to your computer and use it in GitHub Desktop.
Save thomvolker/c739f262c0d8095ea0d8c9bddd3f4968 to your computer and use it in GitHub Desktop.

TabPFN in R with reticulate

Generate some example data.

X1 <- runif(200, 0, 10)
X2 <- sin(X1) + rnorm(200, 0, 0.5)
Y <- 3 + 0.5 * X1 + X2 + rnorm(200, 0, 1)

Import python libraries.

import numpy
import pandas as pd
import tabpfn

Transport R objects to python.

Xpy <- reticulate::r_to_py(cbind(X1, X2))
Ypy <- reticulate::r_to_py(Y)

Fit TabPFN model on synthetic data.

reg = tabpfn.TabPFNRegressor()
reg.fit(r["Xpy"], r["Ypy"])
C:\Users\5868777\ONEDRI~1\DOCUME~1\VIRTUA~1\r-keras\lib\site-packages\sklearn\base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
C:\Users\5868777\ONEDRI~1\DOCUME~1\VIRTUA~1\r-keras\lib\site-packages\sklearn\utils\deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.
  warnings.warn(
C:\Users\5868777\ONEDRI~1\DOCUME~1\VIRTUA~1\r-keras\lib\site-packages\sklearn\utils\deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.
  warnings.warn(
TabPFNRegressor()

Obtain predicted values in python.

pred = reg.predict(r["Xpy"])
C:\Users\5868777\ONEDRI~1\DOCUME~1\VIRTUA~1\r-keras\lib\site-packages\sklearn\base.py:474: FutureWarning: `BaseEstimator._validate_data` is deprecated in 1.6 and will be removed in 1.7. Use `sklearn.utils.validation.validate_data` instead. This function becomes public and is part of the scikit-learn developer API.
  warnings.warn(
C:\Users\5868777\ONEDRI~1\DOCUME~1\VIRTUA~1\r-keras\lib\site-packages\sklearn\utils\deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.
  warnings.warn(

Transport predicted values to R.

r["pred"] = pred

Use predictions in R.

pred <- as.data.frame(pred)
plot(pred$pred, Y)

unnamed-chunk-5-1

Check similarities between TabPFN predictions and "true model" predictions of lm().

var(pred$pred) / var(Y)
[1] 0.6748082
fit <- lm(Y ~ X1 + X2)
summary(fit)$r.squared
[1] 0.6865773
cor(fit$fitted.values, Y)
[1] 0.8285996
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment