sklearn 라이브러리의 train_test_split()
함수를 사용하여 데이터세트를 분할하는 방법에 대해 알아본다.
※ 본 내용에서 사용하는 diamonds.csv 파일은 별도로 다운로드 받아야 한다.
※ diamonds.csv 다운받기 [클릭]
개요
머신러닝의 지도학습(supervised learning)을 사용하여 예측값을 생산하고자 하는 경우 데이터세트를 2개 이상으로 분할한다. 해당 작업을 하기 위해 기본적으로 사용하는 함수가 sklearn 라이브러리의 train_test_split()
함수이다. 해당 함수는 입력된 데이터세트를 지정한 비율(또는 개수)로 분할해주며 기본적으로 row index 기준 단순임의추출(simple random sampling)을 실시한다.
해당 함수의 주요 인자별 설명은 다음과 같다.
train_size: 학습 데이터세트 비율을 뜻하며 첫 번째(왼쪽) 객체에 할당되는 행의 개수 또는 비율을 설정할 수 있다. 보통 0에서 1사이의 숫자로 비율을 입력하며 필요시 정수를 입력하기도 한다. 만약 이 인자에 값을 할당 하지 않을 경우 “test_size” 인자에 할당된 값에 기반하여 결정된다.
test_size: 평가 데이터세트 비율을 뜻하며 두 번째(오른쪽) 객체에 할당되는 행의 개수 또는 비율을 설정할 수 있다. 보통 0에서 1사이의 숫자로 비율을 입력하며 필요시 정수를 입력하기도 한다. 만약 이 인자에 값을 할당 하지 않을 경우 0.25로 동작한다.
random_state:
train_test_split()
함수는 기본적으로 단순 임의 표본추출을 실시한다. 만약 이런 임의 확률과정의 결과를 고정하거나 재현을 하고자 할 때 여기에 정수를 입력할 수 있다. 이 인자에 입력되는 숫자를 포함하여 다른 인자에 할당되는 값이 같고 입력되는 데이터가 같은 경우는 언제나 똑같은 결과를 재현할 수 있다.
※ NumPy 버전이 다른 경우는 제외하며 특히 1.17버전 전후로 그 결과가 차이날 수 있다.stratify: 층화표본추출을 실시하고자 할 때 사용하며 기준이 되는 범주형 변수를 하나 할당 가능하다.
참고로 sklearn 버전이 낮은 경우 train_test_split()
함수의 자동완성이 똑바로 지원되지 않으니 되도록이면 최신버전을 사용하는 것을 추천한다.
실습
다음과 같이 데이터세트를 준비한다.
1 | df = pd.read_csv("diamonds.csv") |
carat | cut | color | clarity | depth | table | price | x | y | z | |
---|---|---|---|---|---|---|---|---|---|---|
0 | 0.23 | Ideal | E | SI2 | 61.5 | 55.0 | 326 | 3.95 | 3.98 | 2.43 |
1 | 0.21 | Premium | E | SI1 | 59.8 | 61.0 | 326 | 3.89 | 3.84 | 2.31 |
train_test_split()
의 기본 코드는 다음과 같다. 분리하고자 하는 데이터프레임을 넣고 “train_size” 또는 “test_size” 인자에 분할 개수 또는 비율을 할당하고 필요에 따라 결과 고정(또는 재현)을 위해 “random_state” 인자에 숫자를 할당한다. 이렇게 하면 결과는 두 개의 데이터프레임 객체가 반환되며 첫 번째(왼쪽) 객체에 담기는 것이 학습 데이터세트이다.
1 | df_train, df_test = train_test_split(df, train_size = 0.8, |
carat | cut | color | clarity | depth | table | price | x | y | z | |
---|---|---|---|---|---|---|---|---|---|---|
13361 | 0.24 | Very Good | F | VS2 | 62.9 | 58.0 | 419 | 3.94 | 4.01 | 2.50 |
18592 | 1.02 | Very Good | F | VS1 | 62.4 | 58.0 | 7587 | 6.42 | 6.47 | 4.02 |
데이터프레임이 아닌 NumPy Array 객체를 입력할 경우 반환되는 객체 또한 Array이다.
1 | arr_train, arr_test = train_test_split(df.values, train_size = 0.8, |
sklearn 라이브러리의 공식문서를 보면 2개의 객체를 train_test_split()
함수에 입력하는데 이 경우 4개의 객체가 반환된다. 예를 들어 “price” 변수를 종속변수로 취급한다고 했을 때 다음과 같이 종속변수와 나머지 변수를 분리해보자.
1 | df_X = df.copy() |
carat | cut | color | clarity | depth | table | x | y | z | |
---|---|---|---|---|---|---|---|---|---|
0 | 0.23 | Ideal | E | SI2 | 61.5 | 55.0 | 3.95 | 3.98 | 2.43 |
1 | 0.21 | Premium | E | SI1 | 59.8 | 61.0 | 3.89 | 3.84 | 2.31 |
1 | ser_y.head(2) |
다음과 같이 총 4개의 객체가 반환되며 첫 2개는 “df_X” 객체의 분할, 마지막 2개는 “ser_y” 객체의 분할 결과이다.
1 | df_tr_X, df_te_X, ser_tr_y, ser_te_y = train_test_split(df_X, |
층화 표본추출을 위해서는 “stratify” 인자에 기준이 되는 범주형 변수(시리즈 객체)를 할당하면 되며 층화 표본추출의 결과 확인은 해당 변수의 원소 비율을 확인하면 된다.
1 | df_train, df_test = train_test_split(df, |
단, “stratify” 인자에는 단일 변수를 기준으로 층화 표본추출을 지원하며 다음과 같이 2개 변수를 할당하더라도 에러는 발생하지 않지만 제대로 동작하지 않는 것을 알 수 있다.
1 | df_train2, df_test2 = train_test_split(df, train_size = 0.8, random_state = 123, |
cut | color | x_train | x_test | |
---|---|---|---|---|
0 | Fair | D | 3.714286 | 0.942857 |
1 | Fair | E | 5.114286 | 1.285714 |
2 | Fair | F | 7.142857 | 1.771429 |
3 | Fair | G | 7.171429 | 1.800000 |
4 | Fair | H | 6.914286 | 1.742857 |
5 | Fair | I | 4.000000 | 1.000000 |
6 | Fair | J | 2.714286 | 0.685714 |
7 | Good | D | 15.142857 | 3.771429 |
8 | Good | E | 21.314286 | 5.342857 |
9 | Good | F | 20.771429 | 5.200000 |
10 | Good | G | 19.914286 | 4.971429 |
11 | Good | H | 16.057143 | 4.000000 |
12 | Good | I | 11.942857 | 2.971429 |
13 | Good | J | 7.028571 | 1.742857 |
14 | Ideal | D | 64.771429 | 16.200000 |
15 | Ideal | E | 89.228571 | 22.285714 |
16 | Ideal | F | 87.457143 | 21.857143 |
17 | Ideal | G | 111.628571 | 27.914286 |
18 | Ideal | H | 71.200000 | 17.800000 |
19 | Ideal | I | 47.857143 | 11.942857 |
20 | Ideal | J | 20.485714 | 5.114286 |
21 | Premium | D | 36.628571 | 9.171429 |
22 | Premium | E | 53.428571 | 13.342857 |
23 | Premium | F | 53.285714 | 13.314286 |
24 | Premium | G | 66.828571 | 16.714286 |
25 | Premium | H | 53.942857 | 13.485714 |
26 | Premium | I | 32.628571 | 8.171429 |
27 | Premium | J | 18.457143 | 4.628571 |
28 | Very Good | D | 34.600000 | 8.628571 |
29 | Very Good | E | 54.857143 | 13.714286 |
30 | Very Good | F | 49.457143 | 12.371429 |
31 | Very Good | G | 52.542857 | 13.142857 |
32 | Very Good | H | 41.685714 | 10.428571 |
33 | Very Good | I | 27.514286 | 6.885714 |
34 | Very Good | J | 15.485714 | 3.885714 |