train

transtab.train(model, trainset, valset=None, num_epoch=10, batch_size=64, eval_batch_size=256, lr=0.0001, weight_decay=0, patience=5, warmup_ratio=None, warmup_steps=None, eval_metric='auc', output_dir='./ckpt', collate_fn=None, num_workers=0, balance_sample=False, load_best_at_last=True, ignore_duplicate_cols=False, eval_less_is_better=False, **kwargs)[source]

The shared train function for all TransTabModel based models.

Parameters
  • model (TransTabModel and its subclass) – A subclass of the base model. Should be able to output logits and loss in forward, e.g., logit, loss = model(x, y).

  • trainset (list or tuple) – a list of trainsets, or a single trainset consisting of (x, y). x: pd.DataFrame or dict, y: pd.Series.

  • valset (list or tuple) – a list of valsets, or a single valset of consisting of (x, y).

  • num_epoch (int) – number of training epochs.

  • batch_size (int) – training batch size.

  • eval_batch_size (int) – evaluation batch size.

  • lr (float) – training learning rate.

  • weight_decay (float) – training weight decay.

  • patience (int) – early stopping patience, only valid when valset is given.

  • warmup_ratio (float) – the portion of training steps for learning rate warmup, if warmup_steps is set, it will be ignored.

  • warmup_steps (int) – the number of training steps for learning rate warmup.

  • eval_metric (str) – the evaluation metric during training for early stopping, can be "acc", "auc", "mse", "val_loss".

  • output_dir (str) – the output training model weights and feature extractor configurations.

  • collate_fn (function) – specify training collate function if it is not standard supervised learning, e.g., contrastive learning.

  • num_workers (int) – the number of workers for the dataloader.

  • balance_sample (bool) – balance_sample: whether or not do bootstrapping to maintain in batch samples are in balanced classes, only support binary classification.

  • load_best_at_last (bool) – whether or not load the best checkpoint after the training completes.

  • ignore_duplicate_cols (bool) – whether or not ignore the contradictory of cat/num/bin cols

  • eval_less_is_better (bool) – if the set eval_metric is the less the better. For val_loss, it should be set True.

Returns

Return type

None