#include <TensorFlowLite.h>
#include "main_functions.h"
#include "mnist_model_data.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_log.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/micro/system_setup.h"
#include "tensorflow/lite/schema/schema_generated.h"
const int kInputTensorSize = 1 * 784;
const int kNumClass = 10;
// Globals, used for compatibility with Arduino-style sketches.
namespace
{
const tflite::Model *model = nullptr;
tflite::MicroInterpreter *interpreter = nullptr;
TfLiteTensor *input = nullptr;
constexpr int kTensorArenaSize = 100 * 1024;
// Keep aligned to 16 bytes for CMSIS
alignas(16) uint8_t tensor_arena[kTensorArenaSize];
} // namespace
// The name of this function is important for Arduino compatibility.
void setup()
{
Serial.begin(9600);
tflite::InitializeTarget();
// Map the model into a usable data structure. This doesn't involve any
// copying or parsing, it's a very lightweight operation.
model = tflite::GetModel(g_person_detect_model_data);
if (model->version() != TFLITE_SCHEMA_VERSION)
{
MicroPrintf(
"Model provided is schema version %d not equal "
"to supported version %d.",
model->version(), TFLITE_SCHEMA_VERSION);
return;
}
static tflite::MicroMutableOpResolver<10> micro_op_resolver;
micro_op_resolver.AddShape();
micro_op_resolver.AddStridedSlice();
micro_op_resolver.AddPack();
micro_op_resolver.AddMaxPool2D();
micro_op_resolver.AddFullyConnected();
micro_op_resolver.AddAveragePool2D();
micro_op_resolver.AddConv2D();
micro_op_resolver.AddDepthwiseConv2D();
micro_op_resolver.AddReshape();
micro_op_resolver.AddSoftmax();
// static tflite::AllOpsResolver resolver;
// Build an interpreter to run the model with.
// NOLINTNEXTLINE(runtime-global-variables)
static tflite::MicroInterpreter static_interpreter(
model, micro_op_resolver, tensor_arena, kTensorArenaSize);
interpreter = &static_interpreter;
// Allocate memory from the tensor_arena for the model's tensors.
TfLiteStatus allocate_status = interpreter->AllocateTensors();
if (allocate_status != kTfLiteOk)
{
MicroPrintf("AllocateTensors() failed");
return;
}
// Get information about the memory area to use for the model's input.
input = interpreter->input(0);
}
// The name of this function is important for Arduino compatibility.
void loop()
{
float x_test[kInputTensorSize] = {...};
for(int i =0; i < kInputTensorSize; i++){
input->data.f[i] = x_test[i];
}
// Run the model on this input and make sure it succeeds.
if (kTfLiteOk != interpreter->Invoke())
{
MicroPrintf("Invoke failed.");
}
TfLiteTensor *output = interpreter->output(0);
int predicated_class = 0;
float max_score = -1;
for (int i = 0; i < kNumClass; i++)
{
float score = output->data.f[i];
if (score > max_score)
{
predicated_class = i;
max_score = score;
}
}
Serial.print("Class: ");
Serial.println(predicated_class);
}
아두이노 Nano를 위한 mnist.ino 파일은 위와 같다.
이전 포스트에서 구축한 ML 모델을 기반으로 Getmodel함수를 사용해 모델을 불러오고, 이 모델을 기반으로 인터프리터 모델을 불러온다.
입력 x_test의 경우 전처리 시킨 변환된 입력을 사용하였으며 아래처럼 mnist 파일을 이용해 test를 진행하고, 올바르게 추론하는지 확인하였다.
총 소스코드는 위와같은 흐름도를 따라가며, 최종적으로 Input에 따른 추론 class를 시리얼 모니터를 통해 출력한다.
출력은 위처럼 예측된 입력값의 class를 나타내며, 아래 프로파일러는 이후에 설명할 time_profile에 관한 파라미터 값을 나타낸다.
'임베디드 딥러닝' 카테고리의 다른 글
[임베디드 딥러닝] gprof 프로파일러를 이용해 딥러닝 최적화하기 - 3 (0) | 2024.08.28 |
---|---|
[임베디드 딥러닝] gprof 프로파일러를 이용해 딥러닝 최적화하기 - 2 (0) | 2024.08.28 |
[임베디드 딥러닝] gprof 프로파일러를 이용해 딥러닝 최적화하기 - 1 (0) | 2024.08.27 |
[임베디드 딥러닝] 아두이노Nano에서 손글씨 인식모델 만들기 - 1 (0) | 2024.08.16 |
[임베디드 딥러닝] Tensorflow Lite (2) | 2024.08.16 |