|
| 1 | +import torch |
| 2 | +import numpy as np |
| 3 | +import os |
| 4 | +import random |
| 5 | +import logging |
| 6 | +from pathlib import Path |
| 7 | +import matplotlib.pyplot as plt |
| 8 | +from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay |
| 9 | +from PIL import Image |
| 10 | +import io |
| 11 | +from torchvision.transforms import transforms |
| 12 | + |
| 13 | +# ~~~~~~~~~~~~~~~~~~~~~ helper functions ~~~~~~~~~~~~~~~~~~~~~ # |
| 14 | + |
| 15 | + |
| 16 | +def setSeed(seed): |
| 17 | + """Sets seed for random, numpy, and torch modules. |
| 18 | +
|
| 19 | + Args: |
| 20 | + seed: a seed value. should be of type int. |
| 21 | + """ |
| 22 | + random.seed(seed) |
| 23 | + np.random.seed(seed) |
| 24 | + torch.manual_seed(seed) |
| 25 | + torch.cuda.manual_seed(seed) |
| 26 | + torch.backends.cudnn.deterministic = True |
| 27 | + |
| 28 | + |
| 29 | +def incremental_folder_name(base_dir='run', folder='train', need_path=True): |
| 30 | + """Creates a Folder and returns its path in base_dir/folder_{increment} format. |
| 31 | +
|
| 32 | + Args: |
| 33 | + base_dir: base directory. this should be relative |
| 34 | + to current python execution path. |
| 35 | +
|
| 36 | + folder: the folder name that needs to be incremented. |
| 37 | + like folder1, folder2, folder3. |
| 38 | + """ |
| 39 | + i = 1 |
| 40 | + while Path(f'{base_dir}/{folder}_{i}').resolve().exists(): |
| 41 | + i += 1 |
| 42 | + p = Path(f'{base_dir}/{folder}_{i}') |
| 43 | + p.mkdir(parents=True) |
| 44 | + if need_path: |
| 45 | + return str(p) |
| 46 | + else: |
| 47 | + return f'{folder}_{i}' |
| 48 | + |
| 49 | + |
| 50 | +def incremental_filename(baseDir='run', name='predictions', ext='png'): |
| 51 | + """Returns a file path in base_dir/name{increment}.ext format. |
| 52 | +
|
| 53 | + Args: |
| 54 | + base_dir: base directory. this should be relative |
| 55 | + to current python execution path. |
| 56 | +
|
| 57 | + name: the name of the file that needs to be incremented. |
| 58 | + like file1, file2, file3. |
| 59 | + """ |
| 60 | + |
| 61 | + i = 1 |
| 62 | + while Path(f'{baseDir}/{name}{i}.{ext}').resolve().exists(): |
| 63 | + i += 1 |
| 64 | + fileName = Path(f'{baseDir}/{name}{i}.{ext}').resolve() |
| 65 | + return str(fileName) |
| 66 | + |
| 67 | + |
| 68 | +def setConsoleLogger(loggerName): |
| 69 | + console = logging.StreamHandler() |
| 70 | + console.setLevel(logging.INFO) |
| 71 | + formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') |
| 72 | + # tell the handler to use this format |
| 73 | + console.setFormatter(formatter) |
| 74 | + # add the handler to the logger |
| 75 | + logging.getLogger(loggerName).addHandler(console) |
| 76 | + |
| 77 | + |
| 78 | +def setFileLogger(loggerName): |
| 79 | + logFilePath = Path(f'logs/{loggerName}.log').resolve() |
| 80 | + if not logFilePath.exists(): |
| 81 | + if not logFilePath.parent.exists(): |
| 82 | + logFilePath.parent.mkdir(parents=True) |
| 83 | + logFilePath.touch() |
| 84 | + # format for file logging |
| 85 | + formatter = logging.Formatter( |
| 86 | + fmt='%(asctime)s:%(name)s:%(levelname)s:%(message)s', |
| 87 | + datefmt='%m/%d/%Y %I:%M:%S %p') |
| 88 | + # file handler |
| 89 | + handler = logging.FileHandler(str(logFilePath), 'w') |
| 90 | + handler.setFormatter(formatter) |
| 91 | + handler.setLevel(logging.INFO) |
| 92 | + logging.getLogger(loggerName).addHandler(handler) |
| 93 | + |
| 94 | +# ~~~~~~~~~~~~~~~~~~~~~ logging functions ~~~~~~~~~~~~~~~~~~~~~ # |
| 95 | + |
| 96 | + |
| 97 | +def set_logger(logger_name): |
| 98 | + """Creates a logger object with given name. |
| 99 | +
|
| 100 | + Args: |
| 101 | + logger_name: str object, suggested usesage: |
| 102 | + set_logger(__name__). |
| 103 | + """ |
B41A
| 104 | + logger = logging.getLogger(logger_name) |
| 105 | + logger.setLevel(logging.DEBUG) |
| 106 | + setConsoleLogger(logger_name) |
| 107 | + setFileLogger(logger_name) |
| 108 | + return logger |
| 109 | + |
| 110 | + |
| 111 | +logger = set_logger(__name__) |
| 112 | + |
| 113 | + |
| 114 | +# ~~~~~~~~~~~~~~~~~~~~~ network helper functions ~~~~~~~~~~~~~~~~~~~~~ # |
| 115 | + |
| 116 | +def calculate_accuracy(y_hat, y): |
| 117 | + """Calculates the model accuracy. |
| 118 | +
|
| 119 | + Args: |
| 120 | + y_hat: output of the Model. Argmax will be taken |
| 121 | + in first dimension |
| 122 | +
|
| 123 | + y: ground Truth. a 1D array of actual values. |
| 124 | + """ |
| 125 | + top_pred = y_hat.argmax(1, keepdim=True) |
| 126 | + correct = top_pred.eq(y.view_as(top_pred)).sum() |
| 127 | + acc = correct.float() / y.shape[0] |
| 128 | + return acc |
| 129 | + |
| 130 | + |
| 131 | +def save_model(model, current_loss, lowest_loss, D='weights/'): |
| 132 | + """Save the model to given directory. saves both last.pt and best.pt. |
| 133 | +
|
| 134 | + Args: |
| 135 | + model: The model that needs to be saved.(nn.Module object) |
| 136 | +
|
| 137 | + current_loss: loss at current state.(float object) |
| 138 | +
|
| 139 | + lowest_loss: lowest loss seen till now. (float object) |
| 140 | +
|
| 141 | + D: directory to save weights. |
| 142 | + default is '{current python execution}/weights/' |
| 143 | + (str object) |
| 144 | + """ |
| 145 | + |
| 146 | + assert isinstance(D, str), ("expecting string object for path, " + |
| 147 | + f'received object of {type(D)}') |
| 148 | + |
| 149 | + D = Path(D) |
| 150 | + if not D.exists(): |
| 151 | + D.mkdir() |
| 152 | + last = str((D/'last.pt').resolve()) |
| 153 | + best = str((D/'best.pt').resolve()) |
| 154 | + |
| 155 | + torch.save(model.state_dict(), last) |
| 156 | + |
| 157 | + if current_loss < lowest_loss: |
| 158 | + torch.save(model.state_dict(), best) |
| 159 | + lowest_loss = current_loss |
| 160 | + |
| 161 | + return lowest_loss |
| 162 | + |
| 163 | + |
| 164 | +def create_confusion_matrix(y, y_hat, figsize=7): |
| 165 | + |
| 166 | + fig = plt.figure(figsize=(figsize, figsize)) |
| 167 | + ax = fig.add_subplot(1, 1, 1) |
| 168 | + cm = confusion_matrix(y, y_hat) |
| 169 | + # Normalize the confusion matrix. |
| 170 | + cm = np.around(cm.astype('float') / cm.sum(axis=1) |
| 171 | + [:, np.newaxis], decimals=2) |
| 172 | + cm = ConfusionMatrixDisplay(cm, display_labels=range(10)) |
| 173 | + cm.plot(cmap='Greys', ax=ax) |
| 174 | + buf = io.BytesIO() |
| 175 | + fig.savefig(buf, format='png') |
| 176 | + buf.seek(0) |
| 177 | + plt.close(fig) |
| 178 | + matrix = Image.open(buf) |
| 179 | + tr = transforms.ToTensor() |
| 180 | + matrix = tr(matrix) |
| 181 | + return matrix |
| 182 | + |
| 183 | + |
| 184 | +# ~~~~~~~~~~~~~~~~~~~~~ system checks ~~~~~~~~~~~~~~~~~~~~~ # |
| 185 | + |
| 186 | + |
| 187 | +def select_device(opt): |
| 188 | + """Returns a torch.device object based on the input. |
| 189 | +
|
| 190 | + Args: |
| 191 | + opt: a dictionary. should have a key called device. |
| 192 | + for example: opt['device'] = 0 for 'cuda:0' |
| 193 | + and 1 for 'cuda:1'. |
| 194 | +
|
| 195 | + incase of empty dictionary 'cuda:0' will be selected |
| 196 | + if available. |
| 197 | +
|
| 198 | + in both the cases 'cpu' will be selected if cuda is |
| 199 | + not available. |
| 200 | + """ |
| 201 | + # device = 'cpu' or '0' or '1' |
| 202 | + cpu = False |
| 203 | + s = 'cuda:0' |
| 204 | + if 'device' in opt.keys(): # device is in the opts |
| 205 | + device = opt['device'] |
| 206 | + if device: # device is not none |
| 207 | + cpu = device.lower() == 'cpu' # device is cpu |
| 208 | + if device and not cpu: # some device is requested and its not cpu |
| 209 | + cuda = torch.cuda.is_available() |
| 210 | + |
| 211 | + if cuda and int(device) in range(0, 5): |
| 212 | + # set environment variable |
| 213 | + os.environ['CUDA_VISIBLE_DEVICES'] = device |
| 214 | + s = f'cuda:{device}' |
| 215 | + else: |
| 216 | + |
| 217 | + logger.info( |
| 218 | + f'CUDA unavailable or invalid device: {device} requested') |
| 219 | + elif cpu: |
| 220 | + # force torch.cuda.is_available() = False |
| 221 | + os.environ['CUDA_VISIBLE_DEVICES'] = '-1' |
| 222 | + |
| 223 | + cuda = torch.cuda.is_available() and not cpu |
| 224 | + if not cuda: |
| 225 | + logger.info('using cpu for computation.') |
| 226 | + else: |
| 227 | + logger.info(f'using {s} for computation.') |
| 228 | + return torch.device(s if cuda else 'cpu') |
0 commit comments