Create your Gitee Account
Explore and code with more than 6 million developers,Free private repositories !:)
Sign up
Clone or download
utils.py 3.22 KB
Copy Edit Web IDE Raw Blame History
Honglie Chen authored 2020-05-19 23:26 . minor clean
import csv
import os
import torch
from torch.optim import *
import torchvision
from torchvision.transforms import *
from scipy import stats
from sklearn import metrics
import numpy as np
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class Logger(object):
def __init__(self, path, header):
self.log_file = open(path, 'w')
self.logger = csv.writer(self.log_file, delimiter='\t')
self.logger.writerow(header)
self.header = header
def __del(self):
self.log_file.close()
def log(self, values):
write_values = []
for col in self.header:
assert col in values
write_values.append(values[col])
self.logger.writerow(write_values)
self.log_file.flush()
def accuracy(output, target, topk=(1, 5)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res, pred
def reverseTransform(img):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
if len(img.shape) == 5:
for i in range(3):
img[:, i, :, :, :] = img[:, i, :, :, :]*std[i] + mean[i]
else:
for i in range(3):
img[:, i, :, :] = img[:, i, :, :]*std[i] + mean[i]
return img
def d_prime(auc):
standard_normal = stats.norm()
d_prime = standard_normal.ppf(auc) * np.sqrt(2.0)
return d_prime
def calculate_stats(output, target):
"""Calculate statistics including mAP, AUC, etc.
Args:
output: 2d array, (samples_num, classes_num)
target: 2d array, (samples_num, classes_num)
Returns:
stats: list of statistic of each class.
"""
classes_num = target.shape[-1]
stats = []
# Class-wise statistics
for k in range(classes_num):
# Average precision
avg_precision = metrics.average_precision_score(
target[:, k], output[:, k], average=None)
# AUC
auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None)
# Precisions, recalls
(precisions, recalls, thresholds) = metrics.precision_recall_curve(
target[:, k], output[:, k])
# FPR, TPR
(fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k])
save_every_steps = 1000 # Sample statistics to reduce size
dict = {'precisions': precisions[0::save_every_steps],
'recalls': recalls[0::save_every_steps],
'AP': avg_precision,
'fpr': fpr[0::save_every_steps],
'fnr': 1. - tpr[0::save_every_steps],
'auc': auc}
stats.append(dict)
return stats

Comment ( 0 )

Sign in for post a comment