PyTorch Frame 0.1.0
We are excited to announce the initial release of PyTorch Frame 🎉🎉🎉
PyTorch Frame is a deep learning extension for PyTorch, designed for heterogeneous tabular data with different column types, including numerical, categorical, time, text, and images.
To get started, please refer to:
- our README.md for the overview of PyTorch Frame,
- "Introduction by Example" tutorial and its code at
examples/tutorial.py
to get started with using PyTorch Frame, and - "Modular Design of Deep Tabular Models" tutorial in our documentation and the existing implementations in
torch_frame/nn/models/
directory to create your own PyTorch Frame model for tabular data.
Highlights
Models, datasets and examples
In our initial release, we introduce 6 models, 9 feature encoders, 5 table convolution layers, 3 decoders, and 14 datasets.
-
Models
Trompt
: "Trompt: Towards a Better Deep Neural Network for Tabular Data" (examples/trompt.py
)FTTransformer
: "Revisiting Deep Learning Models for Tabular Data" (examples/ft_transformer_text.py
)ExcelFormer
: "ExcelFormer: A Neural Network Surpassing GBDTs on Tabular Data" (examples/excelformer.py
)TabNet
: "TabNet: Attentive Interpretable Tabular Learning" (examples/tabnet.py
)Resnet
: "Revisiting Deep Learning Models for Tabular Data" (examples/revisiting.py
)TabTransformer
: "TabTransformer: Tabular Data Modeling Using Contextual Embeddings" (examples/tab_transformer.py
)
-
Encoders
FeatureEncoder
StypeWiseFeatureEncoder
StypeEncoder
EmbeddingEncoder
LinearEncoder
LinearBucketEncoder
: "On Embeddings for Numerical Features in Tabular Deep Learning"LinearPeriodicEncoder
: "On Embeddings for Numerical Features in Tabular Deep Learning"ExcelFormerEncoder
: "ExcelFormer: A Neural Network Surpassing GBDTs on Tabular Data"StackEncoder
-
Table Convolution Layers
TableConv
FFTransformerConvs
: "Revisiting Deep Learning Models for Tabular Data"TromptConv
: "Trompt: Towards a Better Deep Neural Network for Tabular Data"ExcelFormerConv
: "ExcelFormer: A Neural Network Surpassing GBDTs on Tabular Data"TabTransformerConv
: "TabTransformer: Tabular Data Modeling Using Contextual Embeddings"
-
Decoders
-
Datasets:
AdultCensusIncome
,BankMarketing
,DataFrameBenchmark
,Dota2
,FakeDataset
,ForestCoverType
,KDDCensusIncome
,Mercari
,MultimodalTextBenchmark
,Mushroom
,PokerHand
,TabularBenchmark
,Titanic
,Yandex
Benchmarks
With our initial set of models and datasets under torch_frame.nn
and torch_frame.datasets
, we benchmarked their performance on binary classification and regression tasks. The row denotes the model names and the column denotes the dataset idx
. In each cell, we include the mean and standard deviation of the model performance, as well as the total time spent, including Optuna-based hyper-parameter search and final model training.
Note
- For the latest benchmark scripts and results, see
benchmark/
directory. - For which column number denoting dataset
idx
corresponds to which dataset, see thetorch.datasets.DataFrameBenchmark
dataset docs
Benchmark on small-scale binary classification tasks
Metric: ROC-AUC, higher the better.
dataset_0 | dataset_1 | dataset_2 | dataset_3 | dataset_4 | dataset_5 | dataset_6 | dataset_7 | dataset_8 | dataset_9 | dataset_10 | dataset_11 | dataset_12 | dataset_13 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
XGBoost | 0.931±0.000 (41s) | 1.000±0.000 (4s) | 0.935±0.000 (16s) | 0.946±0.000 (26s) | 0.881±0.000 (10s) | 0.951±0.000 (16s) | 0.862±0.000 (26s) | 0.780±0.000 (11s) | 0.983±0.000 (584s) | 0.763±0.000 (240s) | 0.795±0.000 (11s) | 0.950±0.000 (479s) | 0.999±0.000 (148s) | 0.926±0.000 (3042s) |
CatBoost | 0.930±0.000 (152s) | 1.000±0.000 (9s) | 0.938±0.000 (164s) | 0.924±0.000 (29s) | 0.881±0.000 (27s) | 0.963±0.000 (48s) | 0.861±0.000 (12s) | 0.772±0.000 (10s) | 0.930±0.000 (91s) | 0.628±0.000 (10s) | 0.796±0.000 (15s) | 0.948±0.000 (46s) | 0.998±0.000 (38s) | 0.926±0.000 (115s) |
Trompt | 0.919±0.000 (9627s) | 1.000±0.000 (5341s) | 0.945±0.000 (14679s) | 0.942±0.001 (2752s) | 0.881±0.001 (2640s) | 0.964±0.001 (5173s) | 0.855±0.002 (4249s) | 0.778±0.002 (8789s) | 0.933±0.001 (9353s) | 0.686±0.008 (3105s) | 0.793±0.002 (8255s) | 0.952±0.001 (4876s) | 1.000±0.000 (3558s) | 0.916±0.001 (30002s) |
ResNet | 0.917±0.000 (615s) | 1.000±0.000 (71s) | 0.937±0.001 (787s) | 0.938±0.002 (230s) | 0.865±0.001 (183s) | 0.960±0.001 (349s) | 0.828±0.001 (248s) | 0.768±0.002 (205s) | 0.925±0.002 (958s) | 0.665±0.006 (140s) | 0.794±0.002 (76s) | 0.946±0.002 (145s) | 1.000±0.000 (93s) | 0.911±0.001 (880s) |
FTTransformerBucket | 0.915±0.001 (690s) | 0.999±0.001 (354s) | 0.936±0.002 (1705s) | 0.939±0.002 (484s) | 0.876±0.002 (321s) | 0.960±0.001 (746s) | 0.857±0.000 (549s) | 0.771±0.003 (654s) | 0.909±0.002 (1177s) | 0.636±0.012 (244s) | 0.788±0.002 (710s) | 0.950±0.001 (510s) | 0.999±0.000 (634s) | 0.913±0.001 (1164s) |
ExcelFormer | 0.918±0.001 (1587s) | 1.000±0.000 (634s) | 0.939±0.001 (1827s) | 0.939±0.002 (378s) | 0.878±0.003 (251s) | 0.969±0.000 (678s) | 0.833±0.011 (435s) | 0.780±0.002 (938s) | 0.921±0.005 (1131s) | 0.649±0.008 (519s) | 0.794±0.003 (683s) | 0.950±0.001 (405s) | 0.999±0.000 (1169s) | 0.919±0.001 (1798s) |
FTTransformer | 0.918±0.001 (871s) | 1.000±0.000 (571s) | 0.940±0.001 (1371s) | 0.936±0.001 (458s) | 0.874±0.002 (200s) | 0.959±0.001 (622s) | 0.828±0.001 (339s) | 0.773±0.002 (521s) | 0.909±0.002 (1488s) | 0.635±0.011 (392s) | 0.790±0.001 (556s) | 0.949±0.002 (374s) | 1.000±0.000 (713s) | 0.912±0.000 (1855s) |
TabNet | 0.911±0.001 (150s) | 1.000±0.000 (35s) | 0.931±0.005 (254s) | 0.937±0.003 (125s) | 0.864±0.002 (52s) | 0.944±0.001 (116s) | 0.828±0.001 (79s) | 0.771±0.005 (93s) | 0.913±0.005 (177s) | 0.606±0.014 (65s) | 0.790±0.003 (41s) | 0.936±0.003 (104s) | 1.000±0.000 (64s) | 0.910±0.001 (294s) |
TabTransformer | 0.910±0.001 (2044s) | 1.000±0.000 (1321s) | 0.928±0.001 (2519s) | 0.918±0.003 (134s) | 0.829±0.002 (64s) | 0.928±0.001 (105s) | 0.816±0.002 (99s) | 0.757±0.003 (645s) | 0.885±0.001 (1167s) | 0.652±0.006 (282s) | 0.780±0.002 (112s) | 0.937±0.001 (117s) | 0.996±0.000 (76s) | 0.905±0.001 (2283s) |
Benchmark on small-scale regression tasks
Metric: RMSE, lower the better.
dataset_0 | dataset_1 | dataset_2 | dataset_3 | dataset_4 | dataset_5 | dataset_6 | dataset_7 | dataset_8 | dataset_9 | dataset_10 | dataset_11 | dataset_12 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
XGBoost | 0.247±0.000 (516s) | 0.077±0.000 (14s) | 0.167±0.000 (423s) | 1.119±0.000 (1063s) | 0.328±0.000 (2044s) | 1.024±0.000 (47s) | 0.292±0.000 (844s) | 0.606±0.000 (1765s) | 0.876±0.000 (2288s) | 0.023±0.000 (1170s) | 0.697±0.000 (248s) | 0.865±0.000 (8s) | 0.435±0.000 (22s) |
CatBoost | 0.265±0.000 (116s) | 0.062±0.000 (129s) | 0.128±0.000 (97s) | 0.336±0.000 (103s) | 0.346±0.000 (110s) | 0.443±0.000 (97s) | 0.375±0.000 (46s) | 0.273±0.000 (693s) | 0.881±0.000 (660s) | 0.040±0.000 (80s) | 0.756±0.000 (44s) | 0.876±0.000 (110s) | 0.439±0.000 (101s) |
Trompt | 0.261±0.003 (8390s) | 0.015±0.005 (3792s) | 0.118±0.001 (3836s) | 0.262±0.001 (10037s) | 0.323±0.001 (9255s) | 0.418±0.003 (9071s) | 0.329±0.009 (2977s) | 0.312±0.002 (21967s) | OOM | 0.008±0.001 (1889s) | 0.779±0.006 (775s) | 0.874±0.004 (3723s) | 0.424±0.005 (3185s) |
ResNet | 0.288±0.006 (220s) | 0.018±0.003 (187s) | 0.124±0.001 (135s) | 0.268±0.001 (330s) | 0.335±0.001 (471s) | 0.434±0.004 (345s) | 0.325±0.012 (178s) | 0.324±0.004 (365s) | 0.895±0.005 (142s) | 0.036±0.002 (172s) | 0.794±0.006 (120s) | 0.875±0.004 (122s) | 0.468±0.004 (303s) |
FTTransformerBucket | 0.325±0.008 (619s) | 0.096±0.005 (290s) | 0.360±0.354 (332s) | 0.284±0.005 (768s) | 0.342±0.004 (757s) | 0.441±0.003 (835s) | 0.345±0.007 (191s) | 0.339±0.003 (3321s) | OOM | 0.105±0.011 (199s) | 0.807±0.010 (156s) | 0.885±0.008 (820s) | 0.468±0.006 (706s) |
ExcelFormer | 0.302±0.003 (703s) | 0.099±0.003 (490s) | 0.145±0.003 (587s) | 0.382±0.011 (504s) | 0.344±0.002 (1096s) | 0.411±0.005 (469s) | 0.359±0.016 (207s) | 0.336±0.008 (5522s) | OOM | 0.192±0.014 (317s) | 0.794±0.005 (189s) | 0.890±0.003 (1186s) | 0.445±0.005 (550s) |
FTTransformer | 0.335±0.010 (338s) | 0.161±0.022 (370s) | 0.140±0.002 (244s) | 0.277±0.004 (516s) | 0.335±0.003 (973s) | 0.445±0.003 (599s) | 0.361±0.018 (286s) | 0.345±0.005 (2443s) | OOM | 0.106±0.012 (150s) | 0.826±0.005 (121s) | 0.896±0.007 (832s) | 0.461±0.003 (647s) |
TabNet | 0.279±0.003 (68s) | 0.224±0.016 (53s) | 0.141±0.010 (34s) | 0.275±0.002 (61s) | 0.348±0.003 (110s) | 0.451±0.007 (82s) | 0.355±0.030 (49s) | 0.332±0.004 (168s) | 0.992±0.182 (53s) | 0.015±0.002 (57s) | 0.805±0.014 (27s) | 0.885±0.013 (46s) | 0.544±0.011 (112s) |
TabTransformer | 0.624±0.003 (1225s) | 0.229±0.003 (1200s) | 0.369±0.005 (52s) | 0.340±0.004 (163s) | 0.388±0.002 (1137s) | 0.539±0.003 (100s) | 0.619±0.005 (73s) | 0.351±0.001 (125s) | 0.893±0.005 (389s) | 0.431±0.001 (489s) | 0.819±0.002 (52s) | 0.886±0.005 (46s) | 0.545±0.004 (95s) |
Full Changelog
Full Changelog: 5b5525f...0.1.0