Smile Engineering Blog

ジェイエスピーからTipsや技術特集、プロジェクト物語を発信します

PyTorchモデルをバイナリストリーム化してモデル解剖もしてみよう

PyTorchのモデルをバイナリストリームで保存できる!?

最近PyTorchのドキュメントを眺めていたら、PyTorchのモデルがio.BytesIOで保存できることを知りました。
(以前からこんな記載あったかな? あまり記憶にない……)

# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.save(x, buffer)

pytorch.org

あまり使う機会はないかもしれませんが(ところが筆者の案件では実際使用する機会があったのだけど)、とりあえずまずは触って試してみましょう!

まずはモデル定義

今回はあくまで実験用のモデルなので、学習とかは考えず適当に組んでみます。
ほぼ何も考えず、簡単な2層のモデルを作ってみました。

# モデル定義
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(1, 4, 3, 1)
        self.fc1 = nn.Linear(4, 2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.fc1(x))
        return x


# モデル初期化
print("1. モデル初期化")
model = TheModelClass()
print(model, "\n")

PyTorchのモデル定義は nn.Moduleクラスを継承して、その中で学習モデルを定義していきます。
この辺りはTensorFlowよりも、Chainerの定義方法に近いですね。
(最近はChainerよりPyTorchを使用している方の方が多いのかな?)

出力結果

TheModelClass(
  (conv1): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=4, out_features=2, bias=True)
) 

今回はこのモデルをバイナリストリーム化してみようと思います。

例に従ってバイナリストリーム化します!

それではPyTorchのドキュメントの通り、バイナリストリーム化してみます。

out_buf = io.BytesIO()
torch.save(model, out_buf)
print(out_buf.getvalue(), "\n")

一番最初に、 io.BytesIO を out_buf という名前でオブジェクト化します。
その中に、torch.save() で出力していくという方法ですね。

io.BytesIO は getvalueメソッドを使用して、バイナリストリーム内のすべてを出力することができます。

出力結果

b'PK\x03\x04\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x10\x00\x12\x00archive/data.pklFB\x0e\x00ZZZZZZZZZZZZZZ\x80\x02c__main__\nTheModelClass\nq\x00)\x81q\x01}q\x02(X\x08\x00\x00\x00trainingq\x03\x88X\x0b\x00\x00\x00_parametersq\x04ccollections\nOrderedDict\nq\x05)Rq\x06X\x08\x00\x00\x00_buffersq\x07h\x05)Rq\x08X\x1b\x00\x00\x00_non_persistent_buffers_setq\tc__builtin__\nset\nq\n]q\x0b\x85q\x0cRq\rX\x0f\x00\x00\x00_backward_hooksq\x0eh\x05)Rq\x0fX\x0e\x00\x00\x00_forward_hooksq\x10h\x05)Rq\x11X\x12\x00\x00\x00_forward_pre_hooksq\x12h\x05)Rq\x13X\x11\x00\x00\x00_state_dict_hooksq\x14h\x05)Rq\x15X\x1a\x00\x00\x00_load_state_dict_pre_hooksq\x16h\x05)Rq\x17X\x08\x00\x00\x00_modulesq\x18h\x05)Rq\x19(X\x05\x00\x00\x00conv1q\x1actorch.nn.modules.conv\nConv2d\nq\x1b)\x81q\x1c}q\x1d(h\x03\x88h\x04h\x05)Rq\x1e(X\x06\x00\x00\x00weightq\x1fctorch._utils\n_rebuild_parameter\nq ctorch._utils\n_rebuild_tensor_v2\nq!((X\x07\x00\x00\x00storageq"ctorch\nFloatStorage\nq#X\r\x00\x00\x002953073151776q$X\x03\x00\x00\x00cpuq%K$tq&QK\x00(K\x04K\x01K\x03K\x03tq\'(K\tK\tK\x03K\x01tq(\x89h\x05)Rq)tq*Rq+\x88h\x05)Rq,\x87q-Rq.X\x04\x00\x00\x00biasq/h h!((h"h#X\r\x00\x00\x002953073156480q0h%K\x04tq1QK\x00K\x04\x85q2K\x01\x85q3\x89h\x05)Rq4tq5Rq6\x88h\x05)Rq7\x87q8Rq9uh\x07h\x05)Rq:h\th\n]q;\x85q<Rq=h\x0eh\x05)Rq>h\x10h\x05)Rq?h\x12h\x05)Rq@h\x14h\x05)RqAh\x16h\x05)RqBh\x18h\x05)RqCX\x0b\x00\x00\x00in_channelsqDK\x01X\x0c\x00\x00\x00out_channelsqEK\x04X\x0b\x00\x00\x00kernel_sizeqFK\x03K\x03\x86qGX\x06\x00\x00\x00strideqHK\x01K\x01\x86qIX\x07\x00\x00\x00paddingqJK\x00K\x00\x86qKX\x08\x00\x00\x00dilationqLK\x01K\x01\x86qMX\n\x00\x00\x00transposedqN\x89X\x0e\x00\x00\x00output_paddingqOK\x00K\x00\x86qPX\x06\x00\x00\x00groupsqQK\x01X\x0c\x00\x00\x00padding_modeqRX\x05\x00\x00\x00zerosqSX \x00\x00\x00_reversed_padding_repeated_twiceqT(K\x00K\x00K\x00K\x00tqUubX\x03\x00\x00\x00fc1qVctorch.nn.modules.linear\nLinear\nqW)\x81qX}qY(h\x03\x88h\x04h\x05)RqZ(h\x1fh h!((h"h#X\r\x00\x00\x002953073151104q[h%K\x08tq\\QK\x00K\x02K\x04\x86q]K\x04K\x01\x86q^\x89h\x05)Rq_tq`Rqa\x88h\x05)Rqb\x87qcRqdh/h h!((h"h#X\r\x00\x00\x002953073152160qeh%K\x02tqfQK\x00K\x02\x85qgK\x01\x85qh\x89h\x05)RqitqjRqk\x88h\x05)Rql\x87qmRqnuh\x07h\x05)Rqoh\th\n]qp\x85qqRqrh\x0eh\x05)Rqsh\x10h\x05)Rqth\x12h\x05)Rquh\x14h\x05)Rqvh\x16h\x05)Rqwh\x18h\x05)RqxX\x0b\x00\x00\x00in_featuresqyK\x04X\x0c\x00\x00\x00out_featuresqzK\x02ubuub.PK\x07\x08BM\xc3\x18~\x05\x00\x00~\x05\x00\x00PK\x03\x04\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1a\x00:\x00archive/data/2953073151104FB6\x00ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ\xc0\xf3\xe5\xbe\xe0\xa7J\xbe\xb4\x8b"\xbe\xe8\\w>\xc0C>=@K><(\x15\xcc\xbe\x90\xb0\xe2=PK\x07\x08\xfam@\x01 \x00\x00\x00 \x00\x00\x00PK\x03\x04\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1a\x00\x18\x00archive/data/2953073151776FB\x14\x00ZZZZZZZZZZZZZZZZZZZZ!\x07\x9d>\x10\xd4c\xbe\x8a\xa1\x9e\xbe7\x16\x87\xbe8)o\xbd`}$\xbe\xdeL\x92\xbe\x00\x80\x18\xb8\x84\xda\xa0\xbdd\xa9\xa8=\xe6\x1bh\xbe\xc3\xc0b\xbe\x95\xa6\xa9> ,\xe6\xbd%\x0c\x8a>\xd3\xcc0\xbe6 \r>\x85\xfe\x92>\x8c\x94\x90\xbep)\x85\xbe\xd6o \xbe\x13\x8cV\xbe\xb6\xb4\x94\xbe\xb0\xb4\x17>`vh<_\xcf\x80>\xd3d\x80\xbe\xfcc\xed=\'\x1d\x91\xbe\xc81\x15=(\xb0\x15=j\xac\x7f>\xeb\xa2\x96>\xf5\xf3\xa5>\xcf\x7f\x93>\x1a\x02V>PK\x07\x08@XF\x95\x90\x00\x00\x00\x90\x00\x00\x00PK\x03\x04\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1a\x00(\x00archive/data/2953073152160FB$\x00ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ\xf4^\x9b>T\xe1B\xbePK\x07\x08:&\x80\x06\x08\x00\x00\x00\x08\x00\x00\x00PK\x03\x04\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1a\x000\x00archive/data/2953073156480FB,\x00ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ\x92\xafb>\xecJ\x89\xbd*\xe00>\x1e\x07\x88\xbePK\x07\x08\xdes\x13\x12\x10\x00\x00\x00\x10\x00\x00\x00PK\x03\x04\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0f\x003\x00archive/versionFB/\x00ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ3\nPK\x07\x08\xd1\x9egU\x02\x00\x00\x00\x02\x00\x00\x00PK\x01\x02\x00\x00\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00BM\xc3\x18~\x05\x00\x00~\x05\x00\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00archive/data.pklPK\x01\x02\x00\x00\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00\xfam@\x01 \x00\x00\x00 \x00\x00\x00\x1a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xce\x05\x00\x00archive/data/2953073151104PK\x01\x02\x00\x00\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00@XF\x95\x90\x00\x00\x00\x90\x00\x00\x00\x1a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00p\x06\x00\x00archive/data/2953073151776PK\x01\x02\x00\x00\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00:&\x80\x06\x08\x00\x00\x00\x08\x00\x00\x00\x1a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00`\x07\x00\x00archive/data/2953073152160PK\x01\x02\x00\x00\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00\xdes\x13\x12\x10\x00\x00\x00\x10\x00\x00\x00\x1a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xd8\x07\x00\x00archive/data/2953073156480PK\x01\x02\x00\x00\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00\xd1\x9egU\x02\x00\x00\x00\x02\x00\x00\x00\x0f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00`\x08\x00\x00archive/versionPK\x06\x06,\x00\x00\x00\x00\x00\x00\x00\x1e\x03-\x00\x00\x00\x00\x00\x00\x00\x00\x00\x06\x00\x00\x00\x00\x00\x00\x00\x06\x00\x00\x00\x00\x00\x00\x00\x9b\x01\x00\x00\x00\x00\x00\x00\xd2\x08\x00\x00\x00\x00\x00\x00PK\x06\x07\x00\x00\x00\x00m\n\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00PK\x05\x06\x00\x00\x00\x00\x06\x00\x06\x00\x9b\x01\x00\x00\xd2\x08\x00\x00\x00\x00'

こんな具合に、無事出力することができました!
バイナリとなったため、ネットワーク越しに転送とか、いろいろできそうですね。

バイナリストリームを元に戻してみよう!

せっかくバイナリストリーム化したので、次は元に戻してみたいと思います。
が、PyTorchのドキュメントには、元の戻し方はどこにも記載がありません!(ぇ

普通に考えると、こんな実装となるでしょうか?

in_model = torch.load(io.BytesIO(out_buf.getvalue()))
print(in_model, "\n")

io.BytesIO の getvalueメソッドで、オブジェクト内のデータをまるっと取り出します。
それを、torch.load() で復元するという方法です。

出力結果

TheModelClass(
  (conv1): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=4, out_features=2, bias=True)
) 

無事に、元の形には戻せましたね!

ところでPyTorchのモデルの中身ってどうなってるの???

無事元に戻せたには戻せたのですが、ここで気になること。
これ、本当にパラメータは入っているのか???

というわけで、復元したモデルの中身を見てみたいと思います。
PyTorchのモデルの中身を参照するには、state_dict() を使用します。
モデル保存時に使用するので、使ったことがある方は多いかと思います。

print(in_model.state_dict(), "\n")

出力結果

OrderedDict([('conv1.weight', tensor([[[[ 3.0670e-01, -2.2249e-01, -3.0983e-01],
          [-2.6384e-01, -5.8389e-02, -1.6063e-01],
          [-2.8574e-01, -3.6359e-05, -7.8542e-02]]],


        [[[ 8.2354e-02, -2.2667e-01, -2.2144e-01],
          [ 3.3135e-01, -1.1239e-01,  2.6962e-01],
          [-1.7266e-01,  1.3782e-01,  2.8710e-01]]],


        [[[-2.8238e-01, -2.6008e-01, -1.5668e-01],
          [-2.0952e-01, -2.9044e-01,  1.4815e-01],
          [ 1.4188e-02,  2.5158e-01, -2.5077e-01]]],


        [[[ 1.1591e-01, -2.8343e-01,  3.6424e-02],
          [ 3.6545e-02,  2.4968e-01,  2.9421e-01],
          [ 3.2413e-01,  2.8808e-01,  2.0899e-01]]]])),

    ('conv1.bias', tensor([ 0.2214, -0.0670,  0.1727, -0.2657])), 

    ('fc1.weight', tensor([[-0.4491, -0.1979, -0.1587,  0.2416],
        [ 0.0465,  0.0116, -0.3986,  0.1107]])), 

    ('fc1.bias', tensor([ 0.3035, -0.1903]))]) 

上記出力結果は、筆者が適当に改行を追加した個所もありますが、

PyTorchのモデルはOrderedDict型のリスト

となっています。

このモデルの例の場合は、下記の順序のリストになっていることがわかりますね。

  1. conv1.weight :conv1の重み
  2. conv1.bias :conv1のバイアス
  3. fc1.weight :fc1の重み
  4. fc1.bias :fc1のバイアス

重みとバイアスがセットになって、それが順番に並んでいるという具合です。
ちなみに、モデル宣言時に bias=False とすると、バイアスが配列からいなくなるようです。

(これが扱いやすいかどうかは・・・・う~ん・・・

ソースコード全文

最後に、本日紹介したソースコード全文を載せます。

import io
import torch
import torch.nn as nn
import torch.nn.functional as F


# モデル定義
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(1, 4, 3, 1)
        self.fc1 = nn.Linear(4, 2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.fc1(x))
        return x


# モデル初期化
print("1. モデル初期化")
model = TheModelClass()
print(model, "\n")

# モデルのバイナリストリーム化
print("2. モデルのバイナリストリーム化")
out_buf = io.BytesIO()
torch.save(model, out_buf)
print(out_buf.getvalue(), "\n")

# モデルの復元
print("3. モデルの復元")
in_model = torch.load(io.BytesIO(out_buf.getvalue()))
print(in_model, "\n")

# モデル Value 参照
print("4. モデル Value 参照")
print(in_model.state_dict(), "\n")

いかがでしたでしょうか。
PyTorchのバイナリストリーム化、及びモデルの中身について紹介しました。
使う機会があるかは・・・レアケースかと思いますが、雑学レベルで。。。