오늘보다 더 나은 내일의 나에게_

비전공자의 IoT 국비 교육 수강일지 Day_92 본문

비전공자의 코딩일지

비전공자의 IoT 국비 교육 수강일지 Day_92

chan_96 2022. 4. 28. 17:54
728x90

머신러닝

전이학습(Transfer Learning)

- 전이학습이란 다른 데이터 셋을 사용하여 이미 학습한 모델을 유사한 다른 데이터를 인식하는데 사용하는 기법이다.

- 이 방법은 특히 새로 훈련시킬 데이터가 충분히 확보되지 못한 경우에 학습 효율을 높여준다.

- 사전학습모델을 이용하는 방법은 특성 추출(feature extraction)방식과 미세조정(fine-tuning) 방식이 있다.


특성추출방식

 

📌딥러닝(Deep Learning) 실습

VGG16 전이학습

from tensorflow.keras.applications import VGG16

pre_trained_model = VGG16(include_top=False,
                          weights="imagenet",
                          input_shape=(224,224,3))
                          
cnn_model2 = Sequential()
cnn_model2.add(pre_trained_model)
cnn_model2.add(Flatten())
cnn_model2.add(Dense(units=128,activation='relu'))
cnn_model2.add(Dense(units=64,activation='relu'))
cnn_model2.add(Dense(units=3,activation='softmax'))

# 특성추출방식
pre_trained_model.trainable = False

cnn_model2.compile(loss='sparse_categorical_crossentropy',
                   optimizer='Adam',
                   metrics = ['accuracy'])
                   
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.2,
                                                random_state=425)
                                                
cnn_model2.fit(X_train,y_train,epochs=20)

pre = cnn_model2.predict(X_test)
from sklearn.metrics import classification_report
print(classification_report(y_test,np.argmax(pre,axis=1)))

 

# 미세조정방식
pre_trained_model = VGG16(include_top=False,
                          weights="imagenet",
                          input_shape=(224,224,3))

for layer in pre_trained_model.layers:
  #print(layer.name)
  if layer.name == "block5_conv3":
    layer.trainable = True
  else:
    layer.trainable = False
    
cnn_model3 = Sequential()
cnn_model3.add(pre_trained_model)
cnn_model3.add(Flatten())
cnn_model3.add(Dense(units=128,activation='relu'))
cnn_model3.add(Dense(units=64,activation='relu'))
cnn_model3.add(Dense(units=3,activation='softmax'))    

cnn_model3.compile(loss='sparse_categorical_crossentropy',
                   optimizer='Adam',
                   metrics = ['accuracy'])

cnn_model3.fit(X_train,y_train,epochs=20)

데이터 증강

: 모델의 과대적합을 방지하기 위한 기법

from tensorflow.keras.preprocessing.image import ImageDataGenerator

aug = ImageDataGenerator(rotation_range=90,
                         zoom_range=0.2,
                         horizontal_flip=True,
                         height_shift_range=0.2)

from tensorflow.keras.applications import MobileNetV2

pre_trained_model = MobileNetV2(include_top=False,
                          weights="imagenet",
                          input_shape=(224,224,3))
pre_trained_model.trainable = False​

from tensorflow.keras.layers import AveragePooling2D

cnn_model4 = Sequential()
cnn_model4.add(pre_trained_model)
cnn_model4.add(AveragePooling2D())
cnn_model4.add(Flatten())
cnn_model4.add(Dense(units=128,activation='relu'))
cnn_model4.add(Dense(units=64,activation='relu'))
cnn_model4.add(Dense(units=3,activation='softmax'))

cnn_model4.compile(loss="sparse_categorical_crossentropy",
                   optimizer='Adam',
                   metrics=['accuracy'])
                   
cnn_model4.fit(aug.flow(X_train,y_train),
               epochs=50)
               
pre = cnn_model4.predict(X_test)    
print(classification_report(y_test,np.argmax(pre,axis=1)))

안드로이드

실습

✨코드
더보기
MainActivity 코드
package com.example.ex0428;

import androidx.annotation.NonNull;
import androidx.appcompat.app.AppCompatActivity;

import android.os.Bundle;
import android.os.Handler;
import android.os.Message;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.TextView;

public class MainActivity extends AppCompatActivity {

    TextView tvNumber,tvNumber2;
    Button btnStart,btnStart2;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        tvNumber = findViewById(R.id.tvNumber);
        btnStart = findViewById(R.id.btnStart);
        tvNumber2 = findViewById(R.id.tvNumber2);
        btnStart2 = findViewById(R.id.btnStart2);

        btnStart.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {
                TimerThread thread = new TimerThread(tvNumber);
                thread.start();

                btnStart.setEnabled(false);
            }
        });

        btnStart2.setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View view) {
                TimerThread thread2 = new TimerThread(tvNumber2);
                thread2.start();

                btnStart2.setEnabled(false);
            }
        });
    }// end onCreate

    //Thread(스레드)
    // : 하나의 프로세스 내에서 작업을 처리하는 작은 단위
    // : Main Thread 이외에 작업을 별도로 처리할 때 활용

    //Main Thread의 역할
    // : UI를 업데이트하는 역할

    class TimerThread extends Thread{

        TimerHandler handler = new TimerHandler();
        TextView tv;

        public TimerThread(TextView tv){
            this.tv = tv;
        }

        @Override
        public void run() {
            //실행할 로직 정의
            for(int i = 0;i < 10;i++){
                Log.d("TimmerThread","카운트:"+(i+1));

                //tvNumber.setText(String.valueOf(i+1));

                //Handler에 값을 전달 -> Message 객체
                Message msg = new Message();
                msg.arg1 = i + 1;
                msg.obj = tv; // TextView -> Object로 업캐스팅되서 저장


                //Message객체 전송
                handler.sendMessage(msg);

                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        }//end run()
    }//end TimerThread class

    class TimerHandler extends Handler{
        @Override
        public void handleMessage(@NonNull Message msg) {
            //Sub Thread에서 처리한 결과를 UI업데이트 할 때
            //handlerMessage() 안에 정의

            int count = msg.arg1;
            TextView tv = (TextView) msg.obj; //Object타입으로 저장된 TextView객체를 다운캐스팅

            //tvNumber.setText(String.valueOf(count));

            //TimerThread객체생성 시 넘겨받은 Textview객체에 내용을 업데이트
            tv.setText(String.valueOf(count)); //

        }
    }//end TimerHandler class
    
}

 

두더지잡기 게임 실습


✨코드

더보기
MoreActivity
package com.example.ex0428;

import androidx.appcompat.app.AppCompatActivity;

import android.os.Bundle;
import android.widget.ImageView;
import android.widget.TextView;

public class MoreActivity extends AppCompatActivity {

    TextView tvTime, tvCount;
    ImageView[] moreArr = new ImageView[9];

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_more);

        tvTime = findViewById(R.id.tvTime);
        tvCount = findViewById(R.id.tvCount);

        //동적으로 리소스ID접근 후 ImageView 초기화
        for(int i=0;i< moreArr.length;i++){

            //img1 ~ img9까지의 리소스ID 접근
            int resId = getResources().getIdentifier("img"+(i+1),"id",getPackageName());
            moreArr[i] = findViewById(resId);
            
        }


    }//end onCreate
}​
728x90
Comments