前言
这次使用了之前介绍的CNN模型下去修改。主要参考[1]李弘毅老师的影片(内容图文并茂),和[3]是属于比较少图片说明,但两者其实大同小异,如果喜欢看公式可直接看[3],喜欢图片解讲可看[1]。
GAN
想法
假设,是输入资料的分布,但我们无法得知实际的资料分布,而假设
是一个任意函数来近似
,所以必须找到最大theta参数。这里使用[1]图片来解释。
1.将机率相乘得到产生的机率,取log转为指数。
2.指数可变为相加方便处理。
3.转为期望,而上述其实与max为产生出x机率乘上
意思是一样的。
4.转为连续机率分布,再减上,这并不会影响结果,因为
是已知分布且固定的,可视为常数。
5.转为KL散度,第三和四步骤能说是为了转为KL分布计算。所以minKL散度(计算的KL为负数转正求min)即是max原式。
在上一章VAE讲到KL散度就是再度量两者分布,其实也能直接写出第五步骤解释要度量两者分布。
来源[1]。
GAN
这想法是可行的,在VAE讲过在高维度当中要求出theta还是很难的,然而GAN利用最后计算出的结果来衡量。定义loss公式如下。
来源[2]。
以直观角度来看,对于Discriminator就是让原先资料辨识结果越高越好并且生成资料辨识越低越好,反知对于Generator就是让生成资料辨识越高越好(原先资料辨识不影响)。
推导
接着使用数学证明上述公式为何能当作loss。
Max Discriminator
首先将Discriminator最大化,而做这一步就能很明显知道为何这式子可当loss。
1.当max D时,固定住G则会变为第一式。
2.转为连续机率。
3.整理公式。
来源[3]。
假设、
则要最大化的公式如下。
来源[3]。
1.将a和b带入。
2.对D求偏微分(log偏微分公式带入),偏微分即是求出最大化。
3.整理公式。
4.将a和b带回原先的分布。
来源[3]。
1.将最大化D带入原式,右边1-D所以分子扣掉Pdata剩下Pg。
2.将分子分母除以2,因是常数并不影响。
3.将两边分子的1/2提到最前面,则结果会变为2个KL,而这两个KL其实就是JS散度(也能说是对称性KL),简单来说就是一个计算分布差异的公式。
来源[3]。
而前面的-2log2是常数可以忽略,所以由此得知max D就是使用JS散度计算,这样就知道loss是有意义的。
Min Generator
而Min G,则只需要对,因dG与左边无关可忽略。而这里要注意的是[2]提到不要最小化
,而是最大化
,如下图。
来源[2]。
在[1]也有说明,其实主要是训练时梯度下降的关西,最小化一开始会下降很慢,而最大化一开始则不会。如下图。
来源[1]。
loss公式
对于和
我们是不知道真实分布,所以我们只能产生出图片带入计算。也就是真实图片与Generator图片带入Discriminator计算log的平均。
来源[1]。
程式码
使用之前CNN讲解的程式下去修改。训练需要较久时间,所以只使用50笔资料测试。
全域参数
learning_rate: 学习率。
batch_size: 批次训练数量。
train_times: 训练次数。
train_step: 验证步伐。
D_param: discriminator网路层所有相关权重,为了更新用。
G_param: generator网路层所有相关权重,为了更新用。
discriminator_conv: discriminator捲基层数量。
discriminator_output_size: discriminator输出数量。
generator_input_size: generator输入数量。
generator_conv: generator捲基层数量。
generator_output_size: generator输出数量。
learning_rate = 0.0001batch_size = 10train_times = 100000train_step = 1D_param = []G_param = []# [filter size, filter height, filter weight, filter depth]discriminator_conv1_size = [3, 3, 1, 11]discriminator_conv2_size = [3, 3, 11, 13]discriminator_hide3_size = [7 * 7 * 13, 1024]discriminator_output_size = 1generator_input_size = 20generator_conv1_size = [3, 3, 1, 13]generator_conv2_size = [3, 3, 13, 11]generator_hide3_size = [generator_input_size * 11, 1024]generator_output_size = 28 * 28
批次规一化函数
全链结层使用
def layer_batch_norm(x, n_out, is_train): beta = tf.get_variable("beta", [n_out], initializer=tf.ones_initializer()) gamma = tf.get_variable("gamma", [n_out], initializer=tf.ones_initializer()) batch_mean, batch_var = tf.nn.moments(x, [0], name='moments') ema = tf.train.ExponentialMovingAverage(decay=0.9) ema_apply_op = ema.apply([batch_mean, batch_var]) ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var) def mean_var_with_update(): with tf.control_dependencies([ema_apply_op]): return tf.identity(batch_mean), tf.identity(batch_var) mean, var = tf.cond(is_train, mean_var_with_update, lambda:(ema_mean, ema_var)) x_r = tf.reshape(x, [-1, 1, 1, n_out]) normed = tf.nn.batch_norm_with_global_normalization(x_r, mean, var, beta, gamma, 1e-3, True) return tf.reshape(normed, [-1, n_out])
捲积层使用
def conv_batch_norm(x, n_out, train): beta = tf.get_variable("beta", [n_out], initializer=tf.constant_initializer(value=0.0, dtype=tf.float32)) gamma = tf.get_variable("gamma", [n_out], initializer=tf.constant_initializer(value=1.0, dtype=tf.float32)) batch_mean, batch_var = tf.nn.moments(x, [0,1,2], name='moments') ema = tf.train.ExponentialMovingAverage(decay=0.9) ema_apply_op = ema.apply([batch_mean, batch_var]) ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var) def mean_var_with_update(): with tf.control_dependencies([ema_apply_op]): return tf.identity(batch_mean), tf.identity(batch_var) mean, var = tf.cond(train, mean_var_with_update, lambda:(ema_mean, ema_var)) normed = tf.nn.batch_norm_with_global_normalization(x, mean, var, beta, gamma, 1e-3, True) mean_hist = tf.summary.histogram("meanHistogram", mean) var_hist = tf.summary.histogram("varHistogram", var) return normed
捲积层和全链结层
捲积层
这里多了type
去判断是哪个神经网路的参数。活化函数使用relu
。
def conv2d(input, weight_shape, type='D'): size = weight_shape[0] * weight_shape[1] * weight_shape[2] weights_init = tf.random_normal_initializer(stddev=np.sqrt(2. / size)) biases_init = tf.zeros_initializer() weights = tf.get_variable(name="weights", shape=weight_shape, initializer=weights_init) biases = tf.get_variable(name="biases", shape=weight_shape[3], initializer=biases_init) conv_out = tf.nn.conv2d(input, weights, strides=[1, 1, 1, 1], padding='SAME') conv_add = tf.nn.bias_add(conv_out, biases) conv_batch = conv_batch_norm(conv_add, weight_shape[3], tf.constant(True, dtype=tf.bool)) output = tf.nn.relu(conv_batch) if type == 'D': D_param.append(weights) D_param.append(biases) elif type == 'G': G_param.append(weights) G_param.append(biases) return output
全链结层
这里多了type
去判断是哪个神经网路的参数。和activation
能选择不同的活化函数。
def layer(x, weights_shape, activation='relu', type='D'): init = tf.random_normal_initializer(stddev=np.sqrt(2. / weights_shape[0])) weights = tf.get_variable(name="weights", shape=weights_shape, initializer=init) biases = tf.get_variable(name="biases", shape=weights_shape[1], initializer=init) mat_add = tf.matmul(x, weights) + biases #mat_add = layer_batch_norm(mat_add, weights_shape[1], tf.constant(True, dtype=tf.bool)) if activation == 'relu': output = tf.nn.relu(mat_add) elif activation == 'sigmoid': output = tf.nn.sigmoid(mat_add) elif activation == 'softplus': output = tf.nn.softplus(mat_add) else: output = mat_add if type == 'D': D_param.append(weights) D_param.append(biases) elif type == 'G': G_param.append(weights) G_param.append(biases) return output
Discriminator
这里与之前CNN神经网路雷同。而这里我选择使用sigmoid
活化函数(有些使用tensorflow的cross...函数去做可不用)为了让loss
的log
能运算。
def discriminator(x): x = tf.reshape(x, shape=[-1, 28, 28, 1]) with tf.variable_scope("discriminator", reuse=tf.AUTO_REUSE): with tf.variable_scope("conv1", reuse=tf.AUTO_REUSE): conv1_out = conv2d(x, discriminator_conv1_size) pool1_out = max_pool(conv1_out) with tf.variable_scope("conv2", reuse=tf.AUTO_REUSE): conv2_out = conv2d(pool1_out, discriminator_conv2_size) pool2_out = max_pool(conv2_out) with tf.variable_scope("hide3", reuse=tf.AUTO_REUSE): pool2_flat = tf.reshape(pool2_out, [-1, discriminator_hide3_size[0]]) hide3_out = layer(pool2_flat, discriminator_hide3_size, activation='softplus') #hide3_drop = tf.nn.dropout(hide3_out,keep_drop) with tf.variable_scope("output"): output = layer(hide3_out, [discriminator_hide3_size[1], discriminator_output_size], activation='sigmoid') return output
Generator
这里与之前的CNN也雷同,这里第一行reshape
则是上述参数的generator输入。
def generator(x): x = tf.reshape(x, shape=[-1, generator_input_size, 1, 1]) with tf.variable_scope("generator", reuse=tf.AUTO_REUSE): with tf.variable_scope("conv1", reuse=tf.AUTO_REUSE): conv1_out = conv2d(x, generator_conv1_size, type='G') with tf.variable_scope("conv2", reuse=tf.AUTO_REUSE): conv2_out = conv2d(conv1_out, generator_conv2_size, type='G') with tf.variable_scope("hide3", reuse=tf.AUTO_REUSE): conv2_flat = tf.reshape(conv2_out, [-1, generator_hide3_size[0]]) hide3_out = layer(conv2_flat, generator_hide3_size, activation='softplus', type='G') with tf.variable_scope("output", reuse=tf.AUTO_REUSE): output = layer(hide3_out, [generator_hide3_size[1], generator_output_size], activation='sigmoid', type='G') return output
损失函数
损失函数这里将它分为两个,一个用来训练Discriminator一个用来训练Generator,将上述推导公式带入即可。
def discriminator_loss(D_x, D_G): loss = -tf.reduce_mean(tf.log(D_x + 1e-12) + tf.log(1. - D_G + 1e-12)) loss_his = tf.summary.scalar("discriminator_loss", loss) return lossdef generator_loss(D_G): loss = -tf.reduce_mean(tf.log(D_G + 1e-12)) loss_his = tf.summary.scalar("generator_loss", loss) return loss
验证函数
这里验证主要存为图片观看。
def image_summary(label, image_data): reshap_data = tf.reshape(image_data, [-1, 28, 28, 1]) tf.summary.image(label, reshap_data, batch_size)def accuracy(G_z): image_summary("G_z_image", G_z)
训练函数
def discriminator_train(loss, index): return tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.9, beta2=0.999, epsilon=1e-12).minimize(loss, global_step=index, var_list=D_param)def generator_train(loss, index): return tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.9, beta2=0.999, epsilon=1e-12).minimize(loss, global_step=index, var_list=G_param)
训练
这里为了让Generator快速产生结果,而训练多次。然而训练过程有时会不错有时会很糟糕,能感受的到GAN是个不稳定模型。
if __name__ == '__main__': # init mnist = input_data.read_data_sets("MNIST/", one_hot=True) input_x = tf.placeholder(tf.float32, shape=[None, 784], name="input_x") input_z = tf.placeholder(tf.float32, shape=[None, generator_input_size], name="input_z") # predict D_x_op = discriminator(input_x) G_z_op = generator(input_z) D_G_op = discriminator(G_z_op) # loss discriminator_loss_op = discriminator_loss(D_x_op, D_G_op) generator_loss_op = generator_loss(D_G_op) # train discriminator_index = tf.Variable(0, name="discriminator_train_time") discriminator_train_op = discriminator_train(discriminator_loss_op, discriminator_index) generator_index = tf.Variable(0, name="generator_train_time") generator_train_op = generator_train(generator_loss_op, generator_index) # accuracy accuracy(G_z_op) # graph summary_op = tf.summary.merge_all() session = tf.Session() summary_writer = tf.summary.FileWriter("log/", graph=session.graph) init_value = tf.global_variables_initializer() session.run(init_value) saver = tf.train.Saver() sample_z = np.random.uniform(-1., 1., (mnist.train.num_examples, generator_input_size)) D_avg_loss = 1. while D_avg_loss > 0.001: total_batch = 1 for i in range(total_batch): minibatch_x = mnist.train.images[i * batch_size: (i + 1) * batch_size] data = sample_z[i * batch_size: (i + 1) * batch_size] session.run(discriminator_train_op, feed_dict={input_x: minibatch_x, input_z: data}) D_avg_loss = session.run(discriminator_loss_op, feed_dict={input_x: minibatch_x, input_z: data}) for time in range(train_times): D_avg_loss = 0. G_avg_loss = 1.1 total_batch = 1 for i in range(total_batch): minibatch_x = mnist.train.images[i * batch_size: (i + 1) * batch_size] data = sample_z[i * batch_size: (i + 1) * batch_size] session.run(discriminator_train_op, feed_dict={input_x: minibatch_x, input_z: data}) D_avg_loss = session.run(discriminator_loss_op, feed_dict={input_x: minibatch_x, input_z: data}) for k in range(7 + 5 * int(time / 500)): session.run(generator_train_op, feed_dict={input_x: minibatch_x, input_z: data}) G_avg_loss = session.run(generator_loss_op, feed_dict={input_x: minibatch_x, input_z: data}) last_loss = 99. over_time = 0 while G_avg_loss > 1. and over_time < 10: if last_loss < G_avg_loss: over_time += 1 last_loss = G_avg_loss session.run(generator_train_op, feed_dict={input_x: minibatch_x, input_z: data}) G_avg_loss = session.run(generator_loss_op, feed_dict={input_x: minibatch_x, input_z: data}) if ((total_batch * time) + i + 1) % train_step == 0: data = sample_z[0:batch_size] image_summary("G_z_image", session.run(G_z_op, feed_dict={input_z: data})) summary_str = session.run(summary_op, feed_dict={input_x: mnist.validation.images[:batch_size], input_z: data}) summary_writer.add_summary(summary_str, session.run(generator_index)) print("train times:", ((total_batch * time) + i + 1), " D_avg_loss:", session.run(discriminator_loss_op, feed_dict={input_x: minibatch_x, input_z: data}), " G_avg_loss:", session.run(generator_loss_op, feed_dict={input_x: minibatch_x, input_z: data})) session.close()
结果
训练一下即可看到成果(1~10分)。实际训练会随机产生出新的乱数,这里单纯测试所以固定产生,加快看到结果。
结语
一开始尝试直接使用全数据和乱数训练,但训练时间太长,且也没有一个值判断目前的训练情况,而训练值也不好调整,训练起来与先前的网路相比算是一大挑战,但这几年许多人将GAN模型修改,而Wasserstein GAN(WGAN)是其中一个突破,能得知目前的结果是好还是坏,未来有机会还会介绍WGAN,若文章有误欢迎纠正讨论。
参考文献
[1] 李宏毅(2018) GAN Lecture 4 (2018): Basic Theory from: GAN Lecture 4 (2018): Basic Theory
[2]Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio, "Generative Adversarial Nets" arXiv:1406.2661, Jun. 2014.
[3] Sherlock(2018). GAN的数学推导 from: https://zhuanlan.zhihu.com/p/27536143