build_classifier
- transtab.build_classifier(categorical_columns=None, numerical_columns=None, binary_columns=None, feature_extractor=None, num_class=2, hidden_dim=128, num_layer=2, num_attention_head=8, hidden_dropout_prob=0, ffn_dim=256, activation='relu', device='cuda:0', checkpoint=None, **kwargs) transtab.modeling_transtab.TransTabClassifier [source]
Build a
transtab.modeling_transtab.TransTabClassifier
.- Parameters
categorical_columns (list) – a list of categorical feature names.
numerical_columns (list) – a list of numerical feature names.
binary_columns (list) – a list of binary feature names, accept binary indicators like (yes,no); (true,false); (0,1).
feature_extractor (TransTabFeatureExtractor) – a feature extractor to tokenize the input tables. if not passed the model will build itself.
num_class (int) – number of output classes to be predicted.
hidden_dim (int) – the dimension of hidden embeddings.
num_layer (int) – the number of transformer layers used in the encoder.
num_attention_head (int) – the numebr of heads of multihead self-attention layer in the transformers.
hidden_dropout_prob (float) – the dropout ratio in the transformer encoder.
ffn_dim (int) – the dimension of feed-forward layer in the transformer layer.
activation (str) – the name of used activation functions, support
"relu"
,"gelu"
,"selu"
,"leakyrelu"
.device (str) – the device,
"cpu"
or"cuda:0"
.checkpoint (str) – the directory to load the pretrained TransTab model.
- Returns
- Return type
A TransTabClassifier model.
Warning
If categorical_columns
, numerical_columns
, and binary_columns
are ALL not specified, the model takes ALL as categorical columns
,
which may undermine the performance significantly.