mirror of
https://github.com/RIOT-OS/RIOT.git
synced 2024-12-29 04:50:03 +01:00
134 lines
5.2 KiB
C
134 lines
5.2 KiB
C
|
/*
|
||
|
* Copyright (C) 2019 Inria
|
||
|
*
|
||
|
* This file is subject to the terms and conditions of the GNU Lesser
|
||
|
* General Public License v2.1. See the file LICENSE in the top level
|
||
|
* directory for more details.
|
||
|
*/
|
||
|
|
||
|
/**
|
||
|
* @ingroup tests
|
||
|
* @{
|
||
|
*
|
||
|
* @file
|
||
|
* @brief Sample application for ARM CMSIS-NN package
|
||
|
*
|
||
|
* This example is adapted from ARM CMSIS CIFAR10 example to RIOT by Alexandre Abadie
|
||
|
* https://github.com/ARM-software/CMSIS_5/tree/develop/CMSIS/NN/Examples/ARM/arm_nn_examples/cifar10
|
||
|
*
|
||
|
* @author Alexandre Abadie <alexandre.abadie@inria.fr>
|
||
|
*/
|
||
|
|
||
|
#include <stdint.h>
|
||
|
#include <stdio.h>
|
||
|
#include "arm_math.h"
|
||
|
#include "parameter.h"
|
||
|
#include "weights.h"
|
||
|
|
||
|
#include "arm_nnfunctions.h"
|
||
|
|
||
|
#include "blob/input.h"
|
||
|
|
||
|
/* There are 10 different classes of objects in the CIFAR10 dataset */
|
||
|
#define CLASSES_NUMOF 10
|
||
|
|
||
|
/* include the input and weights */
|
||
|
static const q7_t conv1_wt[CONV1_IM_CH * CONV1_KER_DIM * CONV1_KER_DIM * CONV1_OUT_CH] = CONV1_WT;
|
||
|
static const q7_t conv1_bias[CONV1_OUT_CH] = CONV1_BIAS;
|
||
|
|
||
|
static const q7_t conv2_wt[CONV2_IM_CH * CONV2_KER_DIM * CONV2_KER_DIM * CONV2_OUT_CH] = CONV2_WT;
|
||
|
static const q7_t conv2_bias[CONV2_OUT_CH] = CONV2_BIAS;
|
||
|
|
||
|
static const q7_t conv3_wt[CONV3_IM_CH * CONV3_KER_DIM * CONV3_KER_DIM * CONV3_OUT_CH] = CONV3_WT;
|
||
|
static const q7_t conv3_bias[CONV3_OUT_CH] = CONV3_BIAS;
|
||
|
|
||
|
static const q7_t ip1_wt[IP1_DIM * IP1_OUT] = IP1_WT;
|
||
|
static const q7_t ip1_bias[IP1_OUT] = IP1_BIAS;
|
||
|
|
||
|
/* Here the image_data should be the raw uint8 type RGB image in [RGB, RGB, RGB ... RGB] format */
|
||
|
// static const uint8_t image_data[CONV1_IM_CH * CONV1_IM_DIM * CONV1_IM_DIM] = IMG_DATA;
|
||
|
static q7_t output_data[IP1_OUT];
|
||
|
|
||
|
/* vector buffer: max(im2col buffer,average pool buffer, fully connected buffer) */
|
||
|
static q7_t col_buffer[2 * 5 * 5 * 32 * 2];
|
||
|
static q7_t img_buffer1[32 * 32 * 10 * 4];
|
||
|
static q7_t *img_buffer2 = (q7_t *)(img_buffer1 + (32 * 32 * 32));
|
||
|
|
||
|
static const char classes[CLASSES_NUMOF][6] = {
|
||
|
"plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck" };
|
||
|
|
||
|
int main(void)
|
||
|
{
|
||
|
printf("start execution\n");
|
||
|
|
||
|
uint8_t *image_data = (uint8_t *)input;
|
||
|
|
||
|
/* input pre-processing */
|
||
|
int mean_data[3] = INPUT_MEAN_SHIFT;
|
||
|
unsigned int scale_data[3] = INPUT_RIGHT_SHIFT;
|
||
|
for (unsigned i = 0; i < input_len; i += 3) {
|
||
|
img_buffer2[i] = (q7_t)__SSAT( ((((int)image_data[i] - mean_data[0]) << 7) + (0x1 << (scale_data[0] - 1)))
|
||
|
>> scale_data[0], 8);
|
||
|
img_buffer2[i + 1] = (q7_t)__SSAT( ((((int)image_data[i + 1] - mean_data[1]) << 7) + (0x1 << (scale_data[1] - 1)))
|
||
|
>> scale_data[1], 8);
|
||
|
img_buffer2[i + 2] = (q7_t)__SSAT( ((((int)image_data[i + 2] - mean_data[2]) << 7) + (0x1 << (scale_data[2] - 1)))
|
||
|
>> scale_data[2], 8);
|
||
|
}
|
||
|
|
||
|
/* conv1 img_buffer2 -> img_buffer1 */
|
||
|
arm_convolve_HWC_q7_RGB(img_buffer2, CONV1_IM_DIM, CONV1_IM_CH, conv1_wt, CONV1_OUT_CH, CONV1_KER_DIM, CONV1_PADDING,
|
||
|
CONV1_STRIDE, conv1_bias, CONV1_BIAS_LSHIFT, CONV1_OUT_RSHIFT, img_buffer1, CONV1_OUT_DIM,
|
||
|
(q15_t *)col_buffer, NULL);
|
||
|
|
||
|
arm_relu_q7(img_buffer1, CONV1_OUT_DIM * CONV1_OUT_DIM * CONV1_OUT_CH);
|
||
|
|
||
|
/* pool1 img_buffer1 -> img_buffer2 */
|
||
|
arm_maxpool_q7_HWC(img_buffer1, CONV1_OUT_DIM, CONV1_OUT_CH, POOL1_KER_DIM,
|
||
|
POOL1_PADDING, POOL1_STRIDE, POOL1_OUT_DIM, NULL, img_buffer2);
|
||
|
|
||
|
/* conv2 img_buffer2 -> img_buffer1 */
|
||
|
arm_convolve_HWC_q7_fast(img_buffer2, CONV2_IM_DIM, CONV2_IM_CH, conv2_wt, CONV2_OUT_CH, CONV2_KER_DIM,
|
||
|
CONV2_PADDING, CONV2_STRIDE, conv2_bias, CONV2_BIAS_LSHIFT, CONV2_OUT_RSHIFT, img_buffer1,
|
||
|
CONV2_OUT_DIM, (q15_t *)col_buffer, NULL);
|
||
|
|
||
|
arm_relu_q7(img_buffer1, CONV2_OUT_DIM * CONV2_OUT_DIM * CONV2_OUT_CH);
|
||
|
|
||
|
/* pool2 img_buffer1 -> img_buffer2 */
|
||
|
arm_maxpool_q7_HWC(img_buffer1, CONV2_OUT_DIM, CONV2_OUT_CH, POOL2_KER_DIM,
|
||
|
POOL2_PADDING, POOL2_STRIDE, POOL2_OUT_DIM, col_buffer, img_buffer2);
|
||
|
|
||
|
/* conv3 img_buffer2 -> img_buffer1 */
|
||
|
arm_convolve_HWC_q7_fast(img_buffer2, CONV3_IM_DIM, CONV3_IM_CH, conv3_wt, CONV3_OUT_CH, CONV3_KER_DIM,
|
||
|
CONV3_PADDING, CONV3_STRIDE, conv3_bias, CONV3_BIAS_LSHIFT, CONV3_OUT_RSHIFT, img_buffer1,
|
||
|
CONV3_OUT_DIM, (q15_t *)col_buffer, NULL);
|
||
|
|
||
|
arm_relu_q7(img_buffer1, CONV3_OUT_DIM * CONV3_OUT_DIM * CONV3_OUT_CH);
|
||
|
|
||
|
/* pool3 img_buffer-> img_buffer2 */
|
||
|
arm_maxpool_q7_HWC(img_buffer1, CONV3_OUT_DIM, CONV3_OUT_CH, POOL3_KER_DIM,
|
||
|
POOL3_PADDING, POOL3_STRIDE, POOL3_OUT_DIM, col_buffer, img_buffer2);
|
||
|
|
||
|
arm_fully_connected_q7_opt(img_buffer2, ip1_wt, IP1_DIM, IP1_OUT, IP1_BIAS_LSHIFT, IP1_OUT_RSHIFT, ip1_bias,
|
||
|
output_data, (q15_t *)img_buffer1);
|
||
|
|
||
|
arm_softmax_q7(output_data, CLASSES_NUMOF, output_data);
|
||
|
|
||
|
int val = -1;
|
||
|
uint8_t class_idx = 0;
|
||
|
for (unsigned i = 0; i < CLASSES_NUMOF; i++) {
|
||
|
if (output_data[i] > val) {
|
||
|
val = output_data[i];
|
||
|
class_idx = i;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if (val > 0) {
|
||
|
printf("Predicted class: %s\n", classes[class_idx]);
|
||
|
}
|
||
|
else {
|
||
|
puts("No match found");
|
||
|
}
|
||
|
|
||
|
return 0;
|
||
|
}
|