ai.cc 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. /*
  2. * main_functions.cc
  3. *
  4. * Created on: Feb 27, 2022
  5. * Author: pgj
  6. */
  7. #include "ImC/ai.h"
  8. #include "tensorflow/lite/micro/all_ops_resolver.h"
  9. #include "tensorflow/lite/micro/examples/hello_world/constants.h"
  10. #include "tensorflow/lite/micro/examples/hello_world/output_handler.h"
  11. #include "tensorflow/lite/micro/micro_error_reporter.h"
  12. #include "tensorflow/lite/micro/micro_interpreter.h"
  13. #include "tensorflow/lite/micro/recording_micro_interpreter.h"
  14. #include "tensorflow/lite/micro/system_setup.h"
  15. #include "tensorflow/lite/schema/schema_generated.h"
  16. #include "stm32l496xx.h"
  17. #include "ImC/new_model.h"
  18. namespace {
  19. tflite::ErrorReporter* error_reporter = nullptr;
  20. const tflite::Model* model = nullptr;
  21. tflite::MicroInterpreter* interpreter = nullptr;
  22. TfLiteTensor* input = nullptr;
  23. TfLiteTensor* output_tensor = nullptr;
  24. }
  25. uint8_t* setup(uint8_t* tensor_arena, int kTensorArenaSize, float *scale, int32_t* zero_point) {
  26. tflite::InitializeTarget();
  27. model = tflite::GetModel(resnet_quant_tflite);
  28. static tflite::MicroErrorReporter micro_error_reporter;
  29. error_reporter = &micro_error_reporter;
  30. static tflite::AllOpsResolver resolver;
  31. static tflite::MicroInterpreter static_interpreter( model, resolver, tensor_arena, kTensorArenaSize, error_reporter);
  32. interpreter = &static_interpreter;
  33. TfLiteStatus allocate_status = interpreter->AllocateTensors();
  34. if (allocate_status != kTfLiteOk) {
  35. TF_LITE_REPORT_ERROR(error_reporter, "AllocateTensors() failed");
  36. return NULL;
  37. }
  38. input = interpreter->input(0);
  39. output_tensor = interpreter->output(0);
  40. *scale = interpreter->input(0)->params.scale;
  41. *zero_point = interpreter->input(0)->params.zero_point;
  42. return input->data.uint8;
  43. }
  44. extern uint8_t* dst_image[120][160][3];
  45. /*void preprocess_input(){
  46. printf("[AI ] input size: %d bytes \r\n", input->bytes);
  47. uint32_t index = 0;
  48. float temp = 0.0;
  49. for(uint32_t i = 0; i < 120; ++i){
  50. for(uint32_t j = 0; j < 160; ++j){
  51. for(uint32_t k = 0; k < 3; k++){
  52. temp = (dst_image[i][j][k] / 255.0) / interpreter->input(0)->params.scale + interpreter->input(0)->params.zero_point;
  53. input->data.int8[index++] = (int8_t) ( temp);
  54. }
  55. }
  56. }
  57. }
  58. */
  59. void print_output(){
  60. uint32_t max_index = 0;
  61. float max_value = output_tensor->data.f[0];
  62. // output_tensor->bytes << int , bytes/4 << float
  63. // printf("Output size is %d\r\n", output_tensor->bytes);
  64. for(uint32_t i = 0; i < (uint32_t)(output_tensor->bytes/4.0); ++i){
  65. if(output_tensor->data.f[i] > max_value){
  66. max_value = output_tensor->data.f[i];
  67. max_index = i;
  68. }
  69. // printf("[AI ] class %d: %f\r\n", (uint8_t)i, output_tensor->data.f[i]);
  70. }
  71. if(max_index >= 10 || max_index < 0){
  72. max_index = -1;
  73. }
  74. switch (max_index) {
  75. case 1:
  76. printf("[AI ] OK Person, %f\r\n", output_tensor->data.f[1]);
  77. break;
  78. default:
  79. printf("[AI ] OK NOT person, %f\r\n", output_tensor->data.f[1]);
  80. break;
  81. }
  82. }
  83. int perform_inference() {
  84. TfLiteStatus invoke_status = interpreter->Invoke();
  85. if (invoke_status != kTfLiteOk) {
  86. TF_LITE_REPORT_ERROR(error_reporter, "[AI ] FAILED invoke\n");
  87. return -1;
  88. }
  89. return 0;
  90. }