[笔记]Tensorflow-Lesson11_生成式对抗网路(Generative Adversarial Netwo

前言

这次使用了之前介绍的CNN模型下去修改。主要参考[1]李弘毅老师的影片(内容图文并茂),和[3]是属于比较少图片说明,但两者其实大同小异,如果喜欢看公式可直接看[3],喜欢图片解讲可看[1]。

GAN

想法

假设https://chart.googleapis.com/chart?cht=tx&chl=%24%24P_%7Bdata%7D(x)%24%24,是输入资料的分布,但我们无法得知实际的资料分布,而假设https://chart.googleapis.com/chart?cht=tx&chl=%24%24P_%7Bg%7D(x%5Ei%3B%5Ctheta)%24%24是一个任意函数来近似https://chart.googleapis.com/chart?cht=tx&chl=%24%24P_%7Bdata%7D(x)%24%24,所以必须找到最大theta参数。这里使用[1]图片来解释。

1.将https://chart.googleapis.com/chart?cht=tx&chl=%24%24%20P_g%20%24%24机率相乘得到产生的机率,取log转为指数。
2.指数可变为相加方便处理。
3.转为期望,而上述其实与max为https://chart.googleapis.com/chart?cht=tx&chl=%24%24%20P_%7Bdata%7D%20%24%24产生出x机率乘上https://chart.googleapis.com/chart?cht=tx&chl=%24%24%20P_g%20%24%24意思是一样的。
4.转为连续机率分布,再减上https://chart.googleapis.com/chart?cht=tx&chl=%24%24%20P_%7Bdata%7D(x)log...%20%24%24,这并不会影响结果,因为https://chart.googleapis.com/chart?cht=tx&chl=%24%24%20P_%7Bdata%7D%20%24%24是已知分布且固定的,可视为常数。
5.转为KL散度,第三和四步骤能说是为了转为KL分布计算。所以minKL散度(计算的KL为负数转正求min)即是max原式。

在上一章VAE讲到KL散度就是再度量两者分布,其实也能直接写出第五步骤解释要度量两者分布。

http://img2.58codes.com/2024/20110564pP72oOZszw.png
来源[1]。

GAN

这想法是可行的,在VAE讲过在高维度当中要求出theta还是很难的,然而GAN利用最后计算出的结果来衡量。定义loss公式如下。

http://img2.58codes.com/2024/20110564eCOtMLPO16.png
来源[2]。

以直观角度来看,对于Discriminator就是让原先资料辨识结果越高越好并且生成资料辨识越低越好,反知对于Generator就是让生成资料辨识越高越好(原先资料辨识不影响)。

推导

接着使用数学证明上述公式为何能当作loss。

Max Discriminator

首先将Discriminator最大化,而做这一步就能很明显知道为何这式子可当loss。

1.当max D时,固定住G则会变为第一式。
2.转为连续机率。
3.整理公式。
http://img2.58codes.com/2024/20110564OeL27SLreN.png
来源[3]。

假设https://chart.googleapis.com/chart?cht=tx&chl=%24%24%20P_%7Bdata%7D(x)%3Da%20%24%24https://chart.googleapis.com/chart?cht=tx&chl=%24%24%20P_G(x)%3Db%20%24%24则要最大化的公式如下。
来源[3]。

1.将a和b带入。
2.对D求偏微分(log偏微分公式带入),偏微分即是求出最大化。
3.整理公式。
4.将a和b带回原先的分布。

http://img2.58codes.com/2024/20110564iZssAZzdKE.png
来源[3]。

1.将最大化D带入原式,右边1-D所以分子扣掉Pdata剩下Pg。
2.将分子分母除以2,因是常数并不影响。
3.将两边分子的1/2提到最前面,则结果会变为2个KL,而这两个KL其实就是JS散度(也能说是对称性KL),简单来说就是一个计算分布差异的公式。
http://img2.58codes.com/2024/20110564oAvpDAdAlt.png
来源[3]。

而前面的-2log2是常数可以忽略,所以由此得知max D就是使用JS散度计算,这样就知道loss是有意义的。

Min Generator

而Min G,则只需要对https://chart.googleapis.com/chart?cht=tx&chl=%24%24log(1%20-%20D(G(z)))%24%24,因dG与左边无关可忽略。而这里要注意的是[2]提到不要最小化https://chart.googleapis.com/chart?cht=tx&chl=%24%24log(1%20-%20D(G(z)))%24%24,而是最大化https://chart.googleapis.com/chart?cht=tx&chl=%24%24D(G(z))%24%24,如下图。

http://img2.58codes.com/2024/20110564m5JxiNN8o8.png
来源[2]。

在[1]也有说明,其实主要是训练时梯度下降的关西,最小化一开始会下降很慢,而最大化一开始则不会。如下图。
http://img2.58codes.com/2024/20110564yFEIbba9Yn.png
来源[1]。

loss公式

对于https://chart.googleapis.com/chart?cht=tx&chl=%24%24%20P_%7Bdata%7D%20%24%24https://chart.googleapis.com/chart?cht=tx&chl=%24%24%20P_g%20%24%24我们是不知道真实分布,所以我们只能产生出图片带入计算。也就是真实图片与Generator图片带入Discriminator计算log的平均。

http://img2.58codes.com/2024/201105647WTtExMm2t.png
来源[1]。

程式码

使用之前CNN讲解的程式下去修改。训练需要较久时间,所以只使用50笔资料测试。

全域参数

learning_rate: 学习率。
batch_size: 批次训练数量。
train_times: 训练次数。
train_step: 验证步伐。
D_param: discriminator网路层所有相关权重,为了更新用。
G_param: generator网路层所有相关权重,为了更新用。
discriminator_conv: discriminator捲基层数量。
discriminator_output_size: discriminator输出数量。
generator_input_size: generator输入数量。
generator_conv: generator捲基层数量。
generator_output_size: generator输出数量。

learning_rate = 0.0001batch_size = 10train_times = 100000train_step = 1D_param = []G_param = []# [filter size, filter height, filter weight, filter depth]discriminator_conv1_size = [3, 3, 1, 11]discriminator_conv2_size = [3, 3, 11, 13]discriminator_hide3_size = [7 * 7 * 13, 1024]discriminator_output_size = 1generator_input_size = 20generator_conv1_size = [3, 3, 1, 13]generator_conv2_size = [3, 3, 13, 11]generator_hide3_size = [generator_input_size * 11, 1024]generator_output_size = 28 * 28

批次规一化函数

全链结层使用

def layer_batch_norm(x, n_out, is_train):    beta = tf.get_variable("beta", [n_out], initializer=tf.ones_initializer())    gamma = tf.get_variable("gamma", [n_out], initializer=tf.ones_initializer())    batch_mean, batch_var = tf.nn.moments(x, [0], name='moments')    ema = tf.train.ExponentialMovingAverage(decay=0.9)    ema_apply_op = ema.apply([batch_mean, batch_var])    ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)    def mean_var_with_update():        with tf.control_dependencies([ema_apply_op]):            return tf.identity(batch_mean), tf.identity(batch_var)    mean, var = tf.cond(is_train, mean_var_with_update, lambda:(ema_mean, ema_var))    x_r = tf.reshape(x, [-1, 1, 1, n_out])    normed = tf.nn.batch_norm_with_global_normalization(x_r, mean, var, beta, gamma, 1e-3, True)    return tf.reshape(normed, [-1, n_out])

捲积层使用

def conv_batch_norm(x, n_out, train):    beta = tf.get_variable("beta", [n_out], initializer=tf.constant_initializer(value=0.0, dtype=tf.float32))    gamma = tf.get_variable("gamma", [n_out], initializer=tf.constant_initializer(value=1.0, dtype=tf.float32))        batch_mean, batch_var = tf.nn.moments(x, [0,1,2], name='moments')    ema = tf.train.ExponentialMovingAverage(decay=0.9)    ema_apply_op = ema.apply([batch_mean, batch_var])    ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)    def mean_var_with_update():        with tf.control_dependencies([ema_apply_op]):            return tf.identity(batch_mean), tf.identity(batch_var)    mean, var = tf.cond(train, mean_var_with_update, lambda:(ema_mean, ema_var))    normed = tf.nn.batch_norm_with_global_normalization(x, mean, var, beta, gamma, 1e-3, True)    mean_hist = tf.summary.histogram("meanHistogram", mean)    var_hist = tf.summary.histogram("varHistogram", var)    return normed

捲积层和全链结层

捲积层
这里多了type去判断是哪个神经网路的参数。活化函数使用relu

def conv2d(input, weight_shape, type='D'):    size = weight_shape[0] * weight_shape[1] * weight_shape[2]    weights_init = tf.random_normal_initializer(stddev=np.sqrt(2. / size))    biases_init = tf.zeros_initializer()    weights = tf.get_variable(name="weights", shape=weight_shape, initializer=weights_init)    biases = tf.get_variable(name="biases", shape=weight_shape[3], initializer=biases_init)    conv_out = tf.nn.conv2d(input, weights, strides=[1, 1, 1, 1], padding='SAME')    conv_add = tf.nn.bias_add(conv_out, biases)    conv_batch = conv_batch_norm(conv_add, weight_shape[3], tf.constant(True, dtype=tf.bool))    output = tf.nn.relu(conv_batch)    if type == 'D':        D_param.append(weights)        D_param.append(biases)    elif type == 'G':        G_param.append(weights)        G_param.append(biases)    return output

全链结层
这里多了type去判断是哪个神经网路的参数。和activation能选择不同的活化函数。

def layer(x, weights_shape, activation='relu', type='D'):    init = tf.random_normal_initializer(stddev=np.sqrt(2. / weights_shape[0]))    weights = tf.get_variable(name="weights", shape=weights_shape, initializer=init)    biases = tf.get_variable(name="biases", shape=weights_shape[1], initializer=init)    mat_add = tf.matmul(x, weights) + biases    #mat_add = layer_batch_norm(mat_add, weights_shape[1], tf.constant(True, dtype=tf.bool))    if activation == 'relu':        output = tf.nn.relu(mat_add)    elif activation == 'sigmoid':        output = tf.nn.sigmoid(mat_add)    elif activation == 'softplus':        output = tf.nn.softplus(mat_add)    else:        output = mat_add    if type == 'D':        D_param.append(weights)        D_param.append(biases)    elif type == 'G':        G_param.append(weights)        G_param.append(biases)    return output

Discriminator

这里与之前CNN神经网路雷同。而这里我选择使用sigmoid活化函数(有些使用tensorflow的cross...函数去做可不用)为了让losslog能运算。

def discriminator(x):    x = tf.reshape(x, shape=[-1, 28, 28, 1])    with tf.variable_scope("discriminator", reuse=tf.AUTO_REUSE):        with tf.variable_scope("conv1", reuse=tf.AUTO_REUSE):            conv1_out = conv2d(x, discriminator_conv1_size)            pool1_out = max_pool(conv1_out)        with tf.variable_scope("conv2", reuse=tf.AUTO_REUSE):            conv2_out = conv2d(pool1_out, discriminator_conv2_size)            pool2_out = max_pool(conv2_out)        with tf.variable_scope("hide3", reuse=tf.AUTO_REUSE):            pool2_flat = tf.reshape(pool2_out, [-1, discriminator_hide3_size[0]])            hide3_out = layer(pool2_flat, discriminator_hide3_size, activation='softplus')            #hide3_drop = tf.nn.dropout(hide3_out,keep_drop)        with tf.variable_scope("output"):            output = layer(hide3_out, [discriminator_hide3_size[1], discriminator_output_size], activation='sigmoid')    return output

Generator

这里与之前的CNN也雷同,这里第一行reshape则是上述参数的generator输入。

def generator(x):    x = tf.reshape(x, shape=[-1, generator_input_size, 1, 1])    with tf.variable_scope("generator", reuse=tf.AUTO_REUSE):        with tf.variable_scope("conv1", reuse=tf.AUTO_REUSE):            conv1_out = conv2d(x, generator_conv1_size, type='G')        with tf.variable_scope("conv2", reuse=tf.AUTO_REUSE):            conv2_out = conv2d(conv1_out, generator_conv2_size, type='G')        with tf.variable_scope("hide3", reuse=tf.AUTO_REUSE):            conv2_flat = tf.reshape(conv2_out, [-1, generator_hide3_size[0]])            hide3_out = layer(conv2_flat, generator_hide3_size, activation='softplus', type='G')        with tf.variable_scope("output", reuse=tf.AUTO_REUSE):            output = layer(hide3_out, [generator_hide3_size[1], generator_output_size], activation='sigmoid', type='G')    return output

损失函数

损失函数这里将它分为两个,一个用来训练Discriminator一个用来训练Generator,将上述推导公式带入即可。

def discriminator_loss(D_x, D_G):    loss =  -tf.reduce_mean(tf.log(D_x + 1e-12) + tf.log(1. - D_G + 1e-12))    loss_his = tf.summary.scalar("discriminator_loss", loss)    return lossdef generator_loss(D_G):    loss = -tf.reduce_mean(tf.log(D_G + 1e-12))    loss_his = tf.summary.scalar("generator_loss", loss)    return loss

验证函数

这里验证主要存为图片观看。

def image_summary(label, image_data):    reshap_data = tf.reshape(image_data, [-1, 28, 28, 1])    tf.summary.image(label, reshap_data, batch_size)def accuracy(G_z):    image_summary("G_z_image", G_z)

训练函数

def discriminator_train(loss, index):    return tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.9, beta2=0.999, epsilon=1e-12).minimize(loss, global_step=index, var_list=D_param)def generator_train(loss, index):    return tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.9, beta2=0.999, epsilon=1e-12).minimize(loss, global_step=index, var_list=G_param)

训练

这里为了让Generator快速产生结果,而训练多次。然而训练过程有时会不错有时会很糟糕,能感受的到GAN是个不稳定模型。

if __name__ == '__main__':    # init    mnist = input_data.read_data_sets("MNIST/", one_hot=True)    input_x = tf.placeholder(tf.float32, shape=[None, 784], name="input_x")    input_z = tf.placeholder(tf.float32, shape=[None, generator_input_size], name="input_z")    # predict    D_x_op = discriminator(input_x)    G_z_op  = generator(input_z)    D_G_op  = discriminator(G_z_op)    # loss    discriminator_loss_op = discriminator_loss(D_x_op, D_G_op)    generator_loss_op = generator_loss(D_G_op)    # train    discriminator_index = tf.Variable(0, name="discriminator_train_time")    discriminator_train_op = discriminator_train(discriminator_loss_op, discriminator_index)    generator_index = tf.Variable(0, name="generator_train_time")    generator_train_op = generator_train(generator_loss_op, generator_index)    # accuracy    accuracy(G_z_op)    # graph    summary_op = tf.summary.merge_all()    session = tf.Session()    summary_writer = tf.summary.FileWriter("log/", graph=session.graph)    init_value = tf.global_variables_initializer()    session.run(init_value)    saver = tf.train.Saver()    sample_z = np.random.uniform(-1., 1., (mnist.train.num_examples, generator_input_size))    D_avg_loss = 1.        while D_avg_loss > 0.001:        total_batch = 1        for i in range(total_batch):            minibatch_x = mnist.train.images[i * batch_size: (i + 1) * batch_size]            data = sample_z[i * batch_size: (i + 1) * batch_size]            session.run(discriminator_train_op, feed_dict={input_x: minibatch_x, input_z: data})            D_avg_loss = session.run(discriminator_loss_op, feed_dict={input_x: minibatch_x, input_z: data})    for time in range(train_times):        D_avg_loss = 0.        G_avg_loss = 1.1        total_batch = 1        for i in range(total_batch):            minibatch_x = mnist.train.images[i * batch_size: (i + 1) * batch_size]            data = sample_z[i * batch_size: (i + 1) * batch_size]            session.run(discriminator_train_op, feed_dict={input_x: minibatch_x, input_z: data})            D_avg_loss = session.run(discriminator_loss_op, feed_dict={input_x: minibatch_x, input_z: data})            for k in range(7 + 5 * int(time / 500)):                session.run(generator_train_op, feed_dict={input_x: minibatch_x, input_z: data})                G_avg_loss = session.run(generator_loss_op, feed_dict={input_x: minibatch_x, input_z: data})            last_loss = 99.            over_time = 0            while G_avg_loss > 1. and over_time < 10:                if last_loss < G_avg_loss:                    over_time += 1                last_loss = G_avg_loss                session.run(generator_train_op, feed_dict={input_x: minibatch_x, input_z: data})                G_avg_loss = session.run(generator_loss_op, feed_dict={input_x: minibatch_x, input_z: data})            if ((total_batch * time) + i + 1) % train_step == 0:                data = sample_z[0:batch_size]                image_summary("G_z_image", session.run(G_z_op, feed_dict={input_z: data}))                summary_str = session.run(summary_op, feed_dict={input_x: mnist.validation.images[:batch_size], input_z: data})                summary_writer.add_summary(summary_str, session.run(generator_index))                print("train times:", ((total_batch * time) + i + 1),                            " D_avg_loss:", session.run(discriminator_loss_op, feed_dict={input_x: minibatch_x, input_z: data}),                            " G_avg_loss:", session.run(generator_loss_op, feed_dict={input_x: minibatch_x, input_z: data}))    session.close()

结果

训练一下即可看到成果(1~10分)。实际训练会随机产生出新的乱数,这里单纯测试所以固定产生,加快看到结果。
http://img2.58codes.com/2024/201105644rS99YtDfb.png
http://img2.58codes.com/2024/20110564oQmXt1qKez.png
http://img2.58codes.com/2024/20110564KZWyfaOrq0.png
http://img2.58codes.com/2024/20110564pwurREHipK.png

结语

一开始尝试直接使用全数据和乱数训练,但训练时间太长,且也没有一个值判断目前的训练情况,而训练值也不好调整,训练起来与先前的网路相比算是一大挑战,但这几年许多人将GAN模型修改,而Wasserstein GAN(WGAN)是其中一个突破,能得知目前的结果是好还是坏,未来有机会还会介绍WGAN,若文章有误欢迎纠正讨论。

参考文献

[1] 李宏毅(2018) GAN Lecture 4 (2018): Basic Theory from: GAN Lecture 4 (2018): Basic Theory
[2]Ian J. Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, Yoshua Bengio, "Generative Adversarial Nets" arXiv:1406.2661, Jun. 2014.
[3] Sherlock(2018). GAN的数学推导 from: https://zhuanlan.zhihu.com/p/27536143


关于作者: 网站小编

码农网专注IT技术教程资源分享平台,学习资源下载网站,58码农网包含计算机技术、网站程序源码下载、编程技术论坛、互联网资源下载等产品服务,提供原创、优质、完整内容的专业码农交流分享平台。

热门文章