ax.flatten(): Transform n*m
to 1*nm
1-D Array
fig, ax = plt.subplots(nrows=2,ncols=2,sharex='all',sharey='all') ax = ax.flatten() for i in range(4): img = image[i].reshape(28, 28) ax[i].imshow(img, cmap='Greys', interpolation='nearest') # ax[i] is available
Without flatten()
fig, ax = plt.subplots(nrows=2,ncols=2,sharex='all',sharey='all') for i in range(4): img = image[i].reshape(28, 28) ax[0, 0].imshow(img, cmap='Greys', interpolation='nearest') ax[0, 1].imshow(img, cmap='Greys', interpolation='nearest') ax[1, 0].imshow(img, cmap='Greys', interpolation='nearest') ax[1, 1].imshow(img, cmap='Greys', interpolation='nearest') # ax[i] is unavailable
Convolutional Neural NetworksWe often use Flatten
, converting matrice to vectors.After flattening, then feed the vectors to Fully Connected Layers
.CNN
-> Pooling
-> CNN
-> Pooling
...-> Flatten
-> Fully Connected Layers
-> Softmax
-> Probabilities