build_extractor

transtab.build_extractor(categorical_columns=None, numerical_columns=None, binary_columns=None, ignore_duplicate_cols=False, disable_tokenizer_parallel=False, checkpoint=None, **kwargs) transtab.modeling_transtab.TransTabFeatureExtractor[source]

Build a feature extractor for TransTab model.

Parameters
  • categorical_columns (list) – a list of categorical feature names.

  • numerical_columns (list) – a list of numerical feature names.

  • binary_columns (list) – a list of binary feature names, accept binary indicators like (yes,no); (true,false); (0,1).

  • ignore_duplicate_cols (bool) – if there is one column assigned to more than one type, e.g., the feature age is both nominated as categorical and binary columns, the model will raise errors. set True to avoid this error as the model will ignore this duplicate feature.

  • disable_tokenizer_parallel (bool) – if the returned feature extractor is leveraged by the collate function for a dataloader, try to set this False in case the dataloader raises errors because the dataloader builds multiple workers and the tokenizer builds multiple workers at the same time.

  • checkpoint (str) – the directory of the predefined TransTabFeatureExtractor.

Returns

Return type

A TransTabFeatureExtractor module.

The returned feature extractor takes pd.DataFrame as inputs and outputs the encoded outputs in dict.

# build the feature extractor
extractor = transtab.build_extractor(categorical_columns=['gender'], numerical_columns=['age'])

# build a table for inputs
df = pd.DataFrame({'age':[1,2], 'gender':['male','female']})

# extract the outputs
outputs = extractor(df)

print(outputs)

'''
    {
    'x_num': tensor([[1.],[2.]], dtype=torch.float64),
    'num_col_input_ids': tensor([[2287]]),
    'x_cat_input_ids': tensor([[5907, 3287], [5907, 2931]]),
    'x_bin_input_ids': None,
    'num_att_mask': tensor([[1]]),
    'cat_att_mask': tensor([[1, 1], [1, 1]])
    }
'''