predict

transtab.predict(clf, x_test, y_test=None, return_loss=False, eval_batch_size=256)[source]

Make predictions by TransTabClassifier.

Parameters
  • clf (TransTabClassifier) – the classifier model to make predictions.

  • x_test (pd.DataFrame) – input tabular data.

  • y_test (pd.Series) – target labels for input x_test. will be ignored if return_loss=False.

  • return_loss (bool) – set True will return the loss if y_test is given.

  • eval_batch_size (int) – the batch size for inference.

Returns

  • pred_all (np.array) – if return_loss=False, return the predictions made by TransTabClassifier.

  • avg_loss (float) – if return_loss=True, return the mean loss of the predictions made by TransTabClassifier.