R) DL - kmnist 데이터 살펴보기

R) DL - kmnist 데이터 살펴보기

CRAN RStudio mirror downloads
R의 torchvision 패키지에서 기본적으로 사용하는 kmnist 데이터를 살펴본다.

개요

kmnist 데이터는 숫자로 이루어진 mnist 데이터 세트가 아닌 일본어 관련 데이터이다. 데이터의 공식 문서는 Deep Learning for Classical Japanese Literature 논문을 살펴보면 된다. 해당 논문에는 각 일어의 분류 및 그 예시를 보여준다. (거 악필이 너무 심한거 아니오!!)
torchvision 설치

설치

torchvision 패키지는 아직 cran에 정식 등록되지 않아 install.packages() 함수로 설치를 시도하면 다음과 같이 경고가 뜨면서 설치가 되지 않는다.

1
2
3
install.packages("torchvision")
## Warning in install.packages :
## package ‘torchvision’ is not available (for R version 4.0.2)

다음과 같이 github를 통해 제공하는 패키지를 설치해야 하는데 다음은 remotes 패키지를 활용한 것이고 devtools 패키지도 사용할 수 있다.

1
2
remotes::install_github("mlverse/torchvision") # 1
devtools::install_github("mlverse/torchvision") # 2

torchvision 설치
메세지 아래 부분을 보면 kmnist 관련 문서가 언급되는 것을 볼 수 있다.

그리고 설치하다가 “Torch failed to start, restart your R session to try again.” 이라는 문구가 뜨면서 불완전 설치가 되는 것 처럼 보이는데 R을 재시작 하고 설치해도 똑같으니 당장은 무시할 수 밖에 없다.


로딩

torchvision 을 불러와서 kmnist 데이터 세트를 다운받도록 하자.

1
2
3
4
5
library("torchvision")
ds_mnist = kmnist_dataset(root = ".",
download = TRUE,
train = TRUE,
transform = transform_to_tensor)

kmnist_dataset() 함수는 kmnist 데이터 세트를 불러오는데 필요한 함수로 downloadTRUE를 설정하면 새로 다운받지만 기존에 한 번 다운 받았으면 다시 받지 않는다. trainTRUE를 입력한 경우 학습용 데이터 세트 60000개를 불러온다. 아무튼 ds_mnist 객체에 저장된 데이터를 살펴보도록 하자.


확인

데이터 세트 개요

우선 객체 구조를 보기위해 str() 함수를 써보자.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
str(ds_mnist)
## Classes 'kminst_dataset', 'mnist', 'dataset', 'R6' <kminst_dataset>
## Inherits from: <mnist>
## Public:
## .getitem: function (index)
## .length: function ()
## add: function (other)
## check_exists: function ()
## classes: o ki su tsu na ha ma ya re wo
## clone: function (deep = FALSE)
## data: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ...
## download: function ()
## get_item: function (index)
## initialize: function (root, train = TRUE, transform = NULL, target_transform = NULL,
## processed_folder: active binding
## raw_folder: active binding
## resources: list
## root_path: .
## target_transform: NULL
## targets: 9 8 1 2 5 3 5 9 2 2 6 2 1 6 8 7 2 8 10 6 8 4 8 6 7 7 3 8 ...
## test_file: test.rds
## train: TRUE
## training_file: training.rds
## transform: function (img)

여기서 주목해야 할 항목은 classes(line 9), data(line 11), targets(line 20) 이렇게 세 항목이다.

데이터 세트의 클래스와 개수 확인은 다음과 같이 할 수 있다.

1
2
3
4
5
class(ds_mnist)
## [1] "kminst_dataset" "mnist" "dataset" "R6"

length(ds_mnist)
## [1] 60000

그리고 데이터 targets 에 대응하는 문자가 어떻게 되는지 확인하기 위해서는 다음과 같은 작업이 필수이다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
ds_mnist$classes
## [1] "o" "ki" "su" "tsu" "na" "ha" "ma" "ya" "re" "wo"

df_class = data.frame(code = 1:10,
label = ds_mnist$classes)
df_class
## code label
## 1 1 o
## 2 2 ki
## 3 3 su
## 4 4 tsu
## 5 5 na
## 6 6 ha
## 7 7 ma
## 8 8 ya
## 9 9 re
## 10 10 wo

위에서 만든 데이터프레임과 타겟을 합쳐서 보려면 다음과 같이 한다.

1
2
3
4
5
6
7
8
9
10
df_targets = data.frame(obs = 1:length(ds_mnist),
target = ds_mnist$targets)
head(df_targets)
obs target
## 1 1 9
## 2 2 8
## 3 3 1
## 4 4 2
## 5 5 5
## 6 6 3

두 데이터 세트를 병합해보자.

1
2
3
4
5
6
7
8
9
10
11
library("dplyr")
df_tc_join = left_join(x = df_targets, y = df_class,
by = c("target" = "code"))
head(df_tc_join)
## obs target label
## 1 1 9 re
## 2 2 8 ya
## 3 3 1 o
## 4 4 2 ki
## 5 5 5 na
## 6 6 3 su

개별 데이터 탐색

이제 개별 데이터를 살펴보도록 하자. 첫 번째 데이터의 구조는 다음과 같다.

1
2
3
4
str(ds_mnist[1])
## List of 2
## $ :Float [1:1, 1:28, 1:28]
## $ :Long [1:]

아무튼 첫 번째 list는 28x28 크기의 데이터고 두 번째 list는 크기가 1인 것 같다. 뜯어보면 다음과 같다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
ds_mnist[1][[1]]
## torch_tensor
## (1,.,.) =
## Columns 1 to 9 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.2000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0431 0.8314
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.5804 1.0000
## 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.0000 0.8039 1.0000

ds_mnist[1][[2]]
## torch_tensor
## 9
## [ CPULongType{} ]

첫 번째 list에 있는 28x28 데이터는 너무 많아서 일부만 가져왔다. 두 번째 list를 보니 첫 번째 list에 있는 값의 target을 의미하는 것을 알 수 있다. 아무래도 글자가 어떻게 생겼는지 확인을 해야하지 않을까? 60000개 데이터 중 첫 번째 데이터를 ggplot 패키지로 살펴보기 위해 전처리 절차를 밟도록 하자.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
data_sub = ds_mnist[1] # 첫 번째 데이터 떼어내기
arr = as_array(data_sub[[1]])
dim(arr)
## [1] 1 28 28

arr_col = arr[1, , ]

library("reshape2")
df_arr_col = as.data.frame(cbind(obs = nrow(arr_col):1, # ★
arr_col))
df_arr_col_melt = melt(data = df_arr_col, id.vars = "obs")
df_arr_col_melt[, "variable"] = as.numeric(gsub("V", "", df_arr_col_melt$variable))
head(df_arr_col_melt)
## obs variable value
## 1 28 2 0
## 2 27 2 0
## 3 26 2 0
## 4 25 2 0
## 5 24 2 0
## 6 23 2 0

as_array() 함수는 torch tensor를 R에서 사용할 수 있는 array 객체(matrix와 유사)로 바꿔준다. 기본 객체로 바꿔준 후 reshape2 패키지의 melt() 함수로 그래프를 그리기 적당한 long form 형태로 바꿔주고 적당히 데이터를 처리한 결과를 df_arr_col_melt 객체의 출력값에서 확인할 수 있다. 그리고 코드의 9번 줄(★)의 코드를 보면 “1:nrow(arr_col)”이 아니라 “nrow(arr_col):1” 으로 되어있는데 이는 그래프 표현의 문제라 반대로 하였다.

ggplot2 패키지로 시각화 하는 코드는 다음과 같다.

1
2
3
4
5
6
7
8
9
10
11
12
13
library("ggplot2")
ggplot(data = df_arr_col_melt,
aes(x = variable,
y = obs,
fill = value)) +
geom_tile() +
scale_x_continuous(expand = c(0, 0)) +
scale_y_continuous(expand = c(0, 0)) +
labs(title = df_class[df_class$code == as_array(data_sub[[2]]), "label"],
x = NULL, y = NULL) +
theme_minimal() +
theme(legend.position = "none",
plot.title = element_text(size = 20, hjust = 0.5))

torchvision kmnist train 00001


test 데이터

참고로 테스트 세트는 다음과 같이 불러올 수 있다.

1
2
3
4
5
6
ds_mnist_test = kmnist_dataset(root = ".",
download = TRUE,
train = FALSE,
transform = transform_to_tensor)
length(ds_mnist_test)
## [1] 10000
Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×