任务描述

这个神奇的任务来自于工地上的一个奇葩需求:之前的一些基于 GRU 的 keras 模型需要成等价的 pytorch 模块,要求

  1. 模型权重和结构需要保持一致
  2. 模型需要嵌入一个更大的模型结构中,作为一个子模块,以进行后续训练

其中样例 GRU 用 keras functional API 构建,模型结构为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
tf.keras.utils.set_random_seed(4396)
X = tf.keras.layers.Input(shape=(60, 78),
name='factor')
x = X
x = tf.keras.layers.GRU(64,
return_sequences=True,
bias_initializer='glorot_uniform')(x)

x = tf.keras.layers.GRU(64,
return_sequences=False,
bias_initializer='glorot_uniform')(x)

x = tf.keras.layers.Dense(48, activation='relu')(x)
x = tf.keras.layers.Dense(32, activation='relu')(x)
y = tf.keras.layers.Dense(1)(x)

model = tf.keras.models.Model(inputs=X,
outputs=y)

上面的 set_random_seed 为了让测试时能够保证多次运行时,所构建的模型权重保持一致;bias_initializer 值被设为 glorot_uniform 是因为 GRU 的 bias 默认初始化为 zeros,测试时可能对比不出东西来,所以作此修改。

调研

直接用 ChatGPT 调研一下这个怎么做,ChatGPT 给出的方案是:

  1. 使用 ONNX 转换模型
  2. 在 PyTorch 中加载 ONNX 模型
  3. 在 PyTorch 中重新构建模型(可选)

同时 ChatGPT 给出了代码

keras 模型转换成 ONNX

先来科(zhao)普(chao)一下什么是 ONNX:

ONNX(Open Neural Network Exchange)是一种针对机器学习所设计的开放式的文件格式,用于存储训练好的模型。它使得不同的人工智能框架可以采用相同格式存储模型数据并交互。

看起来就是为了让从不同框架中训练出来的模型能够以同一种形式规范部署的一个东西

对于 keras 模型,我们可以利用 tf2onnx 这个库来转换成 ONNX 格式:

1
pip install tf2onnx

然后我们可以用如下代码转换即可:

1
2
3
4
5
import tf2onnx

onnx_model_path = '/path/to/model.onnx'
model_proto, external_tensor_storage = tf2onnx.convert.from_keras(model, output_path=onnx_model_path)

并且我们还可以利用转换后的 onnx 模型在 numpy.ndarray 上进行推理:

1
2
3
4
5
6
7
8
9
10
11
12
13
import onnx

onnx_model = onnx.load(onnx_model_path)

import onnxruntime
import numpy as np
a = np.random.randn(100, 60, 78).astype(np.float32)

resnet_session = onnxruntime.InferenceSession(onnx_model_path)
inputs = {resnet_session.get_inputs()[0].name: a}
outputs_onnx = resnet_session.run(None, inputs)[0]
outputs_keras = model(a, training=False)
print(np.sum(np.abs(outputs_onnx - outputs_keras)))

运行之后发现 keras 模型和 onnx 模型推理结果的残差在 5e-6 这个数量级。

但是马上有一个问题:这个网络需要嵌到 pytorch 某个模型里面作为一个子模块,主要是要用来做反向传播,所以还得继续看看是个怎么个回事儿。

ONNX 模型转换成 pytorch

那么自然会想到将 ONNX 模型进一步转换为 pytorch 模型,主要能找到几个库:

前两个库都配置了 CI,而且近期有维护;后面那个库就几年都没有维护了,看起来比较寄

这里主要使用过了 onnx2pytorchonnx-pytorch,原理应该都是通过解析 onnx 模型部件,然后动态生成对应模块的代码。但是一通操作之后发现:

  • onnx2pytorch 能将 onnx 模型生成对应的 pytorch nn.Module,但是没法用生成之后的模型来做前向传播(到这里都还没涉及到反向传播呢)
  • onnx-pytorch 把可能涉及到的模块都做了动态生成 pytorch 模型构建的代码,但是没有适配 GRU 的代码(怎么可能穷举得完)
  • 这几个模块都会 import torch,然后又需要配置 torch 版本,包括和 CUDA 的适配;如果没配好,分分钟报错:
1
2
    if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"):
RuntimeError: operator torchvision::nms does not exist

直接换 CPU 版本也是报错一堆,主要是这个库不太能支持 GRU 的运算,很烦

其实 onnx2torch 的 star 数是最多的,但是当时没去试试看,说白了还是调研没做够

加上工地上业务一般都偏固定,网络基本上就是那个小而美的结构,不会涉及到太复杂的转换,所以先直接弄一个手艺活的转换代码

手艺活开始

一个比较暴力的做法是直接在 pytorch 那边依葫芦画瓢构造代码,然后把 keras 的模型权重往里面怼进去(更进一步地,直接拿着上面的 keras 代码折磨 ChatGPT 去):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from torch import nn

# 定义与 Keras 模型相同的 PyTorch 模型架构
class GRUModel(nn.Module):
def __init__(self, input_size: int, output_size: int):
super(GRUModel, self).__init__()
self.gru = nn.GRU(input_size, 1, num_layers=2, batch_first=True)
self.dense_1 = nn.Linear(1, 48)
self.relu_1 = nn.ReLU()
self.dense_2 = nn.Linear(48, 32)
self.relu_2 = nn.ReLU()
self.dense_3 = nn.Linear(32, output_size)

def forward(self, x):
# 定义前向传播
x_, h_ = self.gru(x)
x_ = x_[:, -1, :]
x_ = self.relu_1(self.dense_1(x_))
x_ = self.relu_2(self.dense_2(x_))
y_ = self.dense_3(x_)
return y_

这里注意一下 pytorch 里面 GRU 的比较不同的几点:

  • 默认情况下,keras 的 GRU 输入形状含义为 (batch_size, sequence_length, num_features),但是 pytorch 的 GRU 输入形状含义为 (sequence_length, batch_size, num_features),需要在构造函数里面声明 batch_first=True 以应用和 keras 相同含义的输入形状;
  • num_layers=2 也即是两层 GRU,也就是把两个一层的 GRU 垛起来。这里留个坑,之后还会深入到 torch 源码里面探究这个参数
  • pytorch 的 GRU 模块同时返回模块输出和隐藏层结果,最后需要手动把输出的最后一个时间步的结果给取出来

然后就可以怼权重了,这里需要注意两点(经过试错以及打印两边权重的 shape 得出):

  1. 在将权重拷贝进 pytorch 模型的时候,需要在 with torch.no_grad() 下进行,这是因为默认权重的 require_grad 属性都是 True,如果直接拷贝进去的话 pytorch 直接不知道怎么算梯度了;
  2. keras 权重和 pytorch 权重大概是个转置的关系:如线性层的 $W$ 矩阵($y=Wx+b$)在 keras 中的形状为 (40, 16),在 pytorch 中的形状为 (16, 40)

克隆模块权重

先通过 keras model.get_weights() 来查看模型权重都长啥样:

1
2
3
4
# 创建模型实例
pytorch_model = GRUModel(78, 1)

weights = model.get_weights()

结果为

1
[(78, 120), (40, 120), (2, 120), (40, 120), (40, 120), (2, 120), (40, 16), (16,), (16, 1), (1,)]

应该是一字排开的:前面 6 个是 layers.GRU 的权重,后面几个是 layers.Dense 的权重

克隆主要包括 nn.Linearnn.GRU

nn.Linear 根据公开文档以及自己瞎猜,可以猜到模块权重对应的属性为 weightbias

nn.GRU 走到 pytorch 的 nn.GRU 的构造函数里面,可以翻到其基类 nn.RNNBase 构造函数里面的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
for layer in range(num_layers):
for direction in range(num_directions):
real_hidden_size = proj_size if proj_size > 0 else hidden_size
layer_input_size = (
input_size if layer == 0 else real_hidden_size * num_directions
)

w_ih = Parameter(
torch.empty((gate_size, layer_input_size), **factory_kwargs)
)
w_hh = Parameter(
torch.empty((gate_size, real_hidden_size), **factory_kwargs)
)
b_ih = Parameter(torch.empty(gate_size, **factory_kwargs))
# Second bias vector included for CuDNN compatibility. Only one
# bias vector is needed in standard definition.
b_hh = Parameter(torch.empty(gate_size, **factory_kwargs))
layer_params: Tuple[Tensor, ...] = ()
if self.proj_size == 0:
if bias:
layer_params = (w_ih, w_hh, b_ih, b_hh)
else:
layer_params = (w_ih, w_hh)
else:
w_hr = Parameter(
torch.empty((proj_size, hidden_size), **factory_kwargs)
)
if bias:
layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr)
else:
layer_params = (w_ih, w_hh, w_hr)

suffix = "_reverse" if direction == 1 else ""
param_names = ["weight_ih_l{}{}", "weight_hh_l{}{}"]
if bias:
param_names += ["bias_ih_l{}{}", "bias_hh_l{}{}"]
if self.proj_size > 0:
param_names += ["weight_hr_l{}{}"]
param_names = [x.format(layer, suffix) for x in param_names]

特别地,我们这是 2 层的、单向的、带有 bias 的 GRU,所以参数会有

1
2
weight_ih_l0, weight_hh_l0, bias_ih_l0, bias_hh_l0
weight_ih_l1, weight_hh_l1, bias_ih_l1, bias_hh_l1

通过上面的一通分析,自然就会写出如下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# 创建模型实例
pytorch_model = GRUModel(78, 1)

weights = model.get_weights()

# Try to clone weights
print([x.shape for x in weights])

with torch.no_grad():
# Copy GRU
pytorch_model.gru.weight_ih_l0.copy_(torch.Tensor(weights[0].transpose()))
pytorch_model.gru.weight_hh_l0.copy_(torch.Tensor(weights[1].transpose()))
pytorch_model.gru.bias_ih_l0.copy_(torch.Tensor(weights[2][0]))
pytorch_model.gru.bias_hh_l0.copy_(torch.Tensor(weights[2][1]))

pytorch_model.gru.weight_ih_l1.copy_(torch.Tensor(weights[3].transpose()))
pytorch_model.gru.weight_hh_l1.copy_(torch.Tensor(weights[4].transpose()))
pytorch_model.gru.bias_ih_l1.copy_(torch.Tensor(weights[5][0]))
pytorch_model.gru.bias_hh_l1.copy_(torch.Tensor(weights[5][1]))

# Copy linear
pytorch_model.dense_1.weight.copy_(torch.Tensor(weights[6].transpose()))
pytorch_model.dense_1.bias.copy_(torch.Tensor(weights[7]))

pytorch_model.dense_2.weight.copy_(torch.Tensor(weights[8].transpose()))
pytorch_model.dense_2.bias.copy_(torch.Tensor(weights[9]))

pytorch_model.dense_3.weight.copy_(torch.Tensor(weights[10].transpose()))
pytorch_model.dense_3.bias.copy_(torch.Tensor(weights[11]))

# Check results
np.random.seed(7777)
X = np.random.randn(1, 60, 78).astype(np.float32)

y_keras = model(X, training=False).numpy()

X_torch = torch.Tensor(X)
with torch.no_grad():
y_torch = pytorch_model(X_torch).numpy()
y_diff = y_keras - y_torch

print(y_keras)
print()
print(y_torch)
print()
print(y_diff)

然后就开始踩坑了:发现前向传播的结果对不上。线性层的参数克隆应该不太可能出错,所以重点排查 GRU 的参数克隆环节是不是出问题了

keras 与 pytorch 在 GRU 实现上的差异

我们先把模型简化一下,仅实现 1 层 GRU 的参数克隆:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import numpy as np
import torch
import torch.nn as nn
import tensorflow as tf

tf.keras.utils.set_random_seed(777)

# 加载 Keras 模型
X = tf.keras.layers.Input(shape=(6, 5),
name='factor')
x = X
y = tf.keras.layers.GRU(1,
return_sequences=False,
bias_initializer='glorot_uniform')(x)

model = tf.keras.models.Model(inputs=X,
outputs=y)

# 定义与 Keras 模型相同的 PyTorch 模型架构
class GRUModel(nn.Module):
def __init__(self, input_size: int, output_size: int):
super(GRUModel, self).__init__()
self.gru = nn.GRU(input_size, 1, num_layers=1, batch_first=True)

def forward(self, x):
# 定义前向传播
x_, h_ = self.gru(x)
y_ = x_[:, -1, :]
return y_

# 创建模型实例
pytorch_model = GRUModel(5, 1)
weights = model.get_weights()

# Try to clone weights
print([x.shape for x in weights])

with torch.no_grad():
# Copy GRU
pytorch_model.gru.weight_ih_l0.copy_(twist(torch.Tensor(weights[0].transpose())))
pytorch_model.gru.weight_hh_l0.copy_(twist(torch.Tensor(weights[1].transpose())))
pytorch_model.gru.bias_ih_l0.copy_(twist(torch.Tensor(weights[2][0])))
pytorch_model.gru.bias_hh_l0.copy_(twist(torch.Tensor(weights[2][1])))


pytorch_model.eval()


# Check results
# np.random.seed(7777)
X = np.random.randn(1, 6, 5).astype(np.float32) #

y_keras = model(X, training=False).numpy()

X_torch = torch.Tensor(X)
with torch.no_grad():
y_torch = pytorch_model(X_torch).numpy()
y_diff = y_keras - y_torch

print(y_keras.reshape(-1))
print()
print(y_torch.reshape(-1))
print()
print(y_diff.reshape(-1))

会发现输出对不上,然后就想去看看 pytorch 代码层面上是怎么实现。然而,pytorch 关于 GRU 的前向推理都固化到了 so 文件里面了,那总不可能去逆 so 文件(

然后就去找网上有没有现成的,找到个 kaggle notebookgithub issue,发现原因果然是 keras 和 pytorch 的矩阵排布不一致导致的问题:

  • 为了增大矩阵计算时的并行度,在计算

    的时候,我们并不是逐个去计算每条等式,而是将这些等式拼起来计算

  • keras 的 GRU 拼法是 (x_upd, x_reset, x_new)

  • pytorch 的 GRU 拼法是 (x_reset, x_upd, x_new)

所以我们只需要再根据这个排布的异同,转换一下权重矩阵的排布就行了:

(下面这个转换是那个 github issue 里面所提供的将 pytorch 权重转换到 keras 权重的解决方案,事实上这个函数也可以用来将 keras 权重转换到 pytorch 权重,因为置换 (1, 0, 2) 的逆置换仍旧是它本身)

1
2
3
4
5
6
def regroup_params_gru(weight_or_bias_gru, axis=0):
assert len(weight_or_bias_gru.shape) == 2 and weight_or_bias_gru.shape[axis] % 3 == 0
# change params from tf order(z,r,h) into tf order(r,z,n) aka. (r,z,h)

[z, r, h] = np.split(weight_or_bias_gru, 3, axis=axis)
return np.concatenate((r, z, h), axis=axis)

之前没有看到这个 github issue,所以自己手搓了一个

然后还有一点值得注意的就是 pytorch 里面 bias_ih_l0bias_hh_l0 是分别存的,但是 keras 里面是将两者并起来,构成一个 $2 \times n$ 的矩阵。

可以自行比较一下将上述改动应用后,上面单层 GRU 的输出结果是否一致。

最终实现 1

总之把上面讨论到的这些应用到代码里面就行了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import numpy as np
import torch
import torch.nn as nn
import tensorflow as tf

# 加载 Keras 模型
model = tf.keras.models.load_model('/path/to/raw/model')

# 定义与 Keras 模型相同的 PyTorch 模型架构
class GRUModel(nn.Module):
def __init__(self, input_size: int, output_size: int):
super(GRUModel, self).__init__()
self.gru = nn.GRU(input_size, 40, num_layers=2, batch_first=True)
self.dense_1 = nn.Linear(40, 16)
self.relu_1 = nn.ReLU()
self.dense_2 = nn.Linear(16, output_size)

def forward(self, x):
# 定义前向传播
x_, h_ = self.gru(x)
x_ = x_[:, -1, :]
x_ = self.relu_1(self.dense_1(x_))
y_ = self.dense_2(x_)
return y_


# 创建模型实例
pytorch_model = GRUModel(78, 1)

# https://www.kaggle.com/code/hengck23/example-from-pytorch-to-keras-gru-by-hand/notebook
def twist(x: torch.Tensor):
num = x.shape[0]
num_ = num // 3
y = torch.cat([x[num_:2*num_, ...], x[:num_, ...], x[2*num_:, ...]], dim=0)
return y

weights = model.get_weights()

# Try to clone weights
print([x.shape for x in weights])

with torch.no_grad():
# Copy GRU
pytorch_model.gru.weight_ih_l0.copy_(twist(torch.Tensor(weights[0].transpose())))
pytorch_model.gru.weight_hh_l0.copy_(twist(torch.Tensor(weights[1].transpose())))
pytorch_model.gru.bias_ih_l0.copy_(twist(torch.Tensor(weights[2][0])))
pytorch_model.gru.bias_hh_l0.copy_(twist(torch.Tensor(weights[2][1])))

pytorch_model.gru.weight_ih_l1.copy_(twist(torch.Tensor(weights[3].transpose())))
pytorch_model.gru.weight_hh_l1.copy_(twist(torch.Tensor(weights[4].transpose())))
pytorch_model.gru.bias_ih_l1.copy_(twist(torch.Tensor(weights[5][0])))
pytorch_model.gru.bias_hh_l1.copy_(twist(torch.Tensor(weights[5][1])))

# Copy linear
pytorch_model.dense_1.weight.copy_(torch.Tensor(weights[6].transpose()))
pytorch_model.dense_1.bias.copy_(torch.Tensor(weights[7]))

pytorch_model.dense_2.weight.copy_(torch.Tensor(weights[8].transpose()))
pytorch_model.dense_2.bias.copy_(torch.Tensor(weights[9]))


# Check results
np.random.seed(7777)
X = np.random.randn(1, 60, 78).astype(np.float32)

y_keras = model(X, training=False).numpy()
# print(y_keras)

X_torch = torch.Tensor(X)
with torch.no_grad():
y_torch = pytorch_model(X_torch).numpy()
y_diff = y_keras - y_torch

print(y_keras)
print()
print(y_torch)
print()
print(y_diff)

# OK: Save model
save_model_path = '/path/to/model_weights.pth'

torch.save(pytorch_model.state_dict(), save_model_path)

最终实现 2

然后又来了个让人头皮发麻的需求:人拍拍脑袋,想把 pytorch 的两层 GRU 拆成两个一层的 GRU

那彳亍,只需要重新处理下网络结构,以及权重拷贝就行:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import numpy as np
import torch
import torch.nn as nn
import tensorflow as tf

# 加载 Keras 模型
model = tf.keras.models.load_model('/path/to/model')

# 定义与 Keras 模型相同的 PyTorch 模型架构
class GRUModel(nn.Module):
def __init__(self, input_size: int, output_size: int):
super(GRUModel, self).__init__()
self.gru_1 = nn.GRU(input_size, 40, batch_first=True)
self.gru_2 = nn.GRU(40, 40, batch_first=True)
self.dense_1 = nn.Linear(40, 16)
self.relu_1 = nn.ReLU()
self.dense_2 = nn.Linear(16, output_size)

def forward(self, x):
# 定义前向传播
x_, h_ = self.gru_1(x)
x_, h_ = self.gru_2(x_)
x_ = x_[:, -1, :]
x_ = self.relu_1(self.dense_1(x_))
y_ = self.dense_2(x_)
return y_



# 创建模型实例
pytorch_model = GRUModel(78, 1)

# https://www.kaggle.com/code/hengck23/example-from-pytorch-to-keras-gru-by-hand/notebook
def twist(x: torch.Tensor):
num = x.shape[0]
num_ = num // 3
y = torch.cat([x[num_:2*num_, ...], x[:num_, ...], x[2*num_:, ...]], dim=0)
return y

weights = model.get_weights()

# Try to clone weights
print([x.shape for x in weights])

with torch.no_grad():
# Copy GRU
pytorch_model.gru_1.weight_ih_l0.copy_(twist(torch.Tensor(weights[0].transpose())))
pytorch_model.gru_1.weight_hh_l0.copy_(twist(torch.Tensor(weights[1].transpose())))
pytorch_model.gru_1.bias_ih_l0.copy_(twist(torch.Tensor(weights[2][0])))
pytorch_model.gru_1.bias_hh_l0.copy_(twist(torch.Tensor(weights[2][1])))

pytorch_model.gru_2.weight_ih_l0.copy_(twist(torch.Tensor(weights[3].transpose())))
pytorch_model.gru_2.weight_hh_l0.copy_(twist(torch.Tensor(weights[4].transpose())))
pytorch_model.gru_2.bias_ih_l0.copy_(twist(torch.Tensor(weights[5][0])))
pytorch_model.gru_2.bias_hh_l0.copy_(twist(torch.Tensor(weights[5][1])))

# Copy linear
pytorch_model.dense_1.weight.copy_(torch.Tensor(weights[6].transpose()))
pytorch_model.dense_1.bias.copy_(torch.Tensor(weights[7]))

pytorch_model.dense_2.weight.copy_(torch.Tensor(weights[8].transpose()))
pytorch_model.dense_2.bias.copy_(torch.Tensor(weights[9]))


# Check results
np.random.seed(7777)
X = np.random.randn(1, 60, 78).astype(np.float32)

y_keras = model(X, training=False).numpy()
# print(y_keras)

X_torch = torch.Tensor(X)
with torch.no_grad():
y_torch = pytorch_model(X_torch).numpy()
y_diff = y_keras - y_torch

print(y_keras)
print()
print(y_torch)
print()
print(y_diff)

# OK: Save model
save_model_path = '/path/to/model_weights.pth'

# 假设你的模型实例是 pytorch_model
torch.save(pytorch_model.state_dict(), save_model_path)

总结

  • 这种手艺活对比起来是真滴需要耐心
  • github 项目还是得搜索之后去用 star 数多的,起码有个 reputation 上的保证