개발자의 스터디 노트
CNN을 사용한 성씨 분류(3) - 평가, 추론, 분석 하기 본문
1. 테스트 셋에서 평가하기
- MLP의 정확도는 약 50% 정도였습니다. CNN 모델이 어느 정도의 성능 향상이 있었는지 테스트 셋으로 평가해봅시다.
classifier.load_state_dict(torch.load(train_state['model_filename']))
classifier = classifier.to(args.device)
dataset.class_weights = dataset.class_weights.to(args.device)
loss_func = nn.CrossEntropyLoss(dataset.class_weights)
dataset.set_split('test')
batch_generator = generate_batches(dataset,
batch_size=args.batch_size,
device=args.device)
running_loss = 0.
running_acc = 0.
classifier.eval()
for batch_index, batch_dict in enumerate(batch_generator):
# 출력을 계산합니다
y_pred = classifier(batch_dict['x_surname'])
# 손실을 계산합니다
loss = loss_func(y_pred, batch_dict['y_nationality'])
loss_t = loss.item()
running_loss += (loss_t - running_loss) / (batch_index + 1)
# 정확도를 계산합니다
acc_t = compute_accuracy(y_pred, batch_dict['y_nationality'])
running_acc += (acc_t - running_acc) / (batch_index + 1)
train_state['test_loss'] = running_loss
train_state['test_acc'] = running_acc
print("테스트 손실: {};".format(train_state['test_loss']))
print("테스트 정확도: {}".format(train_state['test_acc']))
2. 새로운 성씨를 분류하고 최상위 예측 구하기
- 새로운 데이터 텐서에 배치 차원을 추가할 때 view() 메서드 대신 unsqueeze() 함수를 사용해 size=1인 배치 차원을 추가했습니다.
def predict_nationality(surname, classifier, vectorizer):
"""새로운 성씨로 국적 예측하기
매개변수:
surname (str): 분류할 성씨
classifier (SurnameClassifer): 분류기 객체
vectorizer (SurnameVectorizer): SurnameVectorizer 객체
반환값:
가장 가능성이 높은 국적과 확률로 구성된 딕셔너리
"""
vectorized_surname = vectorizer.vectorize(surname)
vectorized_surname = torch.tensor(vectorized_surname).unsqueeze(0)
result = classifier(vectorized_surname, apply_softmax=True)
probability_values, indices = result.max(dim=1)
index = indices.item()
predicted_nationality = vectorizer.nationality_vocab.lookup_index(index)
probability_value = probability_values.item()
return {'nationality': predicted_nationality, 'probability': probability_value}
new_surname = input("분류하려는 성씨를 입력하세요: ")
def predict_topk_nationality(surname, classifier, vectorizer, k=5):
"""새로운 성씨에 대한 최상위 K개 국적을 예측합니다
매개변수:
surname (str): 분류하려는 성씨
classifier (SurnameClassifer): 분류기 객체
vectorizer (SurnameVectorizer): SurnameVectorizer 객체
k (int): the number of top nationalities to return
반환값:
딕셔너리 리스트, 각 딕셔너리는 국적과 확률로 구성됩니다.
"""
vectorized_surname = vectorizer.vectorize(surname)
vectorized_surname = torch.tensor(vectorized_surname).unsqueeze(dim=0)
prediction_vector = classifier(vectorized_surname, apply_softmax=True)
probability_values, indices = torch.topk(prediction_vector, k=k)
# 반환되는 크기는 (1,k)입니다
probability_values = probability_values[0].detach().numpy()
indices = indices[0].detach().numpy()
results = []
for kth_index in range(k):
nationality = vectorizer.nationality_vocab.lookup_index(indices[kth_index])
probability_value = probability_values[kth_index]
results.append({'nationality': nationality,
'probability': probability_value})
return results
new_surname = input("분류하려는 성씨를 입력하세요: ")
k = int(input("얼마나 많은 예측을 보고 싶나요? "))
if k > len(vectorizer.nationality_vocab):
print("앗! 전체 국적 개수보다 큰 값을 입력했습니다. 모든 국적에 대한 예측을 반환합니다. :)")
k = len(vectorizer.nationality_vocab)
predictions = predict_topk_nationality(new_surname, classifier, vectorizer, k=k)
print("최상위 {}개 예측:".format(k))
print("===================")
for prediction in predictions:
print("{} -> {} (p={:0.2f})".format(new_surname,
prediction['nationality'],
prediction['probability']))
'파이썬 > 파이토치 자연어처리' 카테고리의 다른 글
임베딩 (0) | 2022.03.03 |
---|---|
CNN의 추가 개념 (0) | 2022.02.23 |
CNN을 사용한 성씨 분류(2) - 학습 하기 (0) | 2022.02.23 |
CNN을 사용한 성씨 분류(1) - 학습 준비하기 (0) | 2022.02.23 |
CNN 하이퍼 파라미터 (0) | 2022.02.22 |