10000 initial_commit · caffeine-coder1/computer_vision@caa2b43 · GitHub
[go: up one dir, main page]

Skip to content

Commit caa2b43

Browse files
initial_commit
1 parent 51cff82 commit caa2b43

File tree

1 file changed

+228
-0
lines changed

1 file changed

+228
-0
lines changed

GAN/WGAN/general.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
import torch
2+
import numpy as np
3+
import os
4+
import random
5+
import logging
6+
from pathlib import Path
7+
import matplotlib.pyplot as plt
8+
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
9+
from PIL import Image
10+
import io
11+
from torchvision.transforms import transforms
12+
13+
# ~~~~~~~~~~~~~~~~~~~~~ helper functions ~~~~~~~~~~~~~~~~~~~~~ #
14+
15+
16+
def setSeed(seed):
17+
"""Sets seed for random, numpy, and torch modules.
18+
19+
Args:
20+
seed: a seed value. should be of type int.
21+
"""
22+
random.seed(seed)
23+
np.random.seed(seed)
24+
torch.manual_seed(seed)
25+
torch.cuda.manual_seed(seed)
26+
torch.backends.cudnn.deterministic = True
27+
28+
29+
def incremental_folder_name(base_dir='run', folder='train', need_path=True):
30+
"""Creates a Folder and returns its path in base_dir/folder_{increment} format.
31+
32+
Args:
33+
base_dir: base directory. this should be relative
34+
to current python execution path.
35+
36+
folder: the folder name that needs to be incremented.
37+
like folder1, folder2, folder3.
38+
"""
39+
i = 1
40+
while Path(f'{base_dir}/{folder}_{i}').resolve().exists():
41+
i += 1
42+
p = Path(f'{base_dir}/{folder}_{i}')
43+
p.mkdir(parents=True)
44+
if need_path:
45+
return str(p)
46+
else:
47+
return f'{folder}_{i}'
48+
49+
50+
def incremental_filename(baseDir='run', name='predictions', ext='png'):
51+
"""Returns a file path in base_dir/name{increment}.ext format.
52+
53+
Args:
54+
base_dir: base directory. this should be relative
55+
to current python execution path.
56+
57+
name: the name of the file that needs to be incremented.
58+
like file1, file2, file3.
59+
"""
60+
61+
i = 1
62+
while Path(f'{baseDir}/{name}{i}.{ext}').resolve().exists():
63+
i += 1
64+
fileName = Path(f'{baseDir}/{name}{i}.{ext}').resolve()
65+
return str(fileName)
66+
67+
68+
def setConsoleLogger(loggerName):
69+
console = logging.StreamHandler()
70+
console.setLevel(logging.INFO)
71+
formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
72+
# tell the handler to use this format
73+
console.setFormatter(formatter)
74+
# add the handler to the logger
75+
logging.getLogger(loggerName).addHandler(console)
76+
77+
78+
def setFileLogger(loggerName):
79+
logFilePath = Path(f'logs/{loggerName}.log').resolve()
80+
if not logFilePath.exists():
81+
if not logFilePath.parent.exists():
82+
logFilePath.parent.mkdir(parents=True)
83+
logFilePath.touch()
84+
# format for file logging
85+
formatter = logging.Formatter(
86+
fmt='%(asctime)s:%(name)s:%(levelname)s:%(message)s',
87+
datefmt='%m/%d/%Y %I:%M:%S %p')
88+
# file handler
89+
handler = logging.FileHandler(str(logFilePath), 'w')
90+
handler.setFormatter(formatter)
91+
handler.setLevel(logging.INFO)
92+
logging.getLogger(loggerName).addHandler(handler)
93+
94+
# ~~~~~~~~~~~~~~~~~~~~~ logging functions ~~~~~~~~~~~~~~~~~~~~~ #
95+
96+
97+
def set_logger(logger_name):
98+
"""Creates a logger object with given name.
99+
100+
Args:
101+
logger_name: str object, suggested usesage:
102+
set_logger(__name__).
103+
"""
B41A 104+
logger = logging.getLogger(logger_name)
105+
logger.setLevel(logging.DEBUG)
106+
setConsoleLogger(logger_name)
107+
setFileLogger(logger_name)
108+
return logger
109+
110+
111+
logger = set_logger(__name__)
112+
113+
114+
# ~~~~~~~~~~~~~~~~~~~~~ network helper functions ~~~~~~~~~~~~~~~~~~~~~ #
115+
116+
def calculate_accuracy(y_hat, y):
117+
"""Calculates the model accuracy.
118+
119+
Args:
120+
y_hat: output of the Model. Argmax will be taken
121+
in first dimension
122+
123+
y: ground Truth. a 1D array of actual values.
124+
"""
125+
top_pred = y_hat.argmax(1, keepdim=True)
126+
correct = top_pred.eq(y.view_as(top_pred)).sum()
127+
acc = correct.float() / y.shape[0]
128+
return acc
129+
130+
131+
def save_model(model, current_loss, lowest_loss, D='weights/'):
132+
"""Save the model to given directory. saves both last.pt and best.pt.
133+
134+
Args:
135+
model: The model that needs to be saved.(nn.Module object)
136+
137+
current_loss: loss at current state.(float object)
138+
139+
lowest_loss: lowest loss seen till now. (float object)
140+
141+
D: directory to save weights.
142+
default is '{current python execution}/weights/'
143+
(str object)
144+
"""
145+
146+
assert isinstance(D, str), ("expecting string object for path, " +
147+
f'received object of {type(D)}')
148+
149+
D = Path(D)
150+
if not D.exists():
151+
D.mkdir()
152+
last = str((D/'last.pt').resolve())
153+
best = str((D/'best.pt').resolve())
154+
155+
torch.save(model.state_dict(), last)
156+
157+
if current_loss < lowest_loss:
158+
torch.save(model.state_dict(), best)
159+
lowest_loss = current_loss
160+
161+
return lowest_loss
162+
163+
164+
def create_confusion_matrix(y, y_hat, figsize=7):
165+
166+
fig = plt.figure(figsize=(figsize, figsize))
167+
ax = fig.add_subplot(1, 1, 1)
168+
cm = confusion_matrix(y, y_hat)
169+
# Normalize the confusion matrix.
170+
cm = np.around(cm.astype('float') / cm.sum(axis=1)
171+
[:, np.newaxis], decimals=2)
172+
cm = ConfusionMatrixDisplay(cm, display_labels=range(10))
173+
cm.plot(cmap='Greys', ax=ax)
174+
buf = io.BytesIO()
175+
fig.savefig(buf, format='png')
176+
buf.seek(0)
177+
plt.close(fig)
178+
matrix = Image.open(buf)
179+
tr = transforms.ToTensor()
180+
matrix = tr(matrix)
181+
return matrix
182+
183+
184+
# ~~~~~~~~~~~~~~~~~~~~~ system checks ~~~~~~~~~~~~~~~~~~~~~ #
185+
186+
187+
def select_device(opt):
188+
"""Returns a torch.device object based on the input.
189+
190+
Args:
191+
opt: a dictionary. should have a key called device.
192+
for example: opt['device'] = 0 for 'cuda:0'
193+
and 1 for 'cuda:1'.
194+
195+
incase of empty dictionary 'cuda:0' will be selected
196+
if available.
197+
198+
in both the cases 'cpu' will be selected if cuda is
199+
not available.
200+
"""
201+
# device = 'cpu' or '0' or '1'
202+
cpu = False
203+
s = 'cuda:0'
204+
if 'device' in opt.keys(): # device is in the opts
205+
device = opt['device']
206+
if device: # device is not none
207+
cpu = device.lower() == 'cpu' # device is cpu
208+
if device and not cpu: # some device is requested and its not cpu
209+
cuda = torch.cuda.is_available()
210+
211+
if cuda and int(device) in range(0, 5):
212+
# set environment variable
213+
os.environ['CUDA_VISIBLE_DEVICES'] = device
214+
s = f'cuda:{device}'
215+
else:
216+
217+
logger.info(
218+
f'CUDA unavailable or invalid device: {device} requested')
219+
elif cpu:
220+
# force torch.cuda.is_available() = False
221+
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
222+
223+
cuda = torch.cuda.is_available() and not cpu
224+
if not cuda:
225+
logger.info('using cpu for computation.')
226+
else:
227+
logger.info(f'using {s} for computation.')
228+
return torch.device(s if cuda else 'cpu')

0 commit comments

Comments
 (0)
0