train
- transtab.train(model, trainset, valset=None, num_epoch=10, batch_size=64, eval_batch_size=256, lr=0.0001, weight_decay=0, patience=5, warmup_ratio=None, warmup_steps=None, eval_metric='auc', output_dir='./ckpt', collate_fn=None, num_workers=0, balance_sample=False, load_best_at_last=True, ignore_duplicate_cols=False, eval_less_is_better=False, **kwargs)[source]
The shared train function for all TransTabModel based models.
- Parameters
model (TransTabModel and its subclass) – A subclass of the base model. Should be able to output logits and loss in forward, e.g.,
logit, loss = model(x, y)
.trainset (list or tuple) – a list of trainsets, or a single trainset consisting of (x, y). x: pd.DataFrame or dict, y: pd.Series.
valset (list or tuple) – a list of valsets, or a single valset of consisting of (x, y).
num_epoch (int) – number of training epochs.
batch_size (int) – training batch size.
eval_batch_size (int) – evaluation batch size.
lr (float) – training learning rate.
weight_decay (float) – training weight decay.
patience (int) – early stopping patience, only valid when
valset
is given.warmup_ratio (float) – the portion of training steps for learning rate warmup, if warmup_steps is set, it will be ignored.
warmup_steps (int) – the number of training steps for learning rate warmup.
eval_metric (str) – the evaluation metric during training for early stopping, can be
"acc"
,"auc"
,"mse"
,"val_loss"
.output_dir (str) – the output training model weights and feature extractor configurations.
collate_fn (function) – specify training collate function if it is not standard supervised learning, e.g., contrastive learning.
num_workers (int) – the number of workers for the dataloader.
balance_sample (bool) – balance_sample: whether or not do bootstrapping to maintain in batch samples are in balanced classes, only support binary classification.
load_best_at_last (bool) – whether or not load the best checkpoint after the training completes.
ignore_duplicate_cols (bool) – whether or not ignore the contradictory of cat/num/bin cols
eval_less_is_better (bool) – if the set eval_metric is the less the better. For val_loss, it should be set True.
- Returns
- Return type
None