GANを試してみました!
話のきっかけ
弊社内で研修を行っているのですが、私はテーマを『AI』と設定し(普段の業務柄……)、画像認識の基本の話をしようかな〜と思って準備を進めようとしていたところ、とある会議にてこんな話題が出たんです。
「今GANとか流行ってるもんね〜。研修でそれ面白いかも。」
……ちょっと待ってください。私、GANをやるとは一言も言ってません!!!
などと半分冗談交じりにで聞いて、半分その言葉に完全に乗せられてしまい……今考えると、絶対に生半可な気持ちでやるものではなかったと完全に後悔していますが、とりあえず手を付けてしまったわけで。。。
題材は、JSPも会場の一つとなってる 東海道らぐ のマスコットキャラ、東海りなちゃんを学習させてみました。この東海りなちゃんは、元々 カスタムキャスト というアプリで生成されたものなのですが、そんな画像をAIくんは学習できるのか、試してみよう!というのが今回のお話です。
ちなみに私自身は普段TensorFlowのソースコードを書くような仕事はしていますが、画像認識についてはほぼ初心者ですので、その点踏まえた上で読んでいただけたらと思います。
そもそもGANってなに???
『敵対的生成ネットワーク』とも呼ばれます。
……と、中途半端な気持ちで始めた私が説明してもあまり説得力がないと思いますので、ここでは割愛させていただきます。。。
仕組み
学習モデルとして、GeneratorとDiscriminatorを用意します。
- Generator : 画像を生成するモデル
- Discriminator : 画像が本物か偽物かを見分けるモデル
この二つのモデルを交互に学習させて、徐々に互いを強くしていこう!というのがGANの仕組みです。
どちらかが強すぎても、学習がうまくいかないです。
……はい、本当にうまく行きません(汗) ←一応これでもかなり苦労しました。
ソースコード
ソースコードはTensorFlow ver1.14を使用して書いています。
本当はv2.0.0のEagerモードを使用して書いたほうが、コードも見やすくなりそうですが。。。(今回は研修の題材でもあるので)
Generator
class Generator: def __init__(self, x_size, y_size, z_dim): self.x_size = x_size self.y_size = y_size self.z_dim = z_dim # all-1 self.W_fc1 = tf.Variable(tf.random_normal([z_dim, 1028], stddev=0.02)) self.b_fc1 = tf.Variable(tf.zeros([1028])) # all-2 self.W_fc2 = tf.Variable(tf.random_normal([1028, (9 * 9 * 64)], stddev=0.02)) self.b_fc2 = tf.Variable(tf.zeros([(9 * 9 * 64)])) # conv2d-1 self.W_conv1 = tf.Variable(tf.random_normal([18, 18, 32, 64], stddev=0.02)) # conv2d-2 self.W_conv2 = tf.Variable(tf.random_normal([x_size, y_size, 3, 32], stddev=0.02)) def var_list(self): return [self.W_fc1, self.b_fc1, self.W_fc2, self.b_fc2, self.W_conv1, self.W_conv2] def __call__(self, batch_size=1): """ :return: """ inputs = tf.random_uniform([batch_size, self.z_dim], minval=-1.0, maxval=1.0) h_fc1 = tf.matmul(inputs, self.W_fc1) + self.b_fc1 h_fc1 = tf.nn.leaky_relu(self.batch_norm(h_fc1, [0, 1])) h_fc2 = tf.matmul(h_fc1, self.W_fc2) + self.b_fc2 h_fc2 = tf.nn.leaky_relu(self.batch_norm(h_fc2, [0, 1])) h_fc2 = tf.nn.dropout(h_fc2, 0.5) h_conv_in = tf.reshape(h_fc2, [-1, 9, 9, 64]) # 畳み込み層1 h_conv1 = self.conv2d_transpose(h_conv_in, self.W_conv1, [18, 18, 32]) h_conv1 = tf.nn.leaky_relu(self.batch_norm(h_conv1, [0, 1])) # 畳み込み層2 img = self.conv2d_transpose(h_conv1, self.W_conv2, [self.x_size, self.y_size, 3]) y_conv = tf.reshape(tf.math.tanh(img), [-1, (self.x_size * self.y_size * 3)]) return y_conv @staticmethod def batch_norm(x, axes): mean, var = tf.nn.moments(x, axes) with tf.control_dependencies([mean, var]): return tf.nn.batch_normalization(x, mean, var, None, None, 1e-5) @staticmethod def conv2d_transpose(x, W, shape): """ 畳み込み層 :param x: :param W: :return: """ output_shape = [tf.shape(x)[0], shape[0], shape[1], shape[2]] return tf.nn.conv2d_transpose(x, W, output_shape=output_shape, strides=[1, 2, 2, 1], padding='SAME')
ポイントとなるメソッドは、以下の二つです。
- tf.nn.conv2d_transpose()
- tf.nn.batch_normalization()
Discriminator
class Discriminator: def __init__(self, x_size, y_size): """ コンストラクタ """ self.x_size = x_size self.y_size = y_size # conv2d-1 self.W_conv1 = tf.Variable(tf.truncated_normal([4, 4, 3, 32], stddev=0.1)) self.b_conv1 = tf.Variable(tf.constant(0.1, shape=[32])) # conv2d-2 self.W_conv2 = tf.Variable(tf.truncated_normal([9, 9, 32, 64], stddev=0.1)) self.b_conv2 = tf.Variable(tf.constant(0.1, shape=[64])) # all-1 self.W_fc1 = tf.Variable(tf.truncated_normal([(9 * 9 * 64), 1024], stddev=0.1)) self.b_fc1 = tf.Variable(tf.constant(0.1, shape=[1024])) # all-2 self.W_fc2 = tf.Variable(tf.truncated_normal([1024, 2], stddev=0.1)) self.b_fc2 = tf.Variable(tf.constant(0.1, shape=[2])) def var_list(self): return [self.W_conv1, self.b_conv1, self.W_conv2, self.b_conv2, self.W_fc1, self.b_fc1, self.W_fc2, self.b_fc2] def __call__(self, images_placeholder, y_labels, keep_prob): """ モデル作成 :param images_placeholder: :param labels: :param keep_prob: :return: """ # 画像を行列に変換 x_image = tf.reshape(images_placeholder, [-1, self.x_size, self.y_size, 3]) # 畳み込み層1 h_conv1 = tf.nn.relu(self.conv2d(x_image, self.W_conv1) + self.b_conv1) # プーリング層1 h_pool1 = self.max_pool_2x2(h_conv1) # 畳み込み層2 h_conv2 = tf.nn.relu(self.conv2d(h_pool1, self.W_conv2) + self.b_conv2) # プーリング層2 h_pool2 = self.max_pool_2x2(h_conv2) # 全結合層1 h_pool2_flat = tf.reshape(h_pool2, [-1, 9*9*64]) h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, self.W_fc1) + self.b_fc1) h_fc1 = tf.nn.dropout(h_fc1, keep_prob) # 特徴量→各ラベルの確立へ変換 y = tf.matmul(h_fc1, self.W_fc2) + self.b_fc2 loss = tf.losses.softmax_cross_entropy(onehot_labels=y_labels, logits=y) return loss @staticmethod def conv2d(x, W): """ 畳み込み層 :param x: :param W: :return: """ return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') @staticmethod def max_pool_2x2(x): """ プーリング層 :param x: :return: """ return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
Discriminatorの方はと言うと、実は普段MNISTなどでおなじみの画像認識のモデルほぼそのままです。今回このモデルもベースは私が書いたのではなく、一緒に研修に協力していただく後輩さんに書いてもらったコードほぼそのままとなっています。(多少パラメータなどは変えていますけど)
main
D_real_labels = np.ones(batch_size) D_fake_labels = np.zeros(batch_size) G_labels = np.ones(batch_size) # placeholder Z = tf.placeholder(tf.float32, shape=[None, Z_dim], name='Z') images = tf.placeholder(tf.float32, shape=[None, image_data_size], name='X') labels = tf.placeholder(tf.uint8, shape=[None], name='Y') Y = tf.one_hot(labels, depth=2, dtype=tf.float32) l_rate = 0.00001 learning_rate = tf.placeholder(tf.float32) # Model generator = Generator(image_x, image_y, Z_dim) discriminator = Discriminator(image_x, image_y) # Setup D_loss = d_train(discriminator, images, Y) D_solver = tf.train.AdamOptimizer(learning_rate=l_rate).minimize(D_loss, var_list=discriminator.var_list()) G = generator(batch_size) G_loss = g_train(G, discriminator, Y) G_solver = tf.train.AdamOptimizer(learning_rate=l_rate).minimize(G_loss, var_list=generator.var_list()) saver = tf.train.Saver() session_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)) with tf.Session(config = session_config) as sess: sess.run(tf.global_variables_initializer()) # 学習プロセス開始 for itr in range(epoch_num): if itr % 1000 == 0: try: dirname = "../output/%06d/" % itr if not os.path.isdir(dirname): os.makedirs(dirname) for i in range(9): gene_image = generator().eval() img_obj = gene_image.reshape(image_x, image_y, 3) * 255.0 file_name = dirname + "%02d.png" % i cv2.imwrite(file_name, img_obj) except: pass # 学習率の更新 l_rate *= 1.1 print("Learning rate Change!! -> %.6f" % l_rate) # Discriminator-Real rand_data = [] rand_idx = np.random.randint(0, len(train_data), batch_size) for i in rand_idx: # ノイズを追加してランダムに並び替え gauss = np.random.normal(0, 0.005, train_data[i].shape) train_image_gs = train_data[i] + gauss rand_data.append(train_data[i]) train_np_data = np.asarray(rand_data) sess.run([D_solver, D_loss], feed_dict={images: train_np_data, labels: D_real_labels, learning_rate: l_rate}) # Generator & Discriminator-Fake for i in range(2): _, gene_image, G_loss_curr = sess.run([G_solver, G, G_loss], feed_dict={labels: G_labels, learning_rate: l_rate}) _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={images: gene_image, labels: D_fake_labels, learning_rate: l_rate}) if itr % 100 == 0: print('Iter: {}'.format(itr)) print(" D loss: " + str(D_loss_curr)) print(" G_loss: " + str(G_loss_curr))
はい、上記モデルを呼び出すメイン処理はこんな具合です。TensorFlow1.x だと tf.session() などを使用するため、いろいろ入り組んでてわかりにくいかもしれませんね。
学習率は動的に変更しています……が、これすることに意味があったのか???
Generator と Descrimetor の学習のしかたが肝のようで、ここにたどり着くまで色々苦労したのですが、この方法だとなんだかうまく行ってます。
で、結局どんな画像ができたの??
一番最初にアップした東海りなの画像を40枚ほど作成し、それを学習させました。本来なら40枚ってかなり少ないので、もう少し増やすべきだっただろ!!というのもあるのですが、今回のこの実験の目的はあくまで研修のための実験で、『学習枚数が少ない状態でいかに学習させるか』が狙いだったため、本来ならそれを無視した形でもう少しまともな学習モデルを作れたのでは?と思ってます。
前置き長くなりましたが、出来上がった画像はこんな感じです!
おおっ、なんかそれっぽいぞ!!!(あくまで研修用なので36×36とめっちゃ小さな画像ですがw)
が、問題点!!
実は、同じような画像しかできない状態に陥ってますorz この状態のことを『Mode Collapse』と呼ばれているらしいですが、流石に今回は時間切れのような気がしてきました(汗)
次回(一ヶ月後かな〜?)に向けて
簡単にざっくりとしか書いていませんが、ここまで来るのに相当な実験を要しました。発散しまくるわ学習が進まないわで……安易に手を出すものではなかったですね。
もう少し手直しして、もっと大きな画像を作れるようにしていこうと思います。その内容は後日またここで紹介させていただきますので、お楽しみに!!