개발자의 스터디 노트
CNN을 사용한 성씨 분류(2) - 학습 하기 본문
모델 훈련은 앞서 포스팅했던 과정을 그대로 답습합니다.
1. 학습에 필요한 메서드 생성
def make_train_state(args):
return {'stop_early': False,
'early_stopping_step': 0,
'early_stopping_best_val': 1e8,
'learning_rate': args.learning_rate,
'epoch_index': 0,
'train_loss': [],
'train_acc': [],
'val_loss': [],
'val_acc': [],
'test_loss': -1,
'test_acc': -1,
'model_filename': args.model_state_file}
def update_train_state(args, model, train_state):
""" 훈련 상태를 업데이트합니다.
Components:
- 조기 종료: 과대 적합 방지
- 모델 체크포인트: 더 나은 모델을 저장합니다
:param args: 메인 매개변수
:param model: 훈련할 모델
:param train_state: 훈련 상태를 담은 딕셔너리
:returns:
새로운 훈련 상태
"""
# 적어도 한 번 모델을 저장합니다
if train_state['epoch_index'] == 0:
torch.save(model.state_dict(), train_state['model_filename'])
train_state['stop_early'] = False
# 성능이 향상되면 모델을 저장합니다
elif train_state['epoch_index'] >= 1:
loss_tm1, loss_t = train_state['val_loss'][-2:]
# 손실이 나빠지면
if loss_t >= train_state['early_stopping_best_val']:
# 조기 종료 단계 업데이트
train_state['early_stopping_step'] += 1
# 손실이 감소하면
else:
# 최상의 모델 저장
if loss_t < train_state['early_stopping_best_val']:
torch.save(model.state_dict(), train_state['model_filename'])
# 조기 종료 단계 재설정
train_state['early_stopping_step'] = 0
# 조기 종료 여부 확인
train_state['stop_early'] = \
train_state['early_stopping_step'] >= args.early_stopping_criteria
return train_state
def compute_accuracy(y_pred, y_target):
y_pred_indices = y_pred.max(dim=1)[1]
n_correct = torch.eq(y_pred_indices, y_target).sum().item()
return n_correct / len(y_pred_indices) * 100
2. 학습에 필요한 파라미터 정보 설정
args = Namespace(
# 날짜와 경로 정보
surname_csv="data/surnames_with_splits.csv",
vectorizer_file="vectorizer.json",
model_state_file="model.pth",
save_dir="model_storage/cnn",
# 모델 하이퍼파라미터
hidden_dim=100,
num_channels=256,
# 훈련 하이퍼파라미터
seed=1337,
learning_rate=0.001,
batch_size=128,
num_epochs=100,
early_stopping_criteria=5,
dropout_p=0.1,
# 실행 옵션
cuda=True,
reload_from_files=False,
expand_filepaths_to_save_dir=True,
catch_keyboard_interrupt=True
)
if args.expand_filepaths_to_save_dir:
args.vectorizer_file = os.path.join(args.save_dir,
args.vectorizer_file)
args.model_state_file = os.path.join(args.save_dir,
args.model_state_file)
print("파일 경로: ")
print("\t{}".format(args.vectorizer_file))
print("\t{}".format(args.model_state_file))
# CUDA 체크
if not torch.cuda.is_available():
args.cuda = False
args.device = torch.device("cuda" if args.cuda else "cpu")
print("CUDA 사용여부: {}".format(args.cuda))
def set_seed_everywhere(seed, cuda):
np.random.seed(seed)
torch.manual_seed(seed)
if cuda:
torch.cuda.manual_seed_all(seed)
def handle_dirs(dirpath):
if not os.path.exists(dirpath):
os.makedirs(dirpath)
# 재현성을 위해 시드 설정
set_seed_everywhere(args.seed, args.cuda)
# 디렉토리 처리
handle_dirs(args.save_dir)
3. 훈련
if args.reload_from_files:
# 체크포인트에서 훈련을 다시 시작
dataset = SurnameDataset.load_dataset_and_load_vectorizer(args.surname_csv,
args.vectorizer_file)
else:
# 데이터셋과 Vectorizer 만들기
dataset = SurnameDataset.load_dataset_and_make_vectorizer(args.surname_csv)
dataset.save_vectorizer(args.vectorizer_file)
vectorizer = dataset.get_vectorizer()
classifier = SurnameClassifier(initial_num_channels=len(vectorizer.surname_vocab),
num_classes=len(vectorizer.nationality_vocab),
num_channels=args.num_channels)
classifer = classifier.to(args.device)
dataset.class_weights = dataset.class_weights.to(args.device)
loss_func = nn.CrossEntropyLoss(weight=dataset.class_weights)
optimizer = optim.Adam(classifier.parameters(), lr=args.learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
mode='min', factor=0.5,
patience=1)
train_state = make_train_state(args)
epoch_bar = tqdm.notebook.tqdm(desc='training routine',
total=args.num_epochs,
position=0)
dataset.set_split('train')
train_bar = tqdm.notebook.tqdm(desc='split=train',
total=dataset.get_num_batches(args.batch_size),
position=1,
leave=True)
dataset.set_split('val')
val_bar = tqdm.notebook.tqdm(desc='split=val',
total=dataset.get_num_batches(args.batch_size),
position=1,
leave=True)
try:
for epoch_index in range(args.num_epochs):
train_state['epoch_index'] = epoch_index
# 훈련 세트에 대한 순회
# 훈련 세트와 배치 제너레이터 준비, 손실과 정확도를 0으로 설정
dataset.set_split('train')
batch_generator = generate_batches(dataset,
batch_size=args.batch_size,
device=args.device)
running_loss = 0.0
running_acc = 0.0
classifier.train()
for batch_index, batch_dict in enumerate(batch_generator):
# 훈련 과정은 5단계로 이루어집니다
# --------------------------------------
# 단계 1. 그레이디언트를 0으로 초기화합니다
optimizer.zero_grad()
# 단계 2. 출력을 계산합니다
y_pred = classifier(batch_dict['x_surname'])
# 단계 3. 손실을 계산합니다
loss = loss_func(y_pred, batch_dict['y_nationality'])
loss_t = loss.item()
running_loss += (loss_t - running_loss) / (batch_index + 1)
# 단계 4. 손실을 사용해 그레이디언트를 계산합니다
loss.backward()
# 단계 5. 옵티마이저로 가중치를 업데이트합니다
optimizer.step()
# -----------------------------------------
# 정확도를 계산합니다
acc_t = compute_accuracy(y_pred, batch_dict['y_nationality'])
running_acc += (acc_t - running_acc) / (batch_index + 1)
# 진행 바 업데이트
train_bar.set_postfix(loss=running_loss, acc=running_acc,
epoch=epoch_index)
train_bar.update()
train_state['train_loss'].append(running_loss)
train_state['train_acc'].append(running_acc)
# 검증 세트에 대한 순회
# 검증 세트와 배치 제너레이터 준비, 손실과 정확도를 0으로 설정
dataset.set_split('val')
batch_generator = generate_batches(dataset,
batch_size=args.batch_size,
device=args.device)
running_loss = 0.
running_acc = 0.
classifier.eval()
for batch_index, batch_dict in enumerate(batch_generator):
# 단계 1. 출력을 계산합니다
y_pred = classifier(batch_dict['x_surname'])
# 단계 2. 손실을 계산합니다
loss = loss_func(y_pred, batch_dict['y_nationality'])
loss_t = loss.item()
running_loss += (loss_t - running_loss) / (batch_index + 1)
# 단계 3. 정확도를 계산합니다
acc_t = compute_accuracy(y_pred, batch_dict['y_nationality'])
running_acc += (acc_t - running_acc) / (batch_index + 1)
val_bar.set_postfix(loss=running_loss, acc=running_acc,
epoch=epoch_index)
val_bar.update()
train_state['val_loss'].append(running_loss)
train_state['val_acc'].append(running_acc)
train_state = update_train_state(args=args, model=classifier,
train_state=train_state)
scheduler.step(train_state['val_loss'][-1])
if train_state['stop_early']:
break
train_bar.n = 0
val_bar.n = 0
epoch_bar.update()
except KeyboardInterrupt:
print("Exiting loop")
'파이썬 > 파이토치 자연어처리' 카테고리의 다른 글
CNN의 추가 개념 (0) | 2022.02.23 |
---|---|
CNN을 사용한 성씨 분류(3) - 평가, 추론, 분석 하기 (0) | 2022.02.23 |
CNN을 사용한 성씨 분류(1) - 학습 준비하기 (0) | 2022.02.23 |
CNN 하이퍼 파라미터 (0) | 2022.02.22 |
MLP로 성씨 분류하기 (4) : MLP 규제 적용하기 (0) | 2022.02.19 |