개요
LivePortrait는 최근 face reenactment 분야에서 성능이 좋게 나온 모델 중 하나입니다.
- arxiv: https://arxiv.org/abs/2407.03168
- github(학습코드 x): https://github.com/KwaiVGI/LivePortrait
특징:
- Face vid2vid 을 조금 더 optimize했다
- Retargeting하는 방법론 제시
- Stitching하는 방법론 제시
1번같은 경우 딱히 논문에 인사이트가 있지는 않고, 모델 구조를 face vid2vid 대비 optimization를 좀 했다고 보시면 됩니다. 그리고 양질의 데이터가 더 큰 역할을 했을 것이라고 추측됩니다.
2번은 참고할만한데 cross identity reenactment (source image와 driving video가 서로 다른 사람인 경우) 를 진행할 때 source 얼굴과 driving 얼굴의 모양이 달라서 단순하게 driving에서 얻은 얼굴 motion을 source에 입히면 눈 & 입이 잘 닫히지 않는 문제가 생길 수 있습니다. 이러한 문제를 푸는 방법은 제가 생각했을 때 2가지로 나눌 수 있습니다
- motion에 identity (얼굴 모양)과 expression (표정) 정보가 잘 분리되어 encoding되지 않아서 생기는 문제이기 때문에 처음부터 identity와 expression을 정확하게 분리하도록 motion을 생성하도록 학습을 한다
- source와 driving의 다른 얼굴 모양에서 발생하는 문제를 자연스럽게 보정해주는 모델을 학습한다
이 논문에서는 후자의 방법을 사용했다고 보시면 됩니다.
3번같은 경우 reenactment를할 때 어깨 부분이 이상하게 움직이지 않도록 constraint를 줘서 자연스러운 reeanctment가 가능하도록 하는 모델의 feature이라고 생각하면 됩니다.
방법
Face vid2vid 퀵 리뷰
(시간날 때 따로 디테일한 리뷰를 적어보겠습니다...)
Face reenactment를 하는 방법은 정말 많은데 face vid2vid는 feature deformation을 해서 얼굴을 생성하는 방법 중 하나입니다. 이것을 하기 위해 다음과 같이 진행합니다
- source image에서 appearance를 추출합니다 (network F).
- source image에서 canonical keypoint, head pose, expression deformation을 추출합니다 (networks L, H, Δ). 여기서 canonical keypoint는 이 사람의 얼굴 모양을 encoding하는 keypoint라고 생각하면 되고 (대충 설명하자면 평균적인 표정을 짓고있을 때 얼굴의 모양을 encoding하는 keypoint), canonical keypoint에 표정에 의해 생기는 deformation과 얼굴 방향 & 위치에 의해 생기는 rotation & translation을 가해주면 source image의 얼굴에 해당하는 keypoint를 얻을 수 있습니다. 말보다는 수식이 편할수도...
\[ x_s = x_{c,s} R_s + \delta_s + t_s\]
\(x_s\): source image의 keypoint, \(x_{c,s}\): source image의 identity에 대응하는 canonical keypoint, \(\delta_s:\) expression deformation, \(t_s\): translation - driving image에서도 head pose & keypoint deformation을 예측합니다. 같은 identity이기 때문에 driving keypoint는 다음과 같이 적을 수 있습니다:
\[ x_d = x_{c,s}R_d + \delta_d + t_d\] - source & driving keypoint를 사용해서 source의 appearance feature \( f_s\)를 driving 얼굴을 생성하기에 적합한 feature로 warping합니다 (network W)
- warping된 feature을 사용해서 driving image를 예측합니다 (network G)
LivePortrait 에서 사용한 개선 방법
데이터
- Voxceleb, MEAD, RAVDESS, AAHQ, 대량의 자체 (비공개) 비디오 데이터.
- KVQ 사용해서 low quality video 제거.
아키텍처
- 위에서 언급한 3개의 모델 (L, H, Δ)을 1개의 모델로 통합 (ConvNeXt-V2-Tiny)
- Generator에 SPADE를 사용해서 warped feature 를 넣어준다. 궁금하면 접은글에 복붙한 코드 참고 (매우 단순합니다).
# https://github.com/KwaiVGI/LivePortrait/blob/main/src/modules/spade_generator.py
class SPADEDecoder(nn.Module):
def __init__(self, upscale=1, max_features=256, block_expansion=64, out_channels=64, num_down_blocks=2):
for i in range(num_down_blocks):
input_channels = min(max_features, block_expansion * (2 ** (i + 1)))
self.upscale = upscale
super().__init__()
norm_G = 'spadespectralinstance'
label_num_channels = input_channels # 256
self.fc = nn.Conv2d(input_channels, 2 * input_channels, 3, padding=1)
self.G_middle_0 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.G_middle_1 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.G_middle_2 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.G_middle_3 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.G_middle_4 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.G_middle_5 = SPADEResnetBlock(2 * input_channels, 2 * input_channels, norm_G, label_num_channels)
self.up_0 = SPADEResnetBlock(2 * input_channels, input_channels, norm_G, label_num_channels)
self.up_1 = SPADEResnetBlock(input_channels, out_channels, norm_G, label_num_channels)
self.up = nn.Upsample(scale_factor=2)
if self.upscale is None or self.upscale <= 1:
self.conv_img = nn.Conv2d(out_channels, 3, 3, padding=1)
else:
self.conv_img = nn.Sequential(
nn.Conv2d(out_channels, 3 * (2 * 2), kernel_size=3, padding=1),
nn.PixelShuffle(upscale_factor=2)
)
def forward(self, feature):
seg = feature # Bx256x64x64
x = self.fc(feature) # Bx512x64x64
x = self.G_middle_0(x, seg)
x = self.G_middle_1(x, seg)
x = self.G_middle_2(x, seg)
x = self.G_middle_3(x, seg)
x = self.G_middle_4(x, seg)
x = self.G_middle_5(x, seg)
x = self.up(x) # Bx512x64x64 -> Bx512x128x128
x = self.up_0(x, seg) # Bx512x128x128 -> Bx256x128x128
x = self.up(x) # Bx256x128x128 -> Bx256x256x256
x = self.up_1(x, seg) # Bx256x256x256 -> Bx64x256x256
x = self.conv_img(F.leaky_relu(x, 2e-1)) # Bx64x256x256 -> Bx3xHxW
x = torch.sigmoid(x) # Bx3xHxW
return x
keypoint 관련
- scale parameter 추가:
\( x_s = x_{c,s} R_s + \delta_s + t_s\) 를 \( x_s = s_{s}(x_{c,s} R_s + \delta_s) + t_s \)로 교체
원래 얼굴 scale을 keypoint deformation이 담당했어야했는데 비효율적이기 때문에 scale 추가한 것으로 이해하면 됩니다. - landmark guidance: 원래는 모든 keypoint를 unsupervised 방법으로 학습했지만 이렇게하면 좀 비효율적인 측면도 있고 keypoint들의 움직임이 해석가능하지 않다보니 일부 keypoint는 landmark detector로 label을 얻어서 pseudo-label로 사용했습니다.
Loss
\( L_E\): equivariance loss (face-vid2vid 참고)
\( L_L\): keypoint prior loss (face-vid2vid 참고)
\( L_H\): head pose loss (face-vid2vid 참고)
\( L_\Delta\): deformation prior loss (face-vid2vid 참고)
\( L_{P,cascade}\): perceptual loss. cascade는 얼굴 전체에 대해서, 입에만 대해서 이렇게 2개 loss 사용했다는 뜻.
\( L_{G,cascade}\): GAN loss. cascade 의미는 위와 동일
\( L_{faceid}\): face identity loss (ArcFace 사용)
\( L_{guide}\): landmark guidance loss; WingLoss 사용.
Stitching & Retargeting
위에서 설명한 방법으로 모델을 학습하고 이후 모든 모델들을 고정시킵니다. 그리고 다음과 같은 2가지 기능을 추가하는 모델을 학습합니다.
- source image를 driving video로 reenact했을 때 head pose 등 바뀔 때 어깨가 움직이지 않도록 보정해주는 stitching 모델
- 눈 & 입 크기가 달라서 생기는 차이를 매꿔주는 retargeting 모델
Stitching module
이 단계에서는 앞에 학습한 모델을 다 고정시킨 상태에서 cross identity 상황에서 학습을 진행합니다 (이전에는 same identity).
keypoint decomposition은 기본적으로 이전과 같은데:
\[ x_s = s_{s}(x_{c,s} R_s + \delta_s) + t_s \]
\[ x_d = s_{d}(x_{c,s} R_d + \delta_d) + t_d \]
차이점은 이제 \( x_{c,s}, x_{c,d}\)는 근본적으로 다른 사람의 identity를 encoding할 수 있습니다 (일반화가 더 잘되도록 cross identity를 사용해서 학습했다고 합니다).
이 때 어깨가 움직이는 것을 방지하기 위해 source & driving keypoint 값을 받아서 보정해주는 모델을 (\( \Delta_{st}\): stitching module) 학습합니다:
\[ \Delta_{st} = S(x_s, x_d)\]
\[ x'_{d,st} = x_d + \Delta_{st}\]
이 모듈을 학습하기 위해 다음과 같은 loss를 추가적으로 사용하는데
여기서 \( 1-M^{st}\)는 어깨 부분에 해당하는 mask입니다. 따라서 \( L_{st,const}\)는 self reconstructed image \( I_{p,recon} = D(W(f_s; x_s, x_s)) \) 와 stitching해서 생성한 이미지 (I_{p,st} = D(W(f_s; x_s, x'_{d,st})))가 어깨 라인에 차이가 없도록 constraint를 주는 objective라고 생각하면 됩니다.
\( || \Delta_{st}||_1\)는 L1 regularization.
Retargeting module
source 와 driving 하는 사람의 눈 크기가 다르면 한쪽에서 눈을 감으면 다른 쪽에서 눈을 덜 감을 수도 있습니다 (motion range 차이). 비슷하게 입에도 문제가 생길 수 있겠죠? 이것을 보정해주기 위해 driving keypoint를 다음과 같이 driving keypoint를 수정해줍니다:
\[x'_{d,eyes} = x_s + \Delta_{eyes} = R_{eyes}(x_s; c_{s,eyes}, c_{d,eyes})\]
\[x'_{d,lip} = x_s + \Delta_{lip} = R_{lip}(x_s; c_{s,lip}, c_{d,lip})\]
그리고 다음과 같이 loss function을 사용합니다. 여기서 \( 1-M^{eyes/lip}\)는 눈/입이 아닌 부분을 의미하기 때문에 눈 / 입이 아닌 부분은 retargeting network가 건들지 않도록 constraint를 건 것이라 생각하면 됩니다.
\( c \)는 눈이 얼마나 감겼는지, 입이 얼마나 열렸는지를 의미하고 landmark detector로 예측한 값을 사용합니다. 따라서 c에 대한 loss는 driving의 눈 뜨인 정도 \( c_{d,eyes}\)를 얻으려면 source keypoint를 얼마나 바꿔야 얻을 수 있는지 맞춰주는 loss라고 보면 됩니다. Lip도 마찬가지.
실험
생략
Comments
- 디테일한 표정 컨트롤은 어려울 듯.
- orthographic camera만 모델링
- Retargeting 효과는 있겠지만 이 부분은 좀 무식한 방법이라고 생각이 드는데, \( c_{eye/lip}\)으로 모델링되는 blendshape은 고작 3개이지만 실제로 얼굴 근육은 훨씬 많고 다른 부위에 대해서는 retargeting이 안됨. (FACS 참고: https://imotions.com/blog/learning/research-fundamentals/facial-action-coding-system/)물론 다른 부분들에 대해서 retargeting 효과가 일반적인 application에서 별로 중요하지는 않을 것이라는 생각은 들지만.