utils_.py
1    import matplotlib.pyplot as plt
2    import numpy as np
3    from numpy.linalg import norm
4    import warnings
5    from tensorly.tenalg import mode_dot
6    from sklearn.ensemble import BaggingClassifier
7    from sklearn.tree import DecisionTreeClassifier
8    from sklearn.model_selection import RepeatedStratifiedKFold
9    from sklearn.model_selection import GridSearchCV
10   from sklearn.metrics import classification_report
11   import tensorly as tl
12   
13   
14   
15   
16   # Setting up default parameters.
17   def sparam(rs=False):
18       warnings.filterwarnings("ignore")
19       plt.rcParams['figure.dpi'] = 144
20       np.random.seed(40)
21       if rs: return 40
22   
23   def ndot(a,b,mode): return mode_dot(a,b,mode);
24   
25   ####_________________ T2 plot functions _________________####
26   def plt_score(score1, title1, score2, title2, y, y_title):
27       y_normd = (y - np.mean(y, axis=0)) / np.std(y, axis=0)
28       _ = plt.figure(figsize=(9, 5))
29       plt.subplot(131), plt.title(title1), plt.imshow(score1)
30       plt.subplot(132), plt.title(title2), plt.imshow(score2)
31       plt.subplot(133), plt.title(y_title), plt.imshow(y_normd)
32       plt.suptitle('Score Matrices', size=15)
33       plt.show()
34   
35   def plt_emex(em, ex):
36       plt.figure(figsize=(10, 3))
37       plt.subplot(121), plt.title('Emissions')
38       _ = plt.plot(np.arange(250, 451), em), plt.xlabel('nm')
39       plt.subplot(122), plt.title('Excitations')
40       _ = plt.plot(np.arange(240, 301), ex), plt.xlabel('nm')
41       plt.show()
42   ####_____________________________________________________####
43   
44   
45   ####________________ T3 helper functions ________________####
46   def tk_rec(core, fact):
47       X_rec = ndot(ndot(ndot(core,
48                              fact[0], mode=0),
49                         fact[1], mode=1),
50                    fact[2], mode=2)
51       return X_rec
52   
53   def plt_sbs(img1, title1, img2, title2):
54       plt.figure(figsize=(6, 5))
55       plt.subplot(121), plt.title(title1)
56       plt.imshow(img1, cmap='gray'), plt.axis('off')
57       plt.subplot(122), plt.title(title2)
58       plt.imshow(img2, cmap='gray'), plt.axis('off')
59       plt.show()
60   
61   def corth(core, perc):
62       core_th = np.where(np.abs(core) > np.percentile(
63           np.abs(core), 100-perc), core, 0)
64       return core_th
65   
66   def err(X, core):
67       return np.sqrt((norm(X)**2 - norm(core)**2))/norm(X)
68   
69   def plt_th(X, fact, c1t, c2t, c3t, c4t, c5t):
70       err_ = err(X, c1t), err(X, c2t), err(X, c3t), \
71              err(X, c4t), err(X, c5t)
72       ptt = tk_rec(c1t,fact), tk_rec(c2t,fact), tk_rec(
73           c3t, fact), tk_rec(c4t, fact), tk_rec(c5t, fact)
74       fig, ax = plt.subplots(3, 5, figsize=(9, 7))
75       for i in range(5):
76           for j in range(3):
77               ax[j, i].imshow(ptt[i][:,:,j*5+4], cmap='gray',
78                               vmin=0, vmax=255)
79               ax[j, i].axis('off')
80           ax[0, i].set_title("Top %d0%% Kept\nError: %.4f" %
81                              (i+1,err_[i]))
82       fig.tight_layout()
83       fig.show()
84   
85   ####_____________________________________________________####
86   
87   
88   ####________________ T4 helper functions ________________####
89   
90   def l_img(trn, tst):
91       trn_tens, tst_tens = [], []
92       trn_tens_G, tst_tens_G = [], []
93       for file in trn:
94           img = plt.imread(file)
95           trn_tens.append(img.T)
96           trn_tens_G.append(toG(img).T)
97       for file in tst:
98           img = plt.imread(file)
99           tst_tens.append(img.T)
100          tst_tens_G.append(toG(img).T)
101      trn_tens = np.array(trn_tens).T
102      tst_tens = np.array(tst_tens).T
103      trn_tens_G = np.array(trn_tens_G).T
104      tst_tens_G = np.array(tst_tens_G).T
105  
106      return trn_tens, tst_tens, trn_tens_G, tst_tens_G
107  
108  def toG(img):
109      return np.dot(img, [0.299, 0.587, 0.114])
110  
111  def tBagger(trn_feat, tst_feat, trn_labs, tst_labs,
112              bg_params, rs=42):
113      rskf = RepeatedStratifiedKFold(
114          n_splits=4, n_repeats=6, random_state=rs)
115      dt_clf = DecisionTreeClassifier(
116          max_depth=5, min_samples_leaf=2, random_state=rs)
117      bg_clf = BaggingClassifier(
118          base_estimator=dt_clf, n_jobs=-1,
119          warm_start=True, random_state=rs)
120  
121      # Training and Testing Bagging
122      clf_BG_GS = GridSearchCV(bg_clf, bg_params, cv=rskf,
123                               scoring='accuracy', n_jobs=-1)
124  
125      clf_BG_GS.fit(trn_feat[:,1:], trn_labs)
126      print('Best Parameters: ', clf_BG_GS.best_params_)
127      print('Best Score: %.4f' % clf_BG_GS.best_score_,
128            '(4-fold Stratified CV Repeated 6 times.)\n')
129      tst_pred = clf_BG_GS.predict(tst_feat[:,1:])
130      print('Test Classification Report\n',
131            classification_report(tst_labs, tst_pred))
132  ####_____________________________________________________####
133  
134