무슨 생각을 해 그냥 하는거지

[TensorFlow] tensorflow.linalg.band_part 본문

COMPUTER/Deep Learning

[TensorFlow] tensorflow.linalg.band_part

빛나는콩 2022. 2. 14. 13:34

공식 문서: https://www.tensorflow.org/api_docs/python/tf/linalg/band_part

 

tf.linalg.band_part  |  TensorFlow Core v2.8.0

Copy a tensor setting everything outside a central band in each innermost matrix to zero.

www.tensorflow.org

tensorflow.linalg.band_part(input, num_lower, num_upper, name=None)

  • input: A Tensor. Rank k tensor
  • num_lower: A Tensor. 보존할 subdiagonals의 수. 음수면 모든 lower triangle을 보존함
  • num_upper: A Tensor. 보존할 superdiagonals의 수. 음수면 모든 upper triangle을 보존함.
  • 위 함수는 input과 같은 크기, 같은 타입의 Tensor를 리턴한다.

 

The band part is computed as follows: Assume input has k dimensions [I, J, K, ..., M, N], then the output is a tensor with the same shape where

band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n].

The indicator function

in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && (num_upper < 0 || (n-m) <= num_upper).

 

 

예제를 보고 이해해보자.

 

예시 #1

num_lower가 1, num_upper가 -1이다.

가장 중심 대각선(처음 input에서 0인 부분)의 아래 대각선(subdiagonals)들 중 하나만 보존(num_lower==1)한다는 의미이므로 나머지 두 개의 subdiagonals는 0으로 바꾼다.

num_upper가 음수이므로 superdiagonals는 모두 보존한다.

 

예시 #2

num_lower가 2이므로 맨 마지막 subdiagonal만 0으로 바꾼다.

num_upper가 1이므로 중심 대각선 위쪽의 하나의 대각선만 살리고 나머지는 0으로 바꾼다.

 

특수한 경우는 아래와 같다.

  • num_lower == 0이므로 모든 subdiagonals를 지우고, num_upper == -1이므로 모든 superdiagonals는 보존한다. → 상삼각행렬(Upper triangular matrix)
  • num_lower == -1이므로 모든 subdiagonals를 보존하고, num_upper == 0이므로 모든 superdiagonals를 지운다. → 하삼각행렬 (Lower triangular matrix)
  • num_lower == 0, num_upper == 0이므로 모든 subdiagonals와 superdiagonals를 지운다 → 대각행렬

 

 

그렇다면 PyTorch에서 TensorFlow의 band_part와 같은 역할을 하는 함수는 무엇일까?

완전히 같은 역할을 수행하지는 않지만 torch.tril, torch.triu를 사용할 수 있다.

 

torch.tril은 하삼각행렬(tri 뒤의 l이 lower을 의미)

torch.triu는 상삼각행렬(tri뒤의 u가 upper을 의미)

을 만들어주는 함수이다.

 

둘 다 diagonal 이라는 인자를 갖고 있는데, 중심이라고 여겨질 대각선을 의미한다.

diagonal==0 (default)인 경우, 기존 행렬의 중심 대각선을 중심으로 상삼각행렬 또는 하삼각행렬을 만든다.

diagonal == 0

 

diagonal == 1 인 경우, 기존 행렬의 중심 대각선 다음의 대각선을 중심으로 상삼각행렬 또는 하삼각행렬을 만든다.

torch.triu 를 사용하는 경우 중심 대각선이 하나 위의 것으로 선택되고 기존 중심 대각선이 0으로 바뀐다.

diagonal == 1

diagnoal == -1인 경우, torch.triu를 사용한다면 중심 대각선이 하나 아래의 것으로 선택된다.

diagonal == -1

 

tensorflow에서의 band_part는 중심 대각행렬을 중심으로 보존할 대각선들을 선택하였지만 (양방향 선택 가능)

pytorch에서의 triu, tril은 중심 대각선을 선택하여 보존할 부분을 한 방향으로만 선택할 수 있다.

 

사실 양 방향으로 선택할 일이 잘 없으니 band_part의 경우 특수한 상황(상삼각행렬, 하삼각행렬이 되는 경우)만 신경써주면 될 듯하다.