무슨 생각을 해 그냥 하는거지

[Sketch Recognition] Sketch-R2CNN: An RNN-Rasterization-CNN Architecture for Vector Sketch Recognition 본문

COMPUTER/논문 리뷰

[Sketch Recognition] Sketch-R2CNN: An RNN-Rasterization-CNN Architecture for Vector Sketch Recognition

빛나는콩 2022. 1. 14. 10:02

※ 논문 리뷰X. 혼자 읽고 정리하는 포스트입니다 ※

 

 

서론

기존 sketch recognition 연구들은 벡터 이미지를 binary 이미지로 변환해서 CNN 모델에 사용하였음.

벡터 이미지는 sequential한 정보가 있기 때문에 이런 방법은 좋지 않음.
(그래서 RNN 계열 모델들이 나오고 있음)

 

기존에 있던 SketchRNN의 경우 RNN과 CNN이 각각 다른 branch (병렬적으로 사용됨)

두 모델을 병렬적으로 사용할 경우 모델 학습에 서로 영향을 거의 주지 않음

 

본 논문에서는 RNN과 CNN을 end-to-end로 학습함.

input vector sketch를 neural network에서 pixel 이미지로 변환한다는 것이 포인트

 

 

Ilustration of Sketch-R2CNN

RNN + NLR(Neural Line Rasterization) + CNN 구조

  • RNN에는 vector sketch만 들어감 → RNN이 각 point에서 feature representation을 추출.
  • NLR은 per-point features와 vector sketch를 multi-channel point feature maps로 변환.
  • CNN은 그 point feature maps를 입력으로 받아서 예측값을(카테고리) 반환.

 

(NLR은 완전히 다른 space에서 동작하는 CNN과 RNN을 연결해주는 역할)

 

 

Method

Input Representation

ordered point sequence

  • n은 모든 stroke에 있는 point 개수.
  • x, y는 point의 2D 좌표.
  • s는 i 번째 점이 해당 stroke의 마지막 점인지 나타내는 값 (0이면 i+1 번째 점도 i 번째 점과 같은 stroke에 존재함을 의미하며, 1인 경우 i 번째 점이 해당 stroke의 마지막 점이고 i+1 번째 점은 다른 stroke의 점임을 의미함.)

 

Network Achitecture

[RNN]

time step i에서의 RNN operation

  • h는 hidden states
  • c는 optinal cell states
  • f는 p에 대한 d-dimensional point feature output
  • G_r은 internal states를 순환적으로 업데이트하는 nonlinear mapping
    • G_r == bidirectional LSTM unit with two layers
    • implementation details) hidden states와 cell states 모두 size 512, dropout probabilty=0.5
  • G_f는 hidden state를 원하는 output(multi-channel point feature)으로 바꾸는 nonlinear function
    • G_f == fc layer + sigmoid function

 

위 vector sketch encoding scheme은 SketchRNN의 encoder network를 따름

 

point의 좌푯값은 절대 좌푯값을 사용하는 것이 아니라 전 point p_i-1로부터의 offset을 계산하여 사용.

 

 

RNN은 temporal 정보를, CNN은 spatial 정보를 잘 반영한다.

(CNN lower layer에서는 가까운 픽셀들, higher layer에서는 먼 픽셀들끼리 상호작용한다.)

 

 

[NLR]

NLR module은 RNN으로부터 생성되는 per-point feature를 multi-channel image에 "그린다(draw)"고 볼 수 있음.

NLR의 output은 h x w x d 사이즈의 d-channel point feature maps로, 각 채널은 point features의 하나의 component에 상응함. (h, w는 만들어지는 maps의 height, width이며 d는 hyper param)

 

직관적으로 생각하면 forward pass 때는 I라는 캔버스에 유효한(s_i=0인) 라인을 그리는 것.

전통적인 line rasterization처럼 I_k 픽셀인지 아닌지는 line segment인 p_(i)p_(i+1)에 의해 결정된다.

(pixel의 센터와 line segment 사이의 거리가 미리 정해놓은 threshold보다 작은지 계산 - 본 논문에서는 threshold=1로 실험함)

 

threshold보다 거리가 작으면 I_k를 stroke pixel로 판단하고, linear interpolation으로 feature value를 계산함.

threshold보다 거리가 크면 값은 0.

 

p^k는 I_k의 센터를 line segment에 projection한 점

p^k, p_i, p_i+1은 모두 절대 좌푯값

 

 

backward pass 때 (loss를 계산할 때) 위의 rasterization process를 미분해야 한다.

Eq 2의 linear interpolation이 단순해서 p_(i)p_(i+1)의 gradient를 계산하면 Eq 3과 같음

L을 loss function,

delta를 "CNN을 거치면서 I_k까지 back-propagation한", L에 대한 gradient라 하자.

 

연쇄법칙에 의해 아래와 같이 계산할 수 있다.

 

recurrent relation은 RNN(i.e. G_r)에서 계산되기 때문에, NLR에서는 계산하지 않아도 된다.

 

 

Experiments & Result

윗 섹션에서 HOG-SVM ~ SN v2까지의 결과는 Sketch-a-Net(2017) 논문의 결과를 가져온 것이다. 아래 섹션이 본 논문에서 실험한 내용 (metric: top-1 accuracy)
실제 상황에서는 유저들이 stroke by stroke로 그리기 때문에, stroke의 일부(25%, 50%, 75%)만 사용해서 테스트하였다. 최소한 하나의 stroke가 있도록 하였다.

추가) training 할 때 stroke의 temporal 정보가 얼마나 영향을 미치는지 ablation study를 진행하였는데,

temporal 정보를 없애도록 데이터를 랜덤하게 섞었을 때 정확도가 83.2%로 기존 84.8%보다 1.6%의 성능 감소를 보였다.

 

pen state가 얼마나 기여하는지 확인하기 위해 빼고 실험해보았지만 0.2% 정도의 미미한 성능 감소를 보였다.

temporal 정보와 pen state 모두 제거했을 때는 1.8%의 성능 감소를 보였다.

 

 

오른쪽 grid image를 보면 stroke 부분만 valid point features를 갖는 것을 볼 수 있다. (stroke가 아닌 부분은 0, not color-coded)

본 논문에서 제시한 NLR이 기존 이미지와 별다른 차이 없이 rasterized 이미지를 생성하는 것을 알 수 있다.

왼쪽 이미지에서 초록색은 Sketch-R2CNN이 맞춘 정답, 빨간색은 ResNet101이 예측한 틀린 정답이다.

 

 

Limitations

앞서 언급했던 것과 반대로 Figure 5는 Resnet101이 정답을 맞춘 경우(초록색), Sketch-R2CNN이 정답을 틀린 경우이다.

RNN이 CNN을 guide하기엔 특징적인 feature를 잘 못 찾은 것이라고 해석할 수도 있지만... 사람이 보기에도 어려운 분류이다.

 

 

 

Conclusion

본 논문은 새로운 single-branch 구조(RNN-Rasterization-CNN)를 제안했다. (NLR 모델이 미분가능하다는 게 keypoint인듯)

NLR 모듈은 sketch retrieval, sketch synthesis, sketch simplification 같은 스케치 데이터를 활용하는 task에 좋을 거다