predict
- transtab.predict(clf, x_test, y_test=None, return_loss=False, eval_batch_size=256)[source]
Make predictions by TransTabClassifier.
- Parameters
clf (TransTabClassifier) – the classifier model to make predictions.
x_test (pd.DataFrame) – input tabular data.
y_test (pd.Series) – target labels for input x_test. will be ignored if
return_loss=False
.return_loss (bool) – set True will return the loss if y_test is given.
eval_batch_size (int) – the batch size for inference.
- Returns
pred_all (np.array) – if
return_loss=False
, return the predictions made by TransTabClassifier.avg_loss (float) – if
return_loss=True
, return the mean loss of the predictions made by TransTabClassifier.