Project/pytorch

얼굴 나이 인식기 개발 - 4 train code (Using EfficientNet with Pytorch)

ys_cs17 2020. 9. 17. 13:34
반응형

PREVIEW

이번 시간부터 본격적인 CNN 학습을 할 것이다.

우리는 classification을 위해 2019년 State-of-art를 달성한 EfficientNet을 사용하여 사람의 얼굴에 대해 나이 인식기를 제작할 것이다.

우선 EfficientNet에 대해 알아야 할 필요가 있다.

아직 EfficientNet에 대해 알지 못하는 사람은 아래 링크를 통해 참고하길 바란다.

https://ys-cs17.tistory.com/30

 

EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks 분석

해당 논문은 MNasNet의 저자인 Mingxing Tan과 Quoc V.Le가 쓴 논문이고, 2019년 CVPR에서 발표되었다. ResNet의 residual block의 등장 이후로 CNN은 2가지 성향으로 발전하고 있다. 이는 정확도만 높이는 방향과..

ys-cs17.tistory.com

전체 코드는 아래를 참고하길 바란다.

https://github.com/yunseokddi/pytorch_dev/tree/master/facial_age_classifier/EfficientNet_ver

 

yunseokddi/pytorch_dev

My projects with pytorch . Contribute to yunseokddi/pytorch_dev development by creating an account on GitHub.

github.com

 

EfficientNet (pytorch version)

pytorch를 정상적으로 다운로드하였으면 torchvision도 함께 설치가 된다.

 

 

 

 

이전에 작성하였던 코드는 resnet으로 구현을 완료하였지만, 성능이 매우 좋지 않아 efficientnet으로 진행하려고 한다.

하지만 우리가 사용할 efficientnet은 pytorch에서 기본적으로 제공이 되지 않는다.

따라서 직접 구현을 할 수 있지만, 우리는 연구적인 관점에서 코드를 작성하는 것이 아닌 프로젝트 관점에서 코드를 작성하는 것이기 때문에 다른 사람이 구현한 efficientnet을 가지고 오는 것이 좋다.

efficientnet 구현은 조만간 포스팅을 진행할 것이다.

 

 

pip install efficientnet_pytorch

해당 명령어를 통해 efficientnet을 쉽게 불러올 수 있다. (python3는 pip3로 다운)

원본 Github: https://github.com/lukemelas/EfficientNet-PyTorch

 

1. model_viewer.py

이전 포스팅에서 말했던 바와 같이 딥러닝을 구축하든, 어떠한 개발을 하든 한 번에 완성하는 것보다 중간중간 디버깅하는 습관을 들이는 것이 좋다.

다른 코드에서는 IDE를 사용해 디버깅을 할 수 있지만 딥러닝에서는 코드뿐만이 아니라 모델, 데이터, 가중치, 입력 값 등으로 인해 다양한 오류가 존재한다. 따라서 코드를 한 번에 완성하고 컴파일을 하다 보면 어디서 오류가 났는지 알기 힘든 경우가 있다.

우리는 저번 시간에 데이터에 관해 디버깅을 하였고, 이번 코드에서는 모델에 관하여 디버깅을 진행해보자.

 

from torchsummary import summary
from efficientnet_pytorch import EfficientNet


model = EfficientNet.from_pretrained('efficientnet-b3', num_classes=15)

summary(model, input_size=(3,128,128), device='cpu')

pytorch에서는 model을 시각화하기 위해 torchsummary를 제공한다. 혹시 pip에 다운에 되지 않았으면

pip install torchsummary로 쉽게 다운할 수 있다.

 

또한 위에서 다운로드한 efficientnet_pytorch도 함께 불러오자.

그리고 model을 선언하자.

efficientnet_pytorch는 from_pretrained 메서드를 통해 b0~b8을 호출할 수 있고, 2번째 parameter로 학습시키는 데이터에 대한 num_classes를 주면 된다.

 

본인의 gpu상황에 따라 b0을 쓰든 b7을 쓰던 상관이 없지만, gpu memory가 8g 아래인 사람들은 b3을 추천한다.

여기서 더 높아지면 memory 초과가 일어나 제대로 동작을 하지 않을 수도 있다.

 

그다음은 summary 함수이다. 이 함수를 통해 model의 정보를 알 수 있다. input size는 RGB channel이 3이고, 가로, 세로가 128인 image를 넣을 것이기 때문에 (3,128,128)로 두었고, summary에 사용할 model은 아직 이 함수가 cuda에 대한 오류가 있어 device를 cpu로 설정하였다.

자 이제 이 코드를 compile 해보자.

Loaded pretrained weights for efficientnet-b3
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
         ZeroPad2d-1          [-1, 3, 130, 130]               0
Conv2dStaticSamePadding-2           [-1, 40, 64, 64]           1,080
       BatchNorm2d-3           [-1, 40, 64, 64]              80
MemoryEfficientSwish-4           [-1, 40, 64, 64]               0
         ZeroPad2d-5           [-1, 40, 66, 66]               0
Conv2dStaticSamePadding-6           [-1, 40, 64, 64]             360
       BatchNorm2d-7           [-1, 40, 64, 64]              80
MemoryEfficientSwish-8           [-1, 40, 64, 64]               0
          Identity-9             [-1, 40, 1, 1]               0
Conv2dStaticSamePadding-10             [-1, 10, 1, 1]             410
MemoryEfficientSwish-11             [-1, 10, 1, 1]               0
         Identity-12             [-1, 10, 1, 1]               0
Conv2dStaticSamePadding-13             [-1, 40, 1, 1]             440
         Identity-14           [-1, 40, 64, 64]               0
Conv2dStaticSamePadding-15           [-1, 24, 64, 64]             960
      BatchNorm2d-16           [-1, 24, 64, 64]              48
      MBConvBlock-17           [-1, 24, 64, 64]               0
        ZeroPad2d-18           [-1, 24, 66, 66]               0
Conv2dStaticSamePadding-19           [-1, 24, 64, 64]             216
      BatchNorm2d-20           [-1, 24, 64, 64]              48
MemoryEfficientSwish-21           [-1, 24, 64, 64]               0
         Identity-22             [-1, 24, 1, 1]               0
Conv2dStaticSamePadding-23              [-1, 6, 1, 1]             150
MemoryEfficientSwish-24              [-1, 6, 1, 1]               0
         Identity-25              [-1, 6, 1, 1]               0
Conv2dStaticSamePadding-26             [-1, 24, 1, 1]             168
         Identity-27           [-1, 24, 64, 64]               0
Conv2dStaticSamePadding-28           [-1, 24, 64, 64]             576
      BatchNorm2d-29           [-1, 24, 64, 64]              48
      MBConvBlock-30           [-1, 24, 64, 64]               0
         Identity-31           [-1, 24, 64, 64]               0
Conv2dStaticSamePadding-32          [-1, 144, 64, 64]           3,456
      BatchNorm2d-33          [-1, 144, 64, 64]             288
MemoryEfficientSwish-34          [-1, 144, 64, 64]               0
        ZeroPad2d-35          [-1, 144, 66, 66]               0
Conv2dStaticSamePadding-36          [-1, 144, 32, 32]           1,296
      BatchNorm2d-37          [-1, 144, 32, 32]             288
MemoryEfficientSwish-38          [-1, 144, 32, 32]               0
         Identity-39            [-1, 144, 1, 1]               0
Conv2dStaticSamePadding-40              [-1, 6, 1, 1]             870
MemoryEfficientSwish-41              [-1, 6, 1, 1]               0
         Identity-42              [-1, 6, 1, 1]               0
Conv2dStaticSamePadding-43            [-1, 144, 1, 1]           1,008
         Identity-44          [-1, 144, 32, 32]               0
Conv2dStaticSamePadding-45           [-1, 32, 32, 32]           4,608
      BatchNorm2d-46           [-1, 32, 32, 32]              64
      MBConvBlock-47           [-1, 32, 32, 32]               0
         Identity-48           [-1, 32, 32, 32]               0
Conv2dStaticSamePadding-49          [-1, 192, 32, 32]           6,144
      BatchNorm2d-50          [-1, 192, 32, 32]             384
MemoryEfficientSwish-51          [-1, 192, 32, 32]               0
        ZeroPad2d-52          [-1, 192, 34, 34]               0
Conv2dStaticSamePadding-53          [-1, 192, 32, 32]           1,728
      BatchNorm2d-54          [-1, 192, 32, 32]             384
MemoryEfficientSwish-55          [-1, 192, 32, 32]               0
         Identity-56            [-1, 192, 1, 1]               0
Conv2dStaticSamePadding-57              [-1, 8, 1, 1]           1,544
MemoryEfficientSwish-58              [-1, 8, 1, 1]               0
         Identity-59              [-1, 8, 1, 1]               0
Conv2dStaticSamePadding-60            [-1, 192, 1, 1]           1,728
         Identity-61          [-1, 192, 32, 32]               0
Conv2dStaticSamePadding-62           [-1, 32, 32, 32]           6,144
      BatchNorm2d-63           [-1, 32, 32, 32]              64
      MBConvBlock-64           [-1, 32, 32, 32]               0
         Identity-65           [-1, 32, 32, 32]               0
Conv2dStaticSamePadding-66          [-1, 192, 32, 32]           6,144
      BatchNorm2d-67          [-1, 192, 32, 32]             384
MemoryEfficientSwish-68          [-1, 192, 32, 32]               0
        ZeroPad2d-69          [-1, 192, 34, 34]               0
Conv2dStaticSamePadding-70          [-1, 192, 32, 32]           1,728
      BatchNorm2d-71          [-1, 192, 32, 32]             384
MemoryEfficientSwish-72          [-1, 192, 32, 32]               0
         Identity-73            [-1, 192, 1, 1]               0
Conv2dStaticSamePadding-74              [-1, 8, 1, 1]           1,544
MemoryEfficientSwish-75              [-1, 8, 1, 1]               0
         Identity-76              [-1, 8, 1, 1]               0
Conv2dStaticSamePadding-77            [-1, 192, 1, 1]           1,728
         Identity-78          [-1, 192, 32, 32]               0
Conv2dStaticSamePadding-79           [-1, 32, 32, 32]           6,144
      BatchNorm2d-80           [-1, 32, 32, 32]              64
      MBConvBlock-81           [-1, 32, 32, 32]               0
         Identity-82           [-1, 32, 32, 32]               0
Conv2dStaticSamePadding-83          [-1, 192, 32, 32]           6,144
      BatchNorm2d-84          [-1, 192, 32, 32]             384
MemoryEfficientSwish-85          [-1, 192, 32, 32]               0
        ZeroPad2d-86          [-1, 192, 36, 36]               0
Conv2dStaticSamePadding-87          [-1, 192, 16, 16]           4,800
      BatchNorm2d-88          [-1, 192, 16, 16]             384
MemoryEfficientSwish-89          [-1, 192, 16, 16]               0
         Identity-90            [-1, 192, 1, 1]               0
Conv2dStaticSamePadding-91              [-1, 8, 1, 1]           1,544
MemoryEfficientSwish-92              [-1, 8, 1, 1]               0
         Identity-93              [-1, 8, 1, 1]               0
Conv2dStaticSamePadding-94            [-1, 192, 1, 1]           1,728
         Identity-95          [-1, 192, 16, 16]               0
Conv2dStaticSamePadding-96           [-1, 48, 16, 16]           9,216
      BatchNorm2d-97           [-1, 48, 16, 16]              96
      MBConvBlock-98           [-1, 48, 16, 16]               0
         Identity-99           [-1, 48, 16, 16]               0
Conv2dStaticSamePadding-100          [-1, 288, 16, 16]          13,824
     BatchNorm2d-101          [-1, 288, 16, 16]             576
MemoryEfficientSwish-102          [-1, 288, 16, 16]               0
       ZeroPad2d-103          [-1, 288, 20, 20]               0
Conv2dStaticSamePadding-104          [-1, 288, 16, 16]           7,200
     BatchNorm2d-105          [-1, 288, 16, 16]             576
MemoryEfficientSwish-106          [-1, 288, 16, 16]               0
        Identity-107            [-1, 288, 1, 1]               0
Conv2dStaticSamePadding-108             [-1, 12, 1, 1]           3,468
MemoryEfficientSwish-109             [-1, 12, 1, 1]               0
        Identity-110             [-1, 12, 1, 1]               0
Conv2dStaticSamePadding-111            [-1, 288, 1, 1]           3,744
        Identity-112          [-1, 288, 16, 16]               0
Conv2dStaticSamePadding-113           [-1, 48, 16, 16]          13,824
     BatchNorm2d-114           [-1, 48, 16, 16]              96
     MBConvBlock-115           [-1, 48, 16, 16]               0
        Identity-116           [-1, 48, 16, 16]               0
Conv2dStaticSamePadding-117          [-1, 288, 16, 16]          13,824
     BatchNorm2d-118          [-1, 288, 16, 16]             576
MemoryEfficientSwish-119          [-1, 288, 16, 16]               0
       ZeroPad2d-120          [-1, 288, 20, 20]               0
Conv2dStaticSamePadding-121          [-1, 288, 16, 16]           7,200
     BatchNorm2d-122          [-1, 288, 16, 16]             576
MemoryEfficientSwish-123          [-1, 288, 16, 16]               0
        Identity-124            [-1, 288, 1, 1]               0
Conv2dStaticSamePadding-125             [-1, 12, 1, 1]           3,468
MemoryEfficientSwish-126             [-1, 12, 1, 1]               0
        Identity-127             [-1, 12, 1, 1]               0
Conv2dStaticSamePadding-128            [-1, 288, 1, 1]           3,744
        Identity-129          [-1, 288, 16, 16]               0
Conv2dStaticSamePadding-130           [-1, 48, 16, 16]          13,824
     BatchNorm2d-131           [-1, 48, 16, 16]              96
     MBConvBlock-132           [-1, 48, 16, 16]               0
        Identity-133           [-1, 48, 16, 16]               0
Conv2dStaticSamePadding-134          [-1, 288, 16, 16]          13,824
     BatchNorm2d-135          [-1, 288, 16, 16]             576
MemoryEfficientSwish-136          [-1, 288, 16, 16]               0
       ZeroPad2d-137          [-1, 288, 18, 18]               0
Conv2dStaticSamePadding-138            [-1, 288, 8, 8]           2,592
     BatchNorm2d-139            [-1, 288, 8, 8]             576
MemoryEfficientSwish-140            [-1, 288, 8, 8]               0
        Identity-141            [-1, 288, 1, 1]               0
Conv2dStaticSamePadding-142             [-1, 12, 1, 1]           3,468
MemoryEfficientSwish-143             [-1, 12, 1, 1]               0
        Identity-144             [-1, 12, 1, 1]               0
Conv2dStaticSamePadding-145            [-1, 288, 1, 1]           3,744
        Identity-146            [-1, 288, 8, 8]               0
Conv2dStaticSamePadding-147             [-1, 96, 8, 8]          27,648
     BatchNorm2d-148             [-1, 96, 8, 8]             192
     MBConvBlock-149             [-1, 96, 8, 8]               0
        Identity-150             [-1, 96, 8, 8]               0
Conv2dStaticSamePadding-151            [-1, 576, 8, 8]          55,296
     BatchNorm2d-152            [-1, 576, 8, 8]           1,152
MemoryEfficientSwish-153            [-1, 576, 8, 8]               0
       ZeroPad2d-154          [-1, 576, 10, 10]               0
Conv2dStaticSamePadding-155            [-1, 576, 8, 8]           5,184
     BatchNorm2d-156            [-1, 576, 8, 8]           1,152
MemoryEfficientSwish-157            [-1, 576, 8, 8]               0
        Identity-158            [-1, 576, 1, 1]               0
Conv2dStaticSamePadding-159             [-1, 24, 1, 1]          13,848
MemoryEfficientSwish-160             [-1, 24, 1, 1]               0
        Identity-161             [-1, 24, 1, 1]               0
Conv2dStaticSamePadding-162            [-1, 576, 1, 1]          14,400
        Identity-163            [-1, 576, 8, 8]               0
Conv2dStaticSamePadding-164             [-1, 96, 8, 8]          55,296
     BatchNorm2d-165             [-1, 96, 8, 8]             192
     MBConvBlock-166             [-1, 96, 8, 8]               0
        Identity-167             [-1, 96, 8, 8]               0
Conv2dStaticSamePadding-168            [-1, 576, 8, 8]          55,296
     BatchNorm2d-169            [-1, 576, 8, 8]           1,152
MemoryEfficientSwish-170            [-1, 576, 8, 8]               0
       ZeroPad2d-171          [-1, 576, 10, 10]               0
Conv2dStaticSamePadding-172            [-1, 576, 8, 8]           5,184
     BatchNorm2d-173            [-1, 576, 8, 8]           1,152
MemoryEfficientSwish-174            [-1, 576, 8, 8]               0
        Identity-175            [-1, 576, 1, 1]               0
Conv2dStaticSamePadding-176             [-1, 24, 1, 1]          13,848
MemoryEfficientSwish-177             [-1, 24, 1, 1]               0
        Identity-178             [-1, 24, 1, 1]               0
Conv2dStaticSamePadding-179            [-1, 576, 1, 1]          14,400
        Identity-180            [-1, 576, 8, 8]               0
Conv2dStaticSamePadding-181             [-1, 96, 8, 8]          55,296
     BatchNorm2d-182             [-1, 96, 8, 8]             192
     MBConvBlock-183             [-1, 96, 8, 8]               0
        Identity-184             [-1, 96, 8, 8]               0
Conv2dStaticSamePadding-185            [-1, 576, 8, 8]          55,296
     BatchNorm2d-186            [-1, 576, 8, 8]           1,152
MemoryEfficientSwish-187            [-1, 576, 8, 8]               0
       ZeroPad2d-188          [-1, 576, 10, 10]               0
Conv2dStaticSamePadding-189            [-1, 576, 8, 8]           5,184
     BatchNorm2d-190            [-1, 576, 8, 8]           1,152
MemoryEfficientSwish-191            [-1, 576, 8, 8]               0
        Identity-192            [-1, 576, 1, 1]               0
Conv2dStaticSamePadding-193             [-1, 24, 1, 1]          13,848
MemoryEfficientSwish-194             [-1, 24, 1, 1]               0
        Identity-195             [-1, 24, 1, 1]               0
Conv2dStaticSamePadding-196            [-1, 576, 1, 1]          14,400
        Identity-197            [-1, 576, 8, 8]               0
Conv2dStaticSamePadding-198             [-1, 96, 8, 8]          55,296
     BatchNorm2d-199             [-1, 96, 8, 8]             192
     MBConvBlock-200             [-1, 96, 8, 8]               0
        Identity-201             [-1, 96, 8, 8]               0
Conv2dStaticSamePadding-202            [-1, 576, 8, 8]          55,296
     BatchNorm2d-203            [-1, 576, 8, 8]           1,152
MemoryEfficientSwish-204            [-1, 576, 8, 8]               0
       ZeroPad2d-205          [-1, 576, 10, 10]               0
Conv2dStaticSamePadding-206            [-1, 576, 8, 8]           5,184
     BatchNorm2d-207            [-1, 576, 8, 8]           1,152
MemoryEfficientSwish-208            [-1, 576, 8, 8]               0
        Identity-209            [-1, 576, 1, 1]               0
Conv2dStaticSamePadding-210             [-1, 24, 1, 1]          13,848
MemoryEfficientSwish-211             [-1, 24, 1, 1]               0
        Identity-212             [-1, 24, 1, 1]               0
Conv2dStaticSamePadding-213            [-1, 576, 1, 1]          14,400
        Identity-214            [-1, 576, 8, 8]               0
Conv2dStaticSamePadding-215             [-1, 96, 8, 8]          55,296
     BatchNorm2d-216             [-1, 96, 8, 8]             192
     MBConvBlock-217             [-1, 96, 8, 8]               0
        Identity-218             [-1, 96, 8, 8]               0
Conv2dStaticSamePadding-219            [-1, 576, 8, 8]          55,296
     BatchNorm2d-220            [-1, 576, 8, 8]           1,152
MemoryEfficientSwish-221            [-1, 576, 8, 8]               0
       ZeroPad2d-222          [-1, 576, 12, 12]               0
Conv2dStaticSamePadding-223            [-1, 576, 8, 8]          14,400
     BatchNorm2d-224            [-1, 576, 8, 8]           1,152
MemoryEfficientSwish-225            [-1, 576, 8, 8]               0
        Identity-226            [-1, 576, 1, 1]               0
Conv2dStaticSamePadding-227             [-1, 24, 1, 1]          13,848
MemoryEfficientSwish-228             [-1, 24, 1, 1]               0
        Identity-229             [-1, 24, 1, 1]               0
Conv2dStaticSamePadding-230            [-1, 576, 1, 1]          14,400
        Identity-231            [-1, 576, 8, 8]               0
Conv2dStaticSamePadding-232            [-1, 136, 8, 8]          78,336
     BatchNorm2d-233            [-1, 136, 8, 8]             272
     MBConvBlock-234            [-1, 136, 8, 8]               0
        Identity-235            [-1, 136, 8, 8]               0
Conv2dStaticSamePadding-236            [-1, 816, 8, 8]         110,976
     BatchNorm2d-237            [-1, 816, 8, 8]           1,632
MemoryEfficientSwish-238            [-1, 816, 8, 8]               0
       ZeroPad2d-239          [-1, 816, 12, 12]               0
Conv2dStaticSamePadding-240            [-1, 816, 8, 8]          20,400
     BatchNorm2d-241            [-1, 816, 8, 8]           1,632
MemoryEfficientSwish-242            [-1, 816, 8, 8]               0
        Identity-243            [-1, 816, 1, 1]               0
Conv2dStaticSamePadding-244             [-1, 34, 1, 1]          27,778
MemoryEfficientSwish-245             [-1, 34, 1, 1]               0
        Identity-246             [-1, 34, 1, 1]               0
Conv2dStaticSamePadding-247            [-1, 816, 1, 1]          28,560
        Identity-248            [-1, 816, 8, 8]               0
Conv2dStaticSamePadding-249            [-1, 136, 8, 8]         110,976
     BatchNorm2d-250            [-1, 136, 8, 8]             272
     MBConvBlock-251            [-1, 136, 8, 8]               0
        Identity-252            [-1, 136, 8, 8]               0
Conv2dStaticSamePadding-253            [-1, 816, 8, 8]         110,976
     BatchNorm2d-254            [-1, 816, 8, 8]           1,632
MemoryEfficientSwish-255            [-1, 816, 8, 8]               0
       ZeroPad2d-256          [-1, 816, 12, 12]               0
Conv2dStaticSamePadding-257            [-1, 816, 8, 8]          20,400
     BatchNorm2d-258            [-1, 816, 8, 8]           1,632
MemoryEfficientSwish-259            [-1, 816, 8, 8]               0
        Identity-260            [-1, 816, 1, 1]               0
Conv2dStaticSamePadding-261             [-1, 34, 1, 1]          27,778
MemoryEfficientSwish-262             [-1, 34, 1, 1]               0
        Identity-263             [-1, 34, 1, 1]               0
Conv2dStaticSamePadding-264            [-1, 816, 1, 1]          28,560
        Identity-265            [-1, 816, 8, 8]               0
Conv2dStaticSamePadding-266            [-1, 136, 8, 8]         110,976
     BatchNorm2d-267            [-1, 136, 8, 8]             272
     MBConvBlock-268            [-1, 136, 8, 8]               0
        Identity-269            [-1, 136, 8, 8]               0
Conv2dStaticSamePadding-270            [-1, 816, 8, 8]         110,976
     BatchNorm2d-271            [-1, 816, 8, 8]           1,632
MemoryEfficientSwish-272            [-1, 816, 8, 8]               0
       ZeroPad2d-273          [-1, 816, 12, 12]               0
Conv2dStaticSamePadding-274            [-1, 816, 8, 8]          20,400
     BatchNorm2d-275            [-1, 816, 8, 8]           1,632
MemoryEfficientSwish-276            [-1, 816, 8, 8]               0
        Identity-277            [-1, 816, 1, 1]               0
Conv2dStaticSamePadding-278             [-1, 34, 1, 1]          27,778
MemoryEfficientSwish-279             [-1, 34, 1, 1]               0
        Identity-280             [-1, 34, 1, 1]               0
Conv2dStaticSamePadding-281            [-1, 816, 1, 1]          28,560
        Identity-282            [-1, 816, 8, 8]               0
Conv2dStaticSamePadding-283            [-1, 136, 8, 8]         110,976
     BatchNorm2d-284            [-1, 136, 8, 8]             272
     MBConvBlock-285            [-1, 136, 8, 8]               0
        Identity-286            [-1, 136, 8, 8]               0
Conv2dStaticSamePadding-287            [-1, 816, 8, 8]         110,976
     BatchNorm2d-288            [-1, 816, 8, 8]           1,632
MemoryEfficientSwish-289            [-1, 816, 8, 8]               0
       ZeroPad2d-290          [-1, 816, 12, 12]               0
Conv2dStaticSamePadding-291            [-1, 816, 8, 8]          20,400
     BatchNorm2d-292            [-1, 816, 8, 8]           1,632
MemoryEfficientSwish-293            [-1, 816, 8, 8]               0
        Identity-294            [-1, 816, 1, 1]               0
Conv2dStaticSamePadding-295             [-1, 34, 1, 1]          27,778
MemoryEfficientSwish-296             [-1, 34, 1, 1]               0
        Identity-297             [-1, 34, 1, 1]               0
Conv2dStaticSamePadding-298            [-1, 816, 1, 1]          28,560
        Identity-299            [-1, 816, 8, 8]               0
Conv2dStaticSamePadding-300            [-1, 136, 8, 8]         110,976
     BatchNorm2d-301            [-1, 136, 8, 8]             272
     MBConvBlock-302            [-1, 136, 8, 8]               0
        Identity-303            [-1, 136, 8, 8]               0
Conv2dStaticSamePadding-304            [-1, 816, 8, 8]         110,976
     BatchNorm2d-305            [-1, 816, 8, 8]           1,632
MemoryEfficientSwish-306            [-1, 816, 8, 8]               0
       ZeroPad2d-307          [-1, 816, 12, 12]               0
Conv2dStaticSamePadding-308            [-1, 816, 4, 4]          20,400
     BatchNorm2d-309            [-1, 816, 4, 4]           1,632
MemoryEfficientSwish-310            [-1, 816, 4, 4]               0
        Identity-311            [-1, 816, 1, 1]               0
Conv2dStaticSamePadding-312             [-1, 34, 1, 1]          27,778
MemoryEfficientSwish-313             [-1, 34, 1, 1]               0
        Identity-314             [-1, 34, 1, 1]               0
Conv2dStaticSamePadding-315            [-1, 816, 1, 1]          28,560
        Identity-316            [-1, 816, 4, 4]               0
Conv2dStaticSamePadding-317            [-1, 232, 4, 4]         189,312
     BatchNorm2d-318            [-1, 232, 4, 4]             464
     MBConvBlock-319            [-1, 232, 4, 4]               0
        Identity-320            [-1, 232, 4, 4]               0
Conv2dStaticSamePadding-321           [-1, 1392, 4, 4]         322,944
     BatchNorm2d-322           [-1, 1392, 4, 4]           2,784
MemoryEfficientSwish-323           [-1, 1392, 4, 4]               0
       ZeroPad2d-324           [-1, 1392, 8, 8]               0
Conv2dStaticSamePadding-325           [-1, 1392, 4, 4]          34,800
     BatchNorm2d-326           [-1, 1392, 4, 4]           2,784
MemoryEfficientSwish-327           [-1, 1392, 4, 4]               0
        Identity-328           [-1, 1392, 1, 1]               0
Conv2dStaticSamePadding-329             [-1, 58, 1, 1]          80,794
MemoryEfficientSwish-330             [-1, 58, 1, 1]               0
        Identity-331             [-1, 58, 1, 1]               0
Conv2dStaticSamePadding-332           [-1, 1392, 1, 1]          82,128
        Identity-333           [-1, 1392, 4, 4]               0
Conv2dStaticSamePadding-334            [-1, 232, 4, 4]         322,944
     BatchNorm2d-335            [-1, 232, 4, 4]             464
     MBConvBlock-336            [-1, 232, 4, 4]               0
        Identity-337            [-1, 232, 4, 4]               0
Conv2dStaticSamePadding-338           [-1, 1392, 4, 4]         322,944
     BatchNorm2d-339           [-1, 1392, 4, 4]           2,784
MemoryEfficientSwish-340           [-1, 1392, 4, 4]               0
       ZeroPad2d-341           [-1, 1392, 8, 8]               0
Conv2dStaticSamePadding-342           [-1, 1392, 4, 4]          34,800
     BatchNorm2d-343           [-1, 1392, 4, 4]           2,784
MemoryEfficientSwish-344           [-1, 1392, 4, 4]               0
        Identity-345           [-1, 1392, 1, 1]               0
Conv2dStaticSamePadding-346             [-1, 58, 1, 1]          80,794
MemoryEfficientSwish-347             [-1, 58, 1, 1]               0
        Identity-348             [-1, 58, 1, 1]               0
Conv2dStaticSamePadding-349           [-1, 1392, 1, 1]          82,128
        Identity-350           [-1, 1392, 4, 4]               0
Conv2dStaticSamePadding-351            [-1, 232, 4, 4]         322,944
     BatchNorm2d-352            [-1, 232, 4, 4]             464
     MBConvBlock-353            [-1, 232, 4, 4]               0
        Identity-354            [-1, 232, 4, 4]               0
Conv2dStaticSamePadding-355           [-1, 1392, 4, 4]         322,944
     BatchNorm2d-356           [-1, 1392, 4, 4]           2,784
MemoryEfficientSwish-357           [-1, 1392, 4, 4]               0
       ZeroPad2d-358           [-1, 1392, 8, 8]               0
Conv2dStaticSamePadding-359           [-1, 1392, 4, 4]          34,800
     BatchNorm2d-360           [-1, 1392, 4, 4]           2,784
MemoryEfficientSwish-361           [-1, 1392, 4, 4]               0
        Identity-362           [-1, 1392, 1, 1]               0
Conv2dStaticSamePadding-363             [-1, 58, 1, 1]          80,794
MemoryEfficientSwish-364             [-1, 58, 1, 1]               0
        Identity-365             [-1, 58, 1, 1]               0
Conv2dStaticSamePadding-366           [-1, 1392, 1, 1]          82,128
        Identity-367           [-1, 1392, 4, 4]               0
Conv2dStaticSamePadding-368            [-1, 232, 4, 4]         322,944
     BatchNorm2d-369            [-1, 232, 4, 4]             464
     MBConvBlock-370            [-1, 232, 4, 4]               0
        Identity-371            [-1, 232, 4, 4]               0
Conv2dStaticSamePadding-372           [-1, 1392, 4, 4]         322,944
     BatchNorm2d-373           [-1, 1392, 4, 4]           2,784
MemoryEfficientSwish-374           [-1, 1392, 4, 4]               0
       ZeroPad2d-375           [-1, 1392, 8, 8]               0
Conv2dStaticSamePadding-376           [-1, 1392, 4, 4]          34,800
     BatchNorm2d-377           [-1, 1392, 4, 4]           2,784
MemoryEfficientSwish-378           [-1, 1392, 4, 4]               0
        Identity-379           [-1, 1392, 1, 1]               0
Conv2dStaticSamePadding-380             [-1, 58, 1, 1]          80,794
MemoryEfficientSwish-381             [-1, 58, 1, 1]               0
        Identity-382             [-1, 58, 1, 1]               0
Conv2dStaticSamePadding-383           [-1, 1392, 1, 1]          82,128
        Identity-384           [-1, 1392, 4, 4]               0
Conv2dStaticSamePadding-385            [-1, 232, 4, 4]         322,944
     BatchNorm2d-386            [-1, 232, 4, 4]             464
     MBConvBlock-387            [-1, 232, 4, 4]               0
        Identity-388            [-1, 232, 4, 4]               0
Conv2dStaticSamePadding-389           [-1, 1392, 4, 4]         322,944
     BatchNorm2d-390           [-1, 1392, 4, 4]           2,784
MemoryEfficientSwish-391           [-1, 1392, 4, 4]               0
       ZeroPad2d-392           [-1, 1392, 8, 8]               0
Conv2dStaticSamePadding-393           [-1, 1392, 4, 4]          34,800
     BatchNorm2d-394           [-1, 1392, 4, 4]           2,784
MemoryEfficientSwish-395           [-1, 1392, 4, 4]               0
        Identity-396           [-1, 1392, 1, 1]               0
Conv2dStaticSamePadding-397             [-1, 58, 1, 1]          80,794
MemoryEfficientSwish-398             [-1, 58, 1, 1]               0
        Identity-399             [-1, 58, 1, 1]               0
Conv2dStaticSamePadding-400           [-1, 1392, 1, 1]          82,128
        Identity-401           [-1, 1392, 4, 4]               0
Conv2dStaticSamePadding-402            [-1, 232, 4, 4]         322,944
     BatchNorm2d-403            [-1, 232, 4, 4]             464
     MBConvBlock-404            [-1, 232, 4, 4]               0
        Identity-405            [-1, 232, 4, 4]               0
Conv2dStaticSamePadding-406           [-1, 1392, 4, 4]         322,944
     BatchNorm2d-407           [-1, 1392, 4, 4]           2,784
MemoryEfficientSwish-408           [-1, 1392, 4, 4]               0
       ZeroPad2d-409           [-1, 1392, 6, 6]               0
Conv2dStaticSamePadding-410           [-1, 1392, 4, 4]          12,528
     BatchNorm2d-411           [-1, 1392, 4, 4]           2,784
MemoryEfficientSwish-412           [-1, 1392, 4, 4]               0
        Identity-413           [-1, 1392, 1, 1]               0
Conv2dStaticSamePadding-414             [-1, 58, 1, 1]          80,794
MemoryEfficientSwish-415             [-1, 58, 1, 1]               0
        Identity-416             [-1, 58, 1, 1]               0
Conv2dStaticSamePadding-417           [-1, 1392, 1, 1]          82,128
        Identity-418           [-1, 1392, 4, 4]               0
Conv2dStaticSamePadding-419            [-1, 384, 4, 4]         534,528
     BatchNorm2d-420            [-1, 384, 4, 4]             768
     MBConvBlock-421            [-1, 384, 4, 4]               0
        Identity-422            [-1, 384, 4, 4]               0
Conv2dStaticSamePadding-423           [-1, 2304, 4, 4]         884,736
     BatchNorm2d-424           [-1, 2304, 4, 4]           4,608
MemoryEfficientSwish-425           [-1, 2304, 4, 4]               0
       ZeroPad2d-426           [-1, 2304, 6, 6]               0
Conv2dStaticSamePadding-427           [-1, 2304, 4, 4]          20,736
     BatchNorm2d-428           [-1, 2304, 4, 4]           4,608
MemoryEfficientSwish-429           [-1, 2304, 4, 4]               0
        Identity-430           [-1, 2304, 1, 1]               0
Conv2dStaticSamePadding-431             [-1, 96, 1, 1]         221,280
MemoryEfficientSwish-432             [-1, 96, 1, 1]               0
        Identity-433             [-1, 96, 1, 1]               0
Conv2dStaticSamePadding-434           [-1, 2304, 1, 1]         223,488
        Identity-435           [-1, 2304, 4, 4]               0
Conv2dStaticSamePadding-436            [-1, 384, 4, 4]         884,736
     BatchNorm2d-437            [-1, 384, 4, 4]             768
     MBConvBlock-438            [-1, 384, 4, 4]               0
        Identity-439            [-1, 384, 4, 4]               0
Conv2dStaticSamePadding-440           [-1, 1536, 4, 4]         589,824
     BatchNorm2d-441           [-1, 1536, 4, 4]           3,072
MemoryEfficientSwish-442           [-1, 1536, 4, 4]               0
AdaptiveAvgPool2d-443           [-1, 1536, 1, 1]               0
         Dropout-444                 [-1, 1536]               0
          Linear-445                   [-1, 15]          23,055
================================================================
Total params: 10,719,287
Trainable params: 10,719,287
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.19
Forward/backward pass size (MB): 138.51
Params size (MB): 40.89
Estimated Total Size (MB): 179.59
----------------------------------------------------------------

 

output처럼 model에 대한 모든 정보를 한눈에 살펴볼 수 있다. 또한 model을 전부 학습시키고 나온 weight의 size 또한 179.59MB로 예측을 해준다.

우리가 transfer learning을 진행할 데이터의 class가 15개이므로 마지막 fc layer의 shape도 15인 것을 확인할 수 있다.

 

 

2. train.py

우리가 지금까지 해왔던 작업들을 실질적으로 학습시킬 시간이다.

 

import os
import torch
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import time
import copy

from torch.optim import lr_scheduler
from torchvision import datasets
from efficientnet_pytorch import EfficientNet
from torch.utils.tensorboard import SummaryWriter

우선 해당하는 모듈들을 모두 import 해준다.

몇 가지 모듈들을 살펴보자면 torch.optim은 optimizer가 있는 모듈이고, 여기 lr_scheduler는 validation dataset을 학습시킬 때 조정해주는 scheduler이다. torchvision의 dataset과  transform은 data를 load 해주고, transform을 진행한다.

그리고 마지막에 있는 tensorboard는 학습 중 실시간으로 acc와 loss 같은 값들을 그래프에 찍어서 보여주는 역할을 한다.

 

batch_size = 16
epochs = 30
data_dir = './data/class_15_data/'
writer = SummaryWriter('./runs/experiment1/')

기본적인 경로와 batch size, epoch을 설정하여 준다.

batch size는 본인의 gpu와 데이터 수량에 맞게 설명하면 된다.

epoch은 처음에는 50번을 진행하였으나 더 이상 성능이 올라가지를 않으므로 30으로 설정하였다.

data_dir은 본인의 dataset이 위치한 경로를 넣으면 되고, SummaryWriter를 통해 위에서 설명한 tensorboard의 로그를 찍을 파일의 경로를 설정한다.

 

data_transforms = {'train': transforms.Compose([
    transforms.Resize(200,200),
    transforms.RandomRotation(30),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
    'val': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        transforms.Resize(200,200)
    ])}

그다음은 data transform이다. 여기서 사용되는 메서드들은 이전 시간에 같이 review 하였다.

역시나 주의할 점은 data augmentation은 train시에만 사용하도록 하자

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,
                                              shuffle=True, num_workers=4)
               for x in ['train', 'val']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

다음은 data를 load 하는 부분이다. 이 부분 또한 이전 시간에 작성하였던 코드와 내용이 같다.

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = EfficientNet.from_pretrained('efficientnet-b3', num_classes=15)
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

since = time.time()

best_model_weights = copy.deepcopy(model.state_dict())
best_acc = 0.0

다음은 model에 대한 정의이다.

우선 device를 cuda가 설치되어있다면 cuda를 사용하도록 하자. cpu도 물론 있겠지만 기대를 하면 안 된다.

 

그리고 위에서 설명했던 바와 같이 model을 선언하고, optimizer는 가장 무난하게 Adam을 사용한다. 여기서 learning rate는 어차피  validation을 진행할 때 자동적으로 조절을 함으로 크게 신경 쓸 필요가 없다.

criterion은 CrossEntropy를 사용하도록 하자. scheduler는 7 stop에 1번씩 갱신을 하도록 설정을 하고, best_model_weights를 통해 잠시 뒤인 학습 과정에서 validation의 acc가 가장 높을 때에만 weight를 저장하자.

 

for epoch in range(epochs):
    print('Epoch {}/{}'.format(epoch, epochs - 1))
    print('-' * 10)

    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()

        else:
            model.eval()

        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        if phase == 'train':
            scheduler.step()

        epoch_loss = running_loss / dataset_sizes[phase]
        epoch_acc = running_corrects.double() / dataset_sizes[phase]
        # writer.add_graph('epoch loss', epoch_loss, epoch)
        # writer.add_graph('epoch acc', epoch_acc, epoch)

        print('{} Loss: {:.4f} Acc: {:.4f}'.format(
            phase, epoch_loss, epoch_acc))

        if phase == 'val' and epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model_weights = copy.deepcopy(model.state_dict())

학습 코드는 일반적인 pytorch에 관한 학습 코드이다. 딱히 다른 점이 없어 코드에 대해 궁금한 점이 있으면 pytorch tutorial을 참고하는 것이 좋다.

time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))

torch.save(best_model_weights, './weights/best_weights_b5_class_15.pth')

그리고 마지막으로는 acc가 가장 좋은 weight를 저장해준다. 여기서 주의할 점은 해당 경로의 해당 폴더가 없으면 save를 하지 못하고 마지막에 오류가 생기면서 종료될 수 있으므로 이런 참사를 피하기 위해 미리 폴더를 생성해두자.

 

왼쪽 이미지처럼 gpu의 사용률이 올라가고, 오른쪽 화면과 같이 epoch이 뜨면 학습이 잘되어가고 있다는 신호이다.

참고로 왼쪽 명령어는 watch -d -n 0.5 nvidia-smi이다. 이를 사용하면 실시간으로 gpu 사용량을 확인할 수 있다.

학습 결과 나의 모델에서는 78%의 정확도를 가진 weight가 저장되었다.

반응형