일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 |
- LSTM
- kde
- 네이버 부스트캠프
- band_part
- RNN
- Til
- tril
- forward
- Linux
- error
- ai tech
- GRU
- kernel density estimation
- triu
- 크롬 원격 데스크톱
- nn.Sequential
- tensorflow
- ubuntu
- Chrome Remote Desktop
- pytorch
- Today
- Total
무슨 생각을 해 그냥 하는거지
[TensorFlow] tensorflow.linalg.band_part 본문
공식 문서: https://www.tensorflow.org/api_docs/python/tf/linalg/band_part
tf.linalg.band_part | TensorFlow Core v2.8.0
Copy a tensor setting everything outside a central band in each innermost matrix to zero.
www.tensorflow.org
tensorflow.linalg.band_part(input, num_lower, num_upper, name=None)
- input: A Tensor. Rank k tensor
- num_lower: A Tensor. 보존할 subdiagonals의 수. 음수면 모든 lower triangle을 보존함
- num_upper: A Tensor. 보존할 superdiagonals의 수. 음수면 모든 upper triangle을 보존함.
- 위 함수는 input과 같은 크기, 같은 타입의 Tensor를 리턴한다.
The band part is computed as follows: Assume input has k dimensions [I, J, K, ..., M, N], then the output is a tensor with the same shape where
band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n].
The indicator function
in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && (num_upper < 0 || (n-m) <= num_upper).
예제를 보고 이해해보자.
num_lower가 1, num_upper가 -1이다.
가장 중심 대각선(처음 input에서 0인 부분)의 아래 대각선(subdiagonals)들 중 하나만 보존(num_lower==1)한다는 의미이므로 나머지 두 개의 subdiagonals는 0으로 바꾼다.
num_upper가 음수이므로 superdiagonals는 모두 보존한다.
num_lower가 2이므로 맨 마지막 subdiagonal만 0으로 바꾼다.
num_upper가 1이므로 중심 대각선 위쪽의 하나의 대각선만 살리고 나머지는 0으로 바꾼다.
특수한 경우는 아래와 같다.
- num_lower == 0이므로 모든 subdiagonals를 지우고, num_upper == -1이므로 모든 superdiagonals는 보존한다. → 상삼각행렬(Upper triangular matrix)
- num_lower == -1이므로 모든 subdiagonals를 보존하고, num_upper == 0이므로 모든 superdiagonals를 지운다. → 하삼각행렬 (Lower triangular matrix)
- num_lower == 0, num_upper == 0이므로 모든 subdiagonals와 superdiagonals를 지운다 → 대각행렬
그렇다면 PyTorch에서 TensorFlow의 band_part와 같은 역할을 하는 함수는 무엇일까?
완전히 같은 역할을 수행하지는 않지만 torch.tril, torch.triu를 사용할 수 있다.
torch.tril은 하삼각행렬(tri 뒤의 l이 lower을 의미)
torch.triu는 상삼각행렬(tri뒤의 u가 upper을 의미)
을 만들어주는 함수이다.
둘 다 diagonal 이라는 인자를 갖고 있는데, 중심이라고 여겨질 대각선을 의미한다.
diagonal==0 (default)인 경우, 기존 행렬의 중심 대각선을 중심으로 상삼각행렬 또는 하삼각행렬을 만든다.
diagonal == 1 인 경우, 기존 행렬의 중심 대각선 다음의 대각선을 중심으로 상삼각행렬 또는 하삼각행렬을 만든다.
torch.triu 를 사용하는 경우 중심 대각선이 하나 위의 것으로 선택되고 기존 중심 대각선이 0으로 바뀐다.
diagnoal == -1인 경우, torch.triu를 사용한다면 중심 대각선이 하나 아래의 것으로 선택된다.
tensorflow에서의 band_part는 중심 대각행렬을 중심으로 보존할 대각선들을 선택하였지만 (양방향 선택 가능)
pytorch에서의 triu, tril은 중심 대각선을 선택하여 보존할 부분을 한 방향으로만 선택할 수 있다.
사실 양 방향으로 선택할 일이 잘 없으니 band_part의 경우 특수한 상황(상삼각행렬, 하삼각행렬이 되는 경우)만 신경써주면 될 듯하다.
'COMPUTER > Deep Learning' 카테고리의 다른 글
[Pytorch] nn.Sequential은 multiple input을 다룰 수 없다 (0) | 2022.02.16 |
---|