import os, pdb
import math
import collections
import json
from typing import Dict, Optional, Any, Union, Callable, List
from loguru import logger
from transformers import BertTokenizer, BertTokenizerFast
import torch
from torch import nn
from torch import Tensor
import torch.nn.init as nn_init
import torch.nn.functional as F
import numpy as np
import pandas as pd
from . import constants
class TransTabWordEmbedding(nn.Module):
r'''
Encode tokens drawn from column names, categorical and binary features.
'''
def __init__(self,
vocab_size,
hidden_dim,
padding_idx=0,
hidden_dropout_prob=0,
layer_norm_eps=1e-5,
) -> None:
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_dim, padding_idx)
nn_init.kaiming_normal_(self.word_embeddings.weight)
self.norm = nn.LayerNorm(hidden_dim, eps=layer_norm_eps)
self.dropout = nn.Dropout(hidden_dropout_prob)
def forward(self, input_ids) -> Tensor:
embeddings = self.word_embeddings(input_ids)
embeddings = self.norm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class TransTabNumEmbedding(nn.Module):
r'''
Encode tokens drawn from column names and the corresponding numerical features.
'''
def __init__(self, hidden_dim) -> None:
super().__init__()
self.norm = nn.LayerNorm(hidden_dim)
self.num_bias = nn.Parameter(Tensor(1, 1, hidden_dim)) # add bias
nn_init.uniform_(self.num_bias, a=-1/math.sqrt(hidden_dim), b=1/math.sqrt(hidden_dim))
def forward(self, num_col_emb, x_num_ts, num_mask=None) -> Tensor:
'''args:
num_col_emb: numerical column embedding, (# numerical columns, emb_dim)
x_num_ts: numerical features, (bs, emb_dim)
num_mask: the mask for NaN numerical features, (bs, # numerical columns)
'''
num_col_emb = num_col_emb.unsqueeze(0).expand((x_num_ts.shape[0],-1,-1))
num_feat_emb = num_col_emb * x_num_ts.unsqueeze(-1).float() + self.num_bias
return num_feat_emb
class TransTabFeatureExtractor:
r'''
Process input dataframe to input indices towards transtab encoder,
usually used to build dataloader for paralleling loading.
'''
def __init__(self,
categorical_columns=None,
numerical_columns=None,
binary_columns=None,
disable_tokenizer_parallel=False,
ignore_duplicate_cols=False,
**kwargs,
) -> None:
'''args:
categorical_columns: a list of categories feature names
numerical_columns: a list of numerical feature names
binary_columns: a list of yes or no feature names, accept binary indicators like
(yes,no); (true,false); (0,1).
disable_tokenizer_parallel: true if use extractor for collator function in torch.DataLoader
ignore_duplicate_cols: check if exists one col belongs to both cat/num or cat/bin or num/bin,
if set `true`, the duplicate cols will be deleted, else throws errors.
'''
if os.path.exists('./transtab/tokenizer'):
self.tokenizer = BertTokenizerFast.from_pretrained('./transtab/tokenizer')
else:
self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
self.tokenizer.save_pretrained('./transtab/tokenizer')
self.tokenizer.__dict__['model_max_length'] = 512
if disable_tokenizer_parallel: # disable tokenizer parallel
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self.vocab_size = self.tokenizer.vocab_size
self.pad_token_id = self.tokenizer.pad_token_id
self.categorical_columns = categorical_columns
self.numerical_columns = numerical_columns
self.binary_columns = binary_columns
self.ignore_duplicate_cols = ignore_duplicate_cols
if categorical_columns is not None:
self.categorical_columns = list(set(categorical_columns))
if numerical_columns is not None:
self.numerical_columns = list(set(numerical_columns))
if binary_columns is not None:
self.binary_columns = list(set(binary_columns))
# check if column exists overlap
col_no_overlap, duplicate_cols = self._check_column_overlap(self.categorical_columns, self.numerical_columns, self.binary_columns)
if not self.ignore_duplicate_cols:
for col in duplicate_cols:
logger.error(f'Find duplicate cols named `{col}`, please process the raw data or set `ignore_duplicate_cols` to True!')
assert col_no_overlap, 'The assigned categorical_columns, numerical_columns, binary_columns should not have overlap! Please check your input.'
else:
self._solve_duplicate_cols(duplicate_cols)
def __call__(self, x, shuffle=False) -> Dict:
'''
Parameters
----------
x: pd.DataFrame
with column names and features.
shuffle: bool
if shuffle column order during the training.
Returns
-------
encoded_inputs: a dict with {
'x_num': tensor contains numerical features,
'num_col_input_ids': tensor contains numerical column tokenized ids,
'x_cat_input_ids': tensor contains categorical column + feature ids,
'x_bin_input_ids': tesnor contains binary column + feature ids,
}
'''
encoded_inputs = {
'x_num':None,
'num_col_input_ids':None,
'x_cat_input_ids':None,
'x_bin_input_ids':None,
}
col_names = x.columns.tolist()
cat_cols = [c for c in col_names if c in self.categorical_columns] if self.categorical_columns is not None else []
num_cols = [c for c in col_names if c in self.numerical_columns] if self.numerical_columns is not None else []
bin_cols = [c for c in col_names if c in self.binary_columns] if self.binary_columns is not None else []
if len(cat_cols+num_cols+bin_cols) == 0:
# take all columns as categorical columns!
cat_cols = col_names
if shuffle:
np.random.shuffle(cat_cols)
np.random.shuffle(num_cols)
np.random.shuffle(bin_cols)
# TODO:
# mask out NaN values like done in binary columns
if len(num_cols) > 0:
x_num = x[num_cols]
x_num = x_num.fillna(0) # fill Nan with zero
x_num_ts = torch.tensor(x_num.values, dtype=float)
num_col_ts = self.tokenizer(num_cols, padding=True, truncation=True, add_special_tokens=False, return_tensors='pt')
encoded_inputs['x_num'] = x_num_ts
encoded_inputs['num_col_input_ids'] = num_col_ts['input_ids']
encoded_inputs['num_att_mask'] = num_col_ts['attention_mask'] # mask out attention
if len(cat_cols) > 0:
x_cat = x[cat_cols].astype(str)
x_mask = (~pd.isna(x_cat)).astype(int)
x_cat = x_cat.fillna('')
x_cat = x_cat.apply(lambda x: x.name + ' '+ x) * x_mask # mask out nan features
x_cat_str = x_cat.agg(' '.join, axis=1).values.tolist()
x_cat_ts = self.tokenizer(x_cat_str, padding=True, truncation=True, add_special_tokens=False, return_tensors='pt')
encoded_inputs['x_cat_input_ids'] = x_cat_ts['input_ids']
encoded_inputs['cat_att_mask'] = x_cat_ts['attention_mask']
if len(bin_cols) > 0:
x_bin = x[bin_cols] # x_bin should already be integral (binary values in 0 & 1)
x_bin_str = x_bin.apply(lambda x: x.name + ' ') * x_bin
x_bin_str = x_bin_str.agg(' '.join, axis=1).values.tolist()
x_bin_ts = self.tokenizer(x_bin_str, padding=True, truncation=True, add_special_tokens=False, return_tensors='pt')
if x_bin_ts['input_ids'].shape[1] > 0: # not all false
encoded_inputs['x_bin_input_ids'] = x_bin_ts['input_ids']
encoded_inputs['bin_att_mask'] = x_bin_ts['attention_mask']
return encoded_inputs
def save(self, path):
'''save the feature extractor configuration to local dir.
'''
save_path = os.path.join(path, constants.EXTRACTOR_STATE_DIR)
if not os.path.exists(save_path):
os.makedirs(save_path)
# save tokenizer
tokenizer_path = os.path.join(save_path, constants.TOKENIZER_DIR)
self.tokenizer.save_pretrained(tokenizer_path)
# save other configurations
coltype_path = os.path.join(save_path, constants.EXTRACTOR_STATE_NAME)
col_type_dict = {
'categorical': self.categorical_columns,
'binary': self.binary_columns,
'numerical': self.numerical_columns,
}
with open(coltype_path, 'w', encoding='utf-8') as f:
f.write(json.dumps(col_type_dict))
def load(self, path):
'''load the feature extractor configuration from local dir.
'''
tokenizer_path = os.path.join(path, constants.TOKENIZER_DIR)
coltype_path = os.path.join(path, constants.EXTRACTOR_STATE_NAME)
self.tokenizer = BertTokenizerFast.from_pretrained(tokenizer_path)
with open(coltype_path, 'r', encoding='utf-8') as f:
col_type_dict = json.loads(f.read())
self.categorical_columns = col_type_dict['categorical']
self.numerical_columns = col_type_dict['numerical']
self.binary_columns = col_type_dict['binary']
logger.info(f'load feature extractor from {coltype_path}')
def update(self, cat=None, num=None, bin=None):
'''update cat/num/bin column maps.
'''
if cat is not None:
self.categorical_columns.extend(cat)
self.categorical_columns = list(set(self.categorical_columns))
if num is not None:
self.numerical_columns.extend(num)
self.numerical_columns = list(set(self.numerical_columns))
if bin is not None:
self.binary_columns.extend(bin)
self.binary_columns = list(set(self.binary_columns))
col_no_overlap, duplicate_cols = self._check_column_overlap(self.categorical_columns, self.numerical_columns, self.binary_columns)
if not self.ignore_duplicate_cols:
for col in duplicate_cols:
logger.error(f'Find duplicate cols named `{col}`, please process the raw data or set `ignore_duplicate_cols` to True!')
assert col_no_overlap, 'The assigned categorical_columns, numerical_columns, binary_columns should not have overlap! Please check your input.'
else:
self._solve_duplicate_cols(duplicate_cols)
def _check_column_overlap(self, cat_cols=None, num_cols=None, bin_cols=None):
all_cols = []
if cat_cols is not None: all_cols.extend(cat_cols)
if num_cols is not None: all_cols.extend(num_cols)
if bin_cols is not None: all_cols.extend(bin_cols)
org_length = len(all_cols)
if org_length == 0:
logger.warning('No cat/num/bin cols specified, will take ALL columns as categorical! Ignore this warning if you specify the `checkpoint` to load the model.')
return True, []
unq_length = len(list(set(all_cols)))
duplicate_cols = [item for item, count in collections.Counter(all_cols).items() if count > 1]
return org_length == unq_length, duplicate_cols
def _solve_duplicate_cols(self, duplicate_cols):
for col in duplicate_cols:
logger.warning('Find duplicate cols named `{col}`, will ignore it during training!')
if col in self.categorical_columns:
self.categorical_columns.remove(col)
self.categorical_columns.append(f'[cat]{col}')
if col in self.numerical_columns:
self.numerical_columns.remove(col)
self.numerical_columns.append(f'[num]{col}')
if col in self.binary_columns:
self.binary_columns.remove(col)
self.binary_columns.append(f'[bin]{col}')
class TransTabFeatureProcessor(nn.Module):
r'''
Process inputs from feature extractor to map them to embeddings.
'''
def __init__(self,
vocab_size=None,
hidden_dim=128,
hidden_dropout_prob=0,
pad_token_id=0,
device='cuda:0',
) -> None:
'''args:
categorical_columns: a list of categories feature names
numerical_columns: a list of numerical feature names
binary_columns: a list of yes or no feature names, accept binary indicators like
(yes,no); (true,false); (0,1).
'''
super().__init__()
self.word_embedding = TransTabWordEmbedding(
vocab_size=vocab_size,
hidden_dim=hidden_dim,
hidden_dropout_prob=hidden_dropout_prob,
padding_idx=pad_token_id
)
self.num_embedding = TransTabNumEmbedding(hidden_dim)
self.align_layer = nn.Linear(hidden_dim, hidden_dim, bias=False)
self.device = device
def _avg_embedding_by_mask(self, embs, att_mask=None):
if att_mask is None:
return embs.mean(1)
else:
embs[att_mask==0] = 0
embs = embs.sum(1) / att_mask.sum(1,keepdim=True).to(embs.device)
return embs
def forward(self,
x_num=None,
num_col_input_ids=None,
num_att_mask=None,
x_cat_input_ids=None,
cat_att_mask=None,
x_bin_input_ids=None,
bin_att_mask=None,
**kwargs,
) -> Tensor:
'''args:
x: pd.DataFrame with column names and features.
shuffle: if shuffle column order during the training.
num_mask: indicate the NaN place of numerical features, 0: NaN 1: normal.
'''
num_feat_embedding = None
cat_feat_embedding = None
bin_feat_embedding = None
if x_num is not None and num_col_input_ids is not None:
num_col_emb = self.word_embedding(num_col_input_ids.to(self.device)) # number of cat col, num of tokens, embdding size
x_num = x_num.to(self.device)
num_col_emb = self._avg_embedding_by_mask(num_col_emb, num_att_mask)
num_feat_embedding = self.num_embedding(num_col_emb, x_num)
num_feat_embedding = self.align_layer(num_feat_embedding)
if x_cat_input_ids is not None:
cat_feat_embedding = self.word_embedding(x_cat_input_ids.to(self.device))
cat_feat_embedding = self.align_layer(cat_feat_embedding)
if x_bin_input_ids is not None:
if x_bin_input_ids.shape[1] == 0: # all false, pad zero
x_bin_input_ids = torch.zeros(x_bin_input_ids.shape[0],dtype=int)[:,None]
bin_feat_embedding = self.word_embedding(x_bin_input_ids.to(self.device))
bin_feat_embedding = self.align_layer(bin_feat_embedding)
# concat all embeddings
emb_list = []
att_mask_list = []
if num_feat_embedding is not None:
emb_list += [num_feat_embedding]
att_mask_list += [torch.ones(num_feat_embedding.shape[0], num_feat_embedding.shape[1])]
if cat_feat_embedding is not None:
emb_list += [cat_feat_embedding]
att_mask_list += [cat_att_mask]
if bin_feat_embedding is not None:
emb_list += [bin_feat_embedding]
att_mask_list += [bin_att_mask]
if len(emb_list) == 0: raise Exception('no feature found belonging into numerical, categorical, or binary, check your data!')
all_feat_embedding = torch.cat(emb_list, 1).float()
attention_mask = torch.cat(att_mask_list, 1).to(all_feat_embedding.device)
return {'embedding': all_feat_embedding, 'attention_mask': attention_mask}
def _get_activation_fn(activation):
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
elif activation == 'selu':
return F.selu
elif activation == 'leakyrelu':
return F.leaky_relu
raise RuntimeError("activation should be relu/gelu/selu/leakyrelu, not {}".format(activation))
class TransTabTransformerLayer(nn.Module):
__constants__ = ['batch_first', 'norm_first']
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu,
layer_norm_eps=1e-5, batch_first=True, norm_first=False,
device=None, dtype=None, use_layer_norm=True) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, batch_first=batch_first,
**factory_kwargs)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
# Implementation of gates
self.gate_linear = nn.Linear(d_model, 1, bias=False)
self.gate_act = nn.Sigmoid()
self.norm_first = norm_first
self.use_layer_norm = use_layer_norm
if self.use_layer_norm:
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
# Legacy string support for activation function.
if isinstance(activation, str):
self.activation = _get_activation_fn(activation)
else:
self.activation = activation
# self-attention block
def _sa_block(self, x: Tensor,
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
src = x
key_padding_mask = ~key_padding_mask.bool()
x = self.self_attn(x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
)[0]
return self.dropout1(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
g = self.gate_act(self.gate_linear(x))
h = self.linear1(x)
h = h * g # add gate
h = self.linear2(self.dropout(self.activation(h)))
return self.dropout2(h)
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = F.relu
super().__setstate__(state)
def forward(self, src, src_mask= None, src_key_padding_mask= None, is_causal=None, **kwargs) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
x = src
if self.use_layer_norm:
if self.norm_first:
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
x = x + self._ff_block(self.norm2(x))
else:
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
x = self.norm2(x + self._ff_block(x))
else: # do not use layer norm
x = x + self._sa_block(x, src_mask, src_key_padding_mask)
x = x + self._ff_block(x)
return x
class TransTabInputEncoder(nn.Module):
'''
Build a feature encoder that maps inputs tabular samples to embeddings.
Parameters:
-----------
categorical_columns: list
a list of categorical feature names.
numerical_columns: list
a list of numerical feature names.
binary_columns: list
a list of binary feature names, accept binary indicators like (yes,no); (true,false); (0,1).
ignore_duplicate_cols: bool
if there is one column assigned to more than one type, e.g., the feature age is both nominated
as categorical and binary columns, the model will raise errors. set True to avoid this error as
the model will ignore this duplicate feature.
disable_tokenizer_parallel: bool
if the returned feature extractor is leveraged by the collate function for a dataloader,
try to set this False in case the dataloader raises errors because the dataloader builds
multiple workers and the tokenizer builds multiple workers at the same time.
hidden_dim: int
the dimension of hidden embeddings.
hidden_dropout_prob: float
the dropout ratio in the transformer encoder.
device: str
the device, ``"cpu"`` or ``"cuda:0"``.
'''
def __init__(self,
feature_extractor,
feature_processor,
device='cuda:0',
):
super().__init__()
self.feature_extractor = feature_extractor
self.feature_processor = feature_processor
self.device = device
self.to(device)
def forward(self, x):
'''
Encode input tabular samples into embeddings.
Parameters
----------
x: pd.DataFrame
with column names and features.
'''
tokenized = self.feature_extractor(x)
embeds = self.feature_processor(**tokenized)
return embeds
def load(self, ckpt_dir):
# load feature extractor
self.feature_extractor.load(os.path.join(ckpt_dir, constants.EXTRACTOR_STATE_DIR))
# load embedding layer
model_name = os.path.join(ckpt_dir, constants.INPUT_ENCODER_NAME)
state_dict = torch.load(model_name, map_location='cpu')
missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
logger.info(f'missing keys: {missing_keys}')
logger.info(f'unexpected keys: {unexpected_keys}')
logger.info(f'load model from {ckpt_dir}')
class TransTabEncoder(nn.Module):
def __init__(self,
hidden_dim=128,
num_layer=2,
num_attention_head=2,
hidden_dropout_prob=0,
ffn_dim=256,
activation='relu',
):
super().__init__()
self.transformer_encoder = nn.ModuleList(
[
TransTabTransformerLayer(
d_model=hidden_dim,
nhead=num_attention_head,
dropout=hidden_dropout_prob,
dim_feedforward=ffn_dim,
batch_first=True,
layer_norm_eps=1e-5,
norm_first=False,
use_layer_norm=True,
activation=activation,)
]
)
if num_layer > 1:
encoder_layer = TransTabTransformerLayer(d_model=hidden_dim,
nhead=num_attention_head,
dropout=hidden_dropout_prob,
dim_feedforward=ffn_dim,
batch_first=True,
layer_norm_eps=1e-5,
norm_first=False,
use_layer_norm=True,
activation=activation,
)
stacked_transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layer-1)
self.transformer_encoder.append(stacked_transformer)
def forward(self, embedding, attention_mask=None, **kwargs) -> Tensor:
'''args:
embedding: bs, num_token, hidden_dim
'''
outputs = embedding
for i, mod in enumerate(self.transformer_encoder):
outputs = mod(outputs, src_key_padding_mask=attention_mask)
return outputs
class TransTabLinearClassifier(nn.Module):
def __init__(self,
num_class,
hidden_dim=128) -> None:
super().__init__()
if num_class <= 2:
self.fc = nn.Linear(hidden_dim, 1)
else:
self.fc = nn.Linear(hidden_dim, num_class)
self.norm = nn.LayerNorm(hidden_dim)
def forward(self, x) -> Tensor:
x = x[:,0,:] # take the cls token embedding
x = self.norm(x)
logits = self.fc(x)
return logits
class TransTabProjectionHead(nn.Module):
def __init__(self,
hidden_dim=128,
projection_dim=128):
super().__init__()
self.dense = nn.Linear(hidden_dim, projection_dim, bias=False)
def forward(self, x) -> Tensor:
h = self.dense(x)
return h
class TransTabCLSToken(nn.Module):
'''add a learnable cls token embedding at the end of each sequence.
'''
def __init__(self, hidden_dim) -> None:
super().__init__()
self.weight = nn.Parameter(Tensor(hidden_dim))
nn_init.uniform_(self.weight, a=-1/math.sqrt(hidden_dim),b=1/math.sqrt(hidden_dim))
self.hidden_dim = hidden_dim
def expand(self, *leading_dimensions):
new_dims = (1,) * (len(leading_dimensions)-1)
return self.weight.view(*new_dims, -1).expand(*leading_dimensions, -1)
def forward(self, embedding, attention_mask=None, **kwargs) -> Tensor:
embedding = torch.cat([self.expand(len(embedding), 1), embedding], dim=1)
outputs = {'embedding': embedding}
if attention_mask is not None:
attention_mask = torch.cat([torch.ones(attention_mask.shape[0],1).to(attention_mask.device), attention_mask], 1)
outputs['attention_mask'] = attention_mask
return outputs
[docs]class TransTabModel(nn.Module):
'''The base transtab model for downstream tasks like contrastive learning, binary classification, etc.
All models subclass this basemodel and usually rewrite the ``forward`` function. Refer to the source code of
:class:`transtab.modeling_transtab.TransTabClassifier` or :class:`transtab.modeling_transtab.TransTabForCL` for the implementation details.
Parameters
----------
categorical_columns: list
a list of categorical feature names.
numerical_columns: list
a list of numerical feature names.
binary_columns: list
a list of binary feature names, accept binary indicators like (yes,no); (true,false); (0,1).
feature_extractor: TransTabFeatureExtractor
a feature extractor to tokenize the input tables. if not passed the model will build itself.
hidden_dim: int
the dimension of hidden embeddings.
num_layer: int
the number of transformer layers used in the encoder.
num_attention_head: int
the numebr of heads of multihead self-attention layer in the transformers.
hidden_dropout_prob: float
the dropout ratio in the transformer encoder.
ffn_dim: int
the dimension of feed-forward layer in the transformer layer.
activation: str
the name of used activation functions, support ``"relu"``, ``"gelu"``, ``"selu"``, ``"leakyrelu"``.
device: str
the device, ``"cpu"`` or ``"cuda:0"``.
Returns
-------
A TransTabModel model.
'''
def __init__(self,
categorical_columns=None,
numerical_columns=None,
binary_columns=None,
feature_extractor=None,
hidden_dim=128,
num_layer=2,
num_attention_head=8,
hidden_dropout_prob=0.1,
ffn_dim=256,
activation='relu',
device='cuda:0',
**kwargs,
) -> None:
super().__init__()
self.categorical_columns=categorical_columns
self.numerical_columns=numerical_columns
self.binary_columns=binary_columns
if categorical_columns is not None:
self.categorical_columns = list(set(categorical_columns))
if numerical_columns is not None:
self.numerical_columns = list(set(numerical_columns))
if binary_columns is not None:
self.binary_columns = list(set(binary_columns))
if feature_extractor is None:
feature_extractor = TransTabFeatureExtractor(
categorical_columns=self.categorical_columns,
numerical_columns=self.numerical_columns,
binary_columns=self.binary_columns,
**kwargs,
)
feature_processor = TransTabFeatureProcessor(
vocab_size=feature_extractor.vocab_size,
pad_token_id=feature_extractor.pad_token_id,
hidden_dim=hidden_dim,
hidden_dropout_prob=hidden_dropout_prob,
device=device,
)
self.input_encoder = TransTabInputEncoder(
feature_extractor=feature_extractor,
feature_processor=feature_processor,
device=device,
)
self.encoder = TransTabEncoder(
hidden_dim=hidden_dim,
num_layer=num_layer,
num_attention_head=num_attention_head,
hidden_dropout_prob=hidden_dropout_prob,
ffn_dim=ffn_dim,
activation=activation,
)
self.cls_token = TransTabCLSToken(hidden_dim=hidden_dim)
self.device = device
self.to(device)
[docs] def forward(self, x, y=None):
'''Extract the embeddings based on input tables.
Parameters
----------
x: pd.DataFrame
a batch of samples stored in pd.DataFrame.
y: pd.Series
the corresponding labels for each sample in ``x``. ignored for the basemodel.
Returns
-------
final_cls_embedding: torch.Tensor
the [CLS] embedding at the end of transformer encoder.
'''
embeded = self.input_encoder(x)
embeded = self.cls_token(**embeded)
# go through transformers, get final cls embedding
encoder_output = self.encoder(**embeded)
# get cls token
final_cls_embedding = encoder_output[:,0,:]
return final_cls_embedding
[docs] def load(self, ckpt_dir):
'''Load the model state_dict and feature_extractor configuration
from the ``ckpt_dir``.
Parameters
----------
ckpt_dir: str
the directory path to load.
Returns
-------
None
'''
# load model weight state dict
model_name = os.path.join(ckpt_dir, constants.WEIGHTS_NAME)
state_dict = torch.load(model_name, map_location='cpu')
missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False)
logger.info(f'missing keys: {missing_keys}')
logger.info(f'unexpected keys: {unexpected_keys}')
logger.info(f'load model from {ckpt_dir}')
# load feature extractor
self.input_encoder.feature_extractor.load(os.path.join(ckpt_dir, constants.EXTRACTOR_STATE_DIR))
self.binary_columns = self.input_encoder.feature_extractor.binary_columns
self.categorical_columns = self.input_encoder.feature_extractor.categorical_columns
self.numerical_columns = self.input_encoder.feature_extractor.numerical_columns
[docs] def save(self, ckpt_dir):
'''Save the model state_dict and feature_extractor configuration
to the ``ckpt_dir``.
Parameters
----------
ckpt_dir: str
the directory path to save.
Returns
-------
None
'''
# save model weight state dict
if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir, exist_ok=True)
state_dict = self.state_dict()
torch.save(state_dict, os.path.join(ckpt_dir, constants.WEIGHTS_NAME))
if self.input_encoder.feature_extractor is not None:
self.input_encoder.feature_extractor.save(ckpt_dir)
# save the input encoder separately
state_dict_input_encoder = self.input_encoder.state_dict()
torch.save(state_dict_input_encoder, os.path.join(ckpt_dir, constants.INPUT_ENCODER_NAME))
return None
[docs] def update(self, config):
'''Update the configuration of feature extractor's column map for cat, num, and bin cols.
Or update the number of classes for the output classifier layer.
Parameters
----------
config: dict
a dict of configurations: keys cat:list, num:list, bin:list are to specify the new column names;
key num_class:int is to specify the number of classes for finetuning on a new dataset.
Returns
-------
None
'''
col_map = {}
for k,v in config.items():
if k in ['cat','num','bin']: col_map[k] = v
self.input_encoder.feature_extractor.update(**col_map)
self.binary_columns = self.input_encoder.feature_extractor.binary_columns
self.categorical_columns = self.input_encoder.feature_extractor.categorical_columns
self.numerical_columns = self.input_encoder.feature_extractor.numerical_columns
if 'num_class' in config:
num_class = config['num_class']
self._adapt_to_new_num_class(num_class)
return None
def _check_column_overlap(self, cat_cols=None, num_cols=None, bin_cols=None):
all_cols = []
if cat_cols is not None: all_cols.extend(cat_cols)
if num_cols is not None: all_cols.extend(num_cols)
if bin_cols is not None: all_cols.extend(bin_cols)
org_length = len(all_cols)
unq_length = len(list(set(all_cols)))
duplicate_cols = [item for item, count in collections.Counter(all_cols).items() if count > 1]
return org_length == unq_length, duplicate_cols
def _solve_duplicate_cols(self, duplicate_cols):
for col in duplicate_cols:
logger.warning('Find duplicate cols named `{col}`, will ignore it during training!')
if col in self.categorical_columns:
self.categorical_columns.remove(col)
self.categorical_columns.append(f'[cat]{col}')
if col in self.numerical_columns:
self.numerical_columns.remove(col)
self.numerical_columns.append(f'[num]{col}')
if col in self.binary_columns:
self.binary_columns.remove(col)
self.binary_columns.append(f'[bin]{col}')
def _adapt_to_new_num_class(self, num_class):
if num_class != self.num_class:
self.num_class = num_class
self.clf = TransTabLinearClassifier(num_class, hidden_dim=self.cls_token.hidden_dim)
self.clf.to(self.device)
if self.num_class > 2:
self.loss_fn = nn.CrossEntropyLoss(reduction='none')
else:
self.loss_fn = nn.BCEWithLogitsLoss(reduction='none')
logger.info(f'Build a new classifier with num {num_class} classes outputs, need further finetune to work.')
[docs]class TransTabClassifier(TransTabModel):
'''The classifier model subclass from :class:`transtab.modeling_transtab.TransTabModel`.
Parameters
----------
categorical_columns: list
a list of categorical feature names.
numerical_columns: list
a list of numerical feature names.
binary_columns: list
a list of binary feature names, accept binary indicators like (yes,no); (true,false); (0,1).
feature_extractor: TransTabFeatureExtractor
a feature extractor to tokenize the input tables. if not passed the model will build itself.
num_class: int
number of output classes to be predicted.
hidden_dim: int
the dimension of hidden embeddings.
num_layer: int
the number of transformer layers used in the encoder.
num_attention_head: int
the numebr of heads of multihead self-attention layer in the transformers.
hidden_dropout_prob: float
the dropout ratio in the transformer encoder.
ffn_dim: int
the dimension of feed-forward layer in the transformer layer.
activation: str
the name of used activation functions, support ``"relu"``, ``"gelu"``, ``"selu"``, ``"leakyrelu"``.
device: str
the device, ``"cpu"`` or ``"cuda:0"``.
Returns
-------
A TransTabClassifier model.
'''
def __init__(self,
categorical_columns=None,
numerical_columns=None,
binary_columns=None,
feature_extractor=None,
num_class=2,
hidden_dim=128,
num_layer=2,
num_attention_head=8,
hidden_dropout_prob=0,
ffn_dim=256,
activation='relu',
device='cuda:0',
**kwargs,
) -> None:
super().__init__(
categorical_columns=categorical_columns,
numerical_columns=numerical_columns,
binary_columns=binary_columns,
feature_extractor=feature_extractor,
hidden_dim=hidden_dim,
num_layer=num_layer,
num_attention_head=num_attention_head,
hidden_dropout_prob=hidden_dropout_prob,
ffn_dim=ffn_dim,
activation=activation,
device=device,
**kwargs,
)
self.num_class = num_class
self.clf = TransTabLinearClassifier(num_class=num_class, hidden_dim=hidden_dim)
if self.num_class > 2:
self.loss_fn = nn.CrossEntropyLoss(reduction='none')
else:
self.loss_fn = nn.BCEWithLogitsLoss(reduction='none')
self.to(device)
[docs] def forward(self, x, y=None):
'''Make forward pass given the input feature ``x`` and label ``y`` (optional).
Parameters
----------
x: pd.DataFrame or dict
pd.DataFrame: a batch of raw tabular samples; dict: the output of TransTabFeatureExtractor.
y: pd.Series
the corresponding labels for each sample in ``x``. if label is given, the model will return
the classification loss by ``self.loss_fn``.
Returns
-------
logits: torch.Tensor
the [CLS] embedding at the end of transformer encoder.
loss: torch.Tensor or None
the classification loss.
'''
if isinstance(x, dict):
# input is the pre-tokenized encoded inputs
inputs = x
elif isinstance(x, pd.DataFrame):
# input is dataframe
inputs = self.input_encoder.feature_extractor(x)
else:
raise ValueError(f'TransTabClassifier takes inputs with dict or pd.DataFrame, find {type(x)}.')
outputs = self.input_encoder.feature_processor(**inputs)
outputs = self.cls_token(**outputs)
# go through transformers, get the first cls embedding
encoder_output = self.encoder(**outputs) # bs, seqlen+1, hidden_dim
# classifier
logits = self.clf(encoder_output)
if y is not None:
# compute classification loss
if self.num_class == 2:
y_ts = torch.tensor(y.values).to(self.device).float()
loss = self.loss_fn(logits.flatten(), y_ts)
else:
y_ts = torch.tensor(y.values).to(self.device).long()
loss = self.loss_fn(logits, y_ts)
loss = loss.mean()
else:
loss = None
return logits, loss
[docs]class TransTabForCL(TransTabModel):
'''The contrasstive learning model subclass from :class:`transtab.modeling_transtab.TransTabModel`.
Parameters
----------
categorical_columns: list
a list of categorical feature names.
numerical_columns: list
a list of numerical feature names.
binary_columns: list
a list of binary feature names, accept binary indicators like (yes,no); (true,false); (0,1).
feature_extractor: TransTabFeatureExtractor
a feature extractor to tokenize the input tables. if not passed the model will build itself.
hidden_dim: int
the dimension of hidden embeddings.
num_layer: int
the number of transformer layers used in the encoder.
num_attention_head: int
the numebr of heads of multihead self-attention layer in the transformers.
hidden_dropout_prob: float
the dropout ratio in the transformer encoder.
ffn_dim: int
the dimension of feed-forward layer in the transformer layer.
projection_dim: int
the dimension of projection head on the top of encoder.
overlap_ratio: float
the overlap ratio of columns of different partitions when doing subsetting.
num_partition: int
the number of partitions made for vertical-partition contrastive learning.
supervised: bool
whether or not to take supervised VPCL, otherwise take self-supervised VPCL.
temperature: float
temperature used to compute logits for contrastive learning.
base_temperature: float
base temperature used to normalize the temperature.
activation: str
the name of used activation functions, support ``"relu"``, ``"gelu"``, ``"selu"``, ``"leakyrelu"``.
device: str
the device, ``"cpu"`` or ``"cuda:0"``.
Returns
-------
A TransTabForCL model.
'''
def __init__(self,
categorical_columns=None,
numerical_columns=None,
binary_columns=None,
feature_extractor=None,
hidden_dim=128,
num_layer=2,
num_attention_head=8,
hidden_dropout_prob=0,
ffn_dim=256,
projection_dim=128,
overlap_ratio=0.1,
num_partition=2,
supervised=True,
temperature=10,
base_temperature=10,
activation='relu',
device='cuda:0',
**kwargs,
) -> None:
super().__init__(
categorical_columns=categorical_columns,
numerical_columns=numerical_columns,
binary_columns=binary_columns,
feature_extractor=feature_extractor,
hidden_dim=hidden_dim,
num_layer=num_layer,
num_attention_head=num_attention_head,
hidden_dropout_prob=hidden_dropout_prob,
ffn_dim=ffn_dim,
activation=activation,
device=device,
**kwargs,
)
assert num_partition > 0, f'number of contrastive subsets must be greater than 0, got {num_partition}'
assert isinstance(num_partition,int), f'number of constrative subsets must be int, got {type(num_partition)}'
assert overlap_ratio >= 0 and overlap_ratio < 1, f'overlap_ratio must be in [0, 1), got {overlap_ratio}'
self.projection_head = TransTabProjectionHead(hidden_dim, projection_dim)
self.cross_entropy_loss = nn.CrossEntropyLoss()
self.temperature = temperature
self.base_temperature = base_temperature
self.num_partition = num_partition
self.overlap_ratio = overlap_ratio
self.supervised = supervised
self.device = device
self.to(device)
[docs] def forward(self, x, y=None):
'''Make forward pass given the input feature ``x`` and label ``y`` (optional).
Parameters
----------
x: pd.DataFrame or dict
pd.DataFrame: a batch of raw tabular samples; dict: the output of TransTabFeatureExtractor.
y: pd.Series
the corresponding labels for each sample in ``x``. if label is given, the model will return
the classification loss by ``self.loss_fn``.
Returns
-------
logits: None
this CL model does NOT return logits.
loss: torch.Tensor
the supervised or self-supervised VPCL loss.
'''
# do positive sampling
feat_x_list = []
if isinstance(x, pd.DataFrame):
sub_x_list = self._build_positive_pairs(x, self.num_partition)
for sub_x in sub_x_list:
# encode two subset feature samples
feat_x = self.input_encoder(sub_x)
feat_x = self.cls_token(**feat_x)
feat_x = self.encoder(**feat_x)
feat_x_proj = feat_x[:,0,:] # take cls embedding
feat_x_proj = self.projection_head(feat_x_proj) # bs, projection_dim
feat_x_list.append(feat_x_proj)
elif isinstance(x, dict):
# pretokenized inputs
for input_x in x['input_sub_x']:
feat_x = self.input_encoder.feature_processor(**input_x)
feat_x = self.cls_token(**feat_x)
feat_x = self.encoder(**feat_x)
feat_x_proj = feat_x[:, 0, :]
feat_x_proj = self.projection_head(feat_x_proj)
feat_x_list.append(feat_x_proj)
else:
raise ValueError(f'expect input x to be pd.DataFrame or dict(pretokenized), get {type(x)} instead')
feat_x_multiview = torch.stack(feat_x_list, axis=1) # bs, n_view, emb_dim
if y is not None and self.supervised:
# take supervised loss
y = torch.tensor(y.values, device=feat_x_multiview.device)
loss = self.supervised_contrastive_loss(feat_x_multiview, y)
else:
# compute cl loss (multi-view InfoNCE loss)
loss = self.self_supervised_contrastive_loss(feat_x_multiview)
return None, loss
def _build_positive_pairs(self, x, n):
x_cols = x.columns.tolist()
sub_col_list = np.array_split(np.array(x_cols), n)
len_cols = len(sub_col_list[0])
overlap = int(np.ceil(len_cols * (self.overlap_ratio)))
sub_x_list = []
for i, sub_col in enumerate(sub_col_list):
if overlap > 0 and i < n-1:
sub_col = np.concatenate([sub_col, sub_col_list[i+1][:overlap]])
elif overlap >0 and i == n-1:
sub_col = np.concatenate([sub_col, sub_col_list[i-1][-overlap:]])
sub_x = x.copy()[sub_col]
sub_x_list.append(sub_x)
return sub_x_list
def cos_sim(self, a, b):
if not isinstance(a, torch.Tensor):
a = torch.tensor(a)
if not isinstance(b, torch.Tensor):
b = torch.tensor(b)
if len(a.shape) == 1:
a = a.unsqueeze(0)
if len(b.shape) == 1:
b = b.unsqueeze(0)
a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
return torch.mm(a_norm, b_norm.transpose(0, 1))
[docs] def self_supervised_contrastive_loss(self, features):
'''Compute the self-supervised VPCL loss.
Parameters
----------
features: torch.Tensor
the encoded features of multiple partitions of input tables, with shape ``(bs, n_partition, proj_dim)``.
Returns
-------
loss: torch.Tensor
the computed self-supervised VPCL loss.
'''
batch_size = features.shape[0]
labels = torch.arange(batch_size, dtype=torch.long, device=self.device).view(-1,1)
mask = torch.eq(labels, labels.T).float().to(labels.device)
contrast_count = features.shape[1]
# [[0,1],[2,3]] -> [0,2,1,3]
contrast_feature = torch.cat(torch.unbind(features,dim=1),dim=0)
anchor_feature = contrast_feature
anchor_count = contrast_count
anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), self.temperature)
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
mask = mask.repeat(anchor_count, contrast_count)
logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(batch_size * anchor_count).view(-1, 1).to(features.device), 0)
mask = mask * logits_mask
# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()
return loss
[docs] def supervised_contrastive_loss(self, features, labels):
'''Compute the supervised VPCL loss.
Parameters
----------
features: torch.Tensor
the encoded features of multiple partitions of input tables, with shape ``(bs, n_partition, proj_dim)``.
labels: torch.Tensor
the class labels to be used for building positive/negative pairs in VPCL.
Returns
-------
loss: torch.Tensor
the computed VPCL loss.
'''
labels = labels.contiguous().view(-1,1)
batch_size = features.shape[0]
mask = torch.eq(labels, labels.T).float().to(labels.device)
contrast_count = features.shape[1]
contrast_feature = torch.cat(torch.unbind(features,dim=1),dim=0)
# contrast_mode == 'all'
anchor_feature = contrast_feature
anchor_count = contrast_count
# compute logits
anchor_dot_contrast = torch.div(
torch.matmul(anchor_feature, contrast_feature.T),
self.temperature)
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
# tile mask
mask = mask.repeat(anchor_count, contrast_count)
# mask-out self-contrast cases
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(batch_size * anchor_count).view(-1, 1).to(features.device),
0,
)
mask = mask * logits_mask
# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()
return loss