TransTabModel

class transtab.modeling_transtab.TransTabModel(categorical_columns=None, numerical_columns=None, binary_columns=None, feature_extractor=None, hidden_dim=128, num_layer=2, num_attention_head=8, hidden_dropout_prob=0.1, ffn_dim=256, activation='relu', device='cuda:0', **kwargs)[source]

The base transtab model for downstream tasks like contrastive learning, binary classification, etc. All models subclass this basemodel and usually rewrite the forward function. Refer to the source code of transtab.modeling_transtab.TransTabClassifier or transtab.modeling_transtab.TransTabForCL for the implementation details.

Parameters
  • categorical_columns (list) – a list of categorical feature names.

  • numerical_columns (list) – a list of numerical feature names.

  • binary_columns (list) – a list of binary feature names, accept binary indicators like (yes,no); (true,false); (0,1).

  • feature_extractor (TransTabFeatureExtractor) – a feature extractor to tokenize the input tables. if not passed the model will build itself.

  • hidden_dim (int) – the dimension of hidden embeddings.

  • num_layer (int) – the number of transformer layers used in the encoder.

  • num_attention_head (int) – the numebr of heads of multihead self-attention layer in the transformers.

  • hidden_dropout_prob (float) – the dropout ratio in the transformer encoder.

  • ffn_dim (int) – the dimension of feed-forward layer in the transformer layer.

  • activation (str) – the name of used activation functions, support "relu", "gelu", "selu", "leakyrelu".

  • device (str) – the device, "cpu" or "cuda:0".

Returns

Return type

A TransTabModel model.

forward(x, y=None)[source]

Extract the embeddings based on input tables.

Parameters
  • x (pd.DataFrame) – a batch of samples stored in pd.DataFrame.

  • y (pd.Series) – the corresponding labels for each sample in x. ignored for the basemodel.

Returns

final_cls_embedding – the [CLS] embedding at the end of transformer encoder.

Return type

torch.Tensor

load(ckpt_dir)[source]

Load the model state_dict and feature_extractor configuration from the ckpt_dir.

Parameters

ckpt_dir (str) – the directory path to load.

Returns

Return type

None

save(ckpt_dir)[source]

Save the model state_dict and feature_extractor configuration to the ckpt_dir.

Parameters

ckpt_dir (str) – the directory path to save.

Returns

Return type

None

update(config)[source]

Update the configuration of feature extractor’s column map for cat, num, and bin cols. Or update the number of classes for the output classifier layer.

Parameters

config (dict) – a dict of configurations: keys cat:list, num:list, bin:list are to specify the new column names; key num_class:int is to specify the number of classes for finetuning on a new dataset.

Returns

Return type

None