임베디드 딥러닝

[임베디드 딥러닝] 아두이노Nano에서 손글씨 인식모델 만들기 - 2

다락공방 2024. 8. 27. 20:06

 

#include <TensorFlowLite.h>

#include "main_functions.h"
#include "mnist_model_data.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_log.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/system_setup.h"
#include "tensorflow/lite/schema/schema_generated.h"

const int kInputTensorSize = 1 * 784;
const int kNumClass = 10;

// Globals, used for compatibility with Arduino-style sketches.
namespace
{
  const tflite::Model *model = nullptr;
  tflite::MicroInterpreter *interpreter = nullptr;
  TfLiteTensor *input = nullptr;
  constexpr int kTensorArenaSize = 100 * 1024;
  // Keep aligned to 16 bytes for CMSIS
  alignas(16) uint8_t tensor_arena[kTensorArenaSize];
} // namespace

// The name of this function is important for Arduino compatibility.
void setup()
{
  Serial.begin(9600);
  tflite::InitializeTarget();

  // Map the model into a usable data structure. This doesn't involve any
  // copying or parsing, it's a very lightweight operation.
  model = tflite::GetModel(g_person_detect_model_data);
  if (model->version() != TFLITE_SCHEMA_VERSION)
  {
    MicroPrintf(
        "Model provided is schema version %d not equal "
        "to supported version %d.",
        model->version(), TFLITE_SCHEMA_VERSION);
    return;
  }

  static tflite::MicroMutableOpResolver<10> micro_op_resolver;
  micro_op_resolver.AddShape();
  micro_op_resolver.AddStridedSlice();
  micro_op_resolver.AddPack();
  micro_op_resolver.AddMaxPool2D();
  micro_op_resolver.AddFullyConnected();
  micro_op_resolver.AddAveragePool2D();
  micro_op_resolver.AddConv2D();
  micro_op_resolver.AddDepthwiseConv2D();
  micro_op_resolver.AddReshape();
  micro_op_resolver.AddSoftmax();

  // static tflite::AllOpsResolver resolver;

  // Build an interpreter to run the model with.
  // NOLINTNEXTLINE(runtime-global-variables)
  static tflite::MicroInterpreter static_interpreter(
      model, micro_op_resolver, tensor_arena, kTensorArenaSize);
  interpreter = &static_interpreter;

  // Allocate memory from the tensor_arena for the model's tensors.
  TfLiteStatus allocate_status = interpreter->AllocateTensors();
  if (allocate_status != kTfLiteOk)
  {
    MicroPrintf("AllocateTensors() failed");
    return;
  }

  // Get information about the memory area to use for the model's input.
  input = interpreter->input(0);
}

// The name of this function is important for Arduino compatibility.
void loop()
{

  float x_test[kInputTensorSize] = {...};

  for(int i =0; i < kInputTensorSize; i++){
    input->data.f[i] = x_test[i];
  }

  // Run the model on this input and make sure it succeeds.

  if (kTfLiteOk != interpreter->Invoke())
  {
    MicroPrintf("Invoke failed.");
  }

  TfLiteTensor *output = interpreter->output(0);

  int predicated_class = 0;
  float max_score = -1;
  for (int i = 0; i < kNumClass; i++)
  {
    float score = output->data.f[i];
    if (score > max_score)
    {
      predicated_class = i;
      max_score = score;
    }
  }

  Serial.print("Class: ");
  Serial.println(predicated_class);
}

 

아두이노 Nano를 위한 mnist.ino 파일은 위와 같다.  

이전 포스트에서 구축한  ML 모델을 기반으로 Getmodel함수를 사용해 모델을 불러오고, 이 모델을 기반으로 인터프리터 모델을 불러온다. 

입력 x_test의 경우 전처리 시킨 변환된 입력을 사용하였으며 아래처럼 mnist 파일을 이용해 test를 진행하고, 올바르게 추론하는지 확인하였다. 

 

 

총 소스코드는 위와같은 흐름도를 따라가며, 최종적으로 Input에 따른 추론 class를 시리얼 모니터를 통해 출력한다. 

 

 

출력은 위처럼 예측된 입력값의 class를 나타내며, 아래 프로파일러는 이후에 설명할 time_profile에 관한 파라미터 값을 나타낸다.