ISP RGB 域之 CCM

Why?

我们肉眼的对光谱的 RGB 响应曲线和 sensor 的响应曲线是不同的;

CCM 一般是 3x3 矩阵形式,也有 3x4 形式的,3x4 形式主要是给 rgb 各自加一个 offset

\[ \begin{bmatrix} R_{out} \\ G_{out} \\ B_{out} \end{bmatrix} = \begin{bmatrix} CC_{00} & CC_{01} & CC_{02}\\ CC_{10} & CC_{11} & CC_{12}\\ CC_{20} & CC_{21} & CC_{22} \end{bmatrix} \ast \begin{bmatrix} R_{in} \\ G_{in} \\ B_{in} \end{bmatrix} \]

\[ \begin{bmatrix} R_{out} \\ G_{out} \\ B_{out} \end{bmatrix} = \begin{bmatrix} CC_{00} & CC_{01} & CC_{02} & OFFSET_r\\ CC_{10} & CC_{11} & CC_{12} & OFFSET_g\\ CC_{20} & CC_{21} & CC_{22} & OFFSET_b \end{bmatrix} \ast \begin{bmatrix} R_{in} \\ G_{in} \\ B_{in} \end{bmatrix} \]

上面的人眼 rgb 响应和 sensor rgb 响应曲线都是非线性的,所以指望通过一个 CCM 矩阵就得到匹配度很好的映射关系是不现实的。现实中,往往会标定很多个 CCM,ISP 在运行的时候根据照度,光源等等因素,选择两个最近的 CCM 插值得到最终的 CCM;

CCM 模块在 apply awb gain 后面,因此 3x3 个值存在约束条件:

\[ CC_{00} + CC_{01} + CC_{02} = 1 CC_{10} + CC_{11} + CC_{12} = 1 CC_{20} + CC_{21} + CC_{22} = 1 \]

保证灰点也就是 r=g=b 的点,经过 CCM 以后仍然 r=g=b;

标定 CCM 方法

  • 用 camera 拍一张某个色温下的 24 色卡 raw 文件:注意 shading 影响,拍这个色卡占整个 sensor 中间一小部分就可以。

  • raw 文件预处理。主要包括减 blc,根据第 4 行的 patch,获取 awb gain 值,乘上去;这样就拿到了这个色温下 24 个 patch 的 rgb 值。

  • 理想 rgb 值。

这是色卡厂家提供的 24 个 patch 的标准 rgb 空间下的理论值;拿到这个值以后,需要进行反 gamma 处理,因为厂家提供的是 srgb 的值,是带了 2.2gamma 的,ISP 的 CCM 模块一般是在 gamma 前面,因此要对理论值进行反 gamma 处理;

标定算法

已知 100 个 raw rgb 值,已知对应的理论 rgb 值;求一个 3x3 线性变换矩阵;这个矩阵要使得映射后的 rgb 值尽可能的接近理论值;

借鉴深度学习的梯度下降方法,可以快速得到 CCM;并且可以自定义 100 个 patch 的重要程度,使得某些 patch 的误差非常小。

定义损失和梯度函数,测量 rgb 值得差异,采用 L2 距离;

\[ L_i = {\textstyle \sum_{n=1}^{18}} (CCM \ast RGB_{origin} - RGB_{standard})^2 \]

完整代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import torch

ccm = torch.tensor([[1655, -442, -189], [-248, 1466, -194], [-48, -770, 1842]], dtype=torch.float32)
rgb_data = torch.randint(0, 255, (3, 100))
rgb_data = rgb_data.float()

error_manual = torch.randn((3, 100)) * 16
rgb_target = ccm.mm(rgb_data)/1024.0
rgb_target_error = rgb_target + error_manual
ccm_calc1 = torch.tensor([0.0], dtype=torch.float32, requires_grad=True)
ccm_calc2 = torch.tensor([0.0], dtype=torch.float32, requires_grad=True)
ccm_calc3 = torch.tensor([0.0], dtype=torch.float32, requires_grad=True)
ccm_calc5 = torch.tensor([0.0], dtype=torch.float32, requires_grad=True)
ccm_calc6 = torch.tensor([0.0], dtype=torch.float32, requires_grad=True)
ccm_calc7 = torch.tensor([0.0], dtype=torch.float32, requires_grad=True)

def squared_loss(rgb_tmp, rgb_ideal):
return torch.sum((rgb_tmp-rgb_ideal)**2)

def sgd(params, lr, batch_size):
for param in params:
param.data -= lr * param.grad/batch_size;

def net(ccm_calc1, ccm_calc2, ccm_calc3, ccm_calc5, ccm_calc6, ccm_calc7, rgb_data):
rgb_tmp = torch.zeros_like(rgb_data)
rgb_tmp[0, :] = ((1024.0 - ccm_calc1 - ccm_calc2) * rgb_data[0, :] + ccm_calc1 * rgb_data[1, :] + ccm_calc2 * rgb_data[2, :]) / 1024.0
rgb_tmp[1, :] = (ccm_calc3 * rgb_data[0, :] + (1024.0 - ccm_calc3 - ccm_calc5) * rgb_data[1, :] + ccm_calc5 * rgb_data[2, :]) / 1024.0
rgb_tmp[2, :] = (ccm_calc6 * rgb_data[0, :] + ccm_calc7 * rgb_data[1, :] + (1024.0 - ccm_calc6 - ccm_calc7) * rgb_data[2, :]) / 1024.0
return rgb_tmp

lr = 3
num_epochs = 100
for epoch in range(num_epochs):
l = squared_loss(net(ccm_calc1, ccm_calc2, ccm_calc3, ccm_calc5, ccm_calc6, ccm_calc7, rgb_data), rgb_target_error)
l.backward()
sgd([ccm_calc1, ccm_calc2, ccm_calc3, ccm_calc5, ccm_calc6, ccm_calc7], lr, 100)
ccm_calc1.grad.data.zero_()
ccm_calc2.grad.data.zero_()
ccm_calc3.grad.data.zero_()
ccm_calc5.grad.data.zero_()
ccm_calc6.grad.data.zero_()
ccm_calc7.grad.data.zero_()
print('epoch %d, loss %f'%(epoch, l))

res = torch.tensor([[1024.0 - ccm_calc1 - ccm_calc2, ccm_calc1, ccm_calc2],
[ccm_calc3, 1024.0-ccm_calc3-ccm_calc5, ccm_calc5],
[ccm_calc6, ccm_calc7, 1024.0-ccm_calc6-ccm_calc7]], dtype=torch.float32)
print(res);

rgb_apply_ccm = res.mm(rgb_data)/1024.0

fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111, projection='3d')
fig2 = plt.figure(2)
ax2 = fig2.add_subplot(111, projection='3d')

x2 = rgb_data[0]
y2 = rgb_data[1]
z2 = rgb_data[2]

ax1.scatter(x2, y2, z2, marker='*', c='b', label='origin RGB')

ax1.set_xlim(-80, 360)
ax1.set_ylim(-80, 360)
ax1.set_zlim(-80, 360)
ax1.set_xlabel('R')
ax1.set_ylabel('G')
ax1.set_zlabel('B')

x3 = rgb_target[0]
y3 = rgb_target[1]
z3 = rgb_target[2]

ax1.scatter(x3, y3, z3, marker='o', c='c', label='target rgb')

for i in range(len(x3)):
ax1.plot([x2[i], x3[i]], [y2[i], y3[i]], [z2[i], z3[i]], 'k-.')
ax1.legend()

ax2.set_xlim(-80, 360)
ax2.set_ylim(-80, 360)
ax2.set_zlim(-80, 360)
ax2.set_xlabel('R')
ax2.set_ylabel('G')
ax2.set_zlabel('B')
ax2.scatter(x3, y3, z3, marker='o', c='c', label='target rgb')

x4 = rgb_apply_ccm[0]
y4 = rgb_apply_ccm[1]
z4 = rgb_apply_ccm[2]
ax2.scatter(x4, y4, z4, marker='^', c='b', label='apply ccm rgb')
ax2.legend()

plt.show()

可视化看一下映射后的点与理论点的距离:

参考文献

https://zhuanlan.zhihu.com/p/108626480