TransTabModel
- class transtab.modeling_transtab.TransTabModel(categorical_columns=None, numerical_columns=None, binary_columns=None, feature_extractor=None, hidden_dim=128, num_layer=2, num_attention_head=8, hidden_dropout_prob=0.1, ffn_dim=256, activation='relu', device='cuda:0', **kwargs)[source]
The base transtab model for downstream tasks like contrastive learning, binary classification, etc. All models subclass this basemodel and usually rewrite the
forward
function. Refer to the source code oftranstab.modeling_transtab.TransTabClassifier
ortranstab.modeling_transtab.TransTabForCL
for the implementation details.- Parameters
categorical_columns (list) – a list of categorical feature names.
numerical_columns (list) – a list of numerical feature names.
binary_columns (list) – a list of binary feature names, accept binary indicators like (yes,no); (true,false); (0,1).
feature_extractor (TransTabFeatureExtractor) – a feature extractor to tokenize the input tables. if not passed the model will build itself.
hidden_dim (int) – the dimension of hidden embeddings.
num_layer (int) – the number of transformer layers used in the encoder.
num_attention_head (int) – the numebr of heads of multihead self-attention layer in the transformers.
hidden_dropout_prob (float) – the dropout ratio in the transformer encoder.
ffn_dim (int) – the dimension of feed-forward layer in the transformer layer.
activation (str) – the name of used activation functions, support
"relu"
,"gelu"
,"selu"
,"leakyrelu"
.device (str) – the device,
"cpu"
or"cuda:0"
.
- Returns
- Return type
A TransTabModel model.
- forward(x, y=None)[source]
Extract the embeddings based on input tables.
- Parameters
x (pd.DataFrame) – a batch of samples stored in pd.DataFrame.
y (pd.Series) – the corresponding labels for each sample in
x
. ignored for the basemodel.
- Returns
final_cls_embedding – the [CLS] embedding at the end of transformer encoder.
- Return type
torch.Tensor
- load(ckpt_dir)[source]
Load the model state_dict and feature_extractor configuration from the
ckpt_dir
.- Parameters
ckpt_dir (str) – the directory path to load.
- Returns
- Return type
None
- save(ckpt_dir)[source]
Save the model state_dict and feature_extractor configuration to the
ckpt_dir
.- Parameters
ckpt_dir (str) – the directory path to save.
- Returns
- Return type
None
- update(config)[source]
Update the configuration of feature extractor’s column map for cat, num, and bin cols. Or update the number of classes for the output classifier layer.
- Parameters
config (dict) – a dict of configurations: keys cat:list, num:list, bin:list are to specify the new column names; key num_class:int is to specify the number of classes for finetuning on a new dataset.
- Returns
- Return type
None