[split] “train_test_split”메소드의 매개 변수 “stratify”(scikit Learn)

train_test_split패키지 scikit Learn에서 사용하려고 하는데 parameter에 문제가 있습니다 stratify. 다음은 코드입니다.

from sklearn import cross_validation, datasets

X = iris.data[:,:2]
y = iris.target

cross_validation.train_test_split(X,y,stratify=y)

그러나 다음과 같은 문제가 계속 발생합니다.

raise TypeError("Invalid parameters passed: %s" % str(options))
TypeError: Invalid parameters passed: {'stratify': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])}

누군가가 무슨 일이 일어나고 있는지 알고 있습니까? 아래는 함수 문서입니다.

[…]

stratify : 배열 유사 또는 없음 (기본값은 없음)

None이 아닌 경우 데이터는 레이블 배열로 사용하여 계층화 된 방식으로 분할됩니다.

버전 0.17의 새로운 기능 : 계층화 분할

[…]



답변

Scikit-Learn은 단지 “stratify”라는 주장을 인식하지 못한다고 말하고 있습니다. 당신이 그것을 잘못 사용하고 있다는 것이 아닙니다. 이는 인용 한 문서에 표시된대로 매개 변수가 버전 0.17에 추가 되었기 때문입니다.

따라서 Scikit-Learn을 업데이트하기 만하면됩니다.


답변

stratify매개 변수는 생성 된 샘플의 값 비율이 매개 변수에 제공된 값의 비율과 같도록 분할 stratify합니다.

예를 들어, 변수 y값을 바이너리 범주 변수 01, 1과 0의 25 % 사람의 75 %가 stratify=y반드시 당신의 임의 분할의 25 %를 가지고 할 것 0‘s와 75 %의 1의를.


답변

Google을 통해 여기에 오는 나의 미래를 위해 :

train_test_split이제에 있습니다 model_selection.

from sklearn.model_selection import train_test_split

# given:
# features: xs
# ground truth: ys

x_train, x_test, y_train, y_test = train_test_split(xs, ys,
                                                    test_size=0.33,
                                                    random_state=0,
                                                    stratify=ys)

그것을 사용하는 방법입니다. random_state재현성을 위해 설정하는 것이 바람직합니다.


답변

이 컨텍스트에서 계층화는 train_test_split 메서드가 입력 데이터 세트와 동일한 비율의 클래스 레이블을 가진 훈련 및 테스트 하위 집합을 반환 함을 의미합니다.


답변

이 코드를 실행 해보십시오. “그냥 작동”합니다.

from sklearn import cross_validation, datasets

iris = datasets.load_iris()

X = iris.data[:,:2]
y = iris.target

x_train, x_test, y_train, y_test = cross_validation.train_test_split(X,y,train_size=.8, stratify=y)

y_test

array([0, 0, 0, 0, 2, 2, 1, 0, 1, 2, 2, 0, 0, 1, 0, 1, 1, 2, 1, 2, 0, 2, 2,
       1, 2, 1, 1, 0, 2, 1])


답변