Data notes
46 subscribers
59 photos
6 videos
2 files
123 links
My data science notes
Download Telegram
Forwarded from Aspiring Data Science (Anatoly Alekseev)
#optimization #ml #metrics #python #numba #codegems

В общем, sklearn-овские метрики оказались слишком медленными, пришлось их переписать на numba. Вот пример classification_report, который работает в тысячу раз быстрее и поддерживает почти всю функциональность (кроме весов и микровзвешивания). Также оптимизировал метрики auc (алгоритм взят из fastauc) и calibration (считаю бины предсказанные vs реальные, потом mae/std от их разностей). На 8M сэмплов всё работает за ~30 миллисекунд кроме auc, та ~300 мс. Для сравнения, scikit-learn-овские работают от нескольких секунд до нескольких десятков секунд.

@njit()
def fast_classification_report(
y_true: np.ndarray, y_pred: np.ndarray, nclasses: int = 2, zero_division: int = 0
):
"""Custom classification report, proof of concept."""

N_AVG_ARRAYS = 3 # precisions, recalls, f1s

# storage inits
weighted_averages = np.empty(N_AVG_ARRAYS, dtype=np.float64)
macro_averages = np.empty(N_AVG_ARRAYS, dtype=np.float64)
supports = np.zeros(nclasses, dtype=np.int64)
allpreds = np.zeros(nclasses, dtype=np.int64)
misses = np.zeros(nclasses, dtype=np.int64)
hits = np.zeros(nclasses, dtype=np.int64)

# count stats
for true_class, predicted_class in zip(y_true, y_pred):
supports[true_class] += 1
allpreds[predicted_class] += 1
if predicted_class == true_class:
hits[predicted_class] += 1
else:
misses[predicted_class] += 1

# main calcs
accuracy = hits.sum() / len(y_true)
balanced_accuracy = np.nan_to_num(hits / supports, copy=True, nan=zero_division).mean()

recalls = hits / supports
precisions = hits / allpreds
f1s = 2 * (precisions * recalls) / (precisions + recalls)

# fix nans & compute averages
for arr in (precisions, recalls, f1s):
np.nan_to_num(arr, copy=False, nan=zero_division)
weighted_averages[i] = (arr * supports).sum() / len(y_true)
macro_averages[i] = arr.mean()

return hits, misses, accuracy, balanced_accuracy, supports, precisions, recalls, f1s, macro_averages, weighted_averages