先日作った熱画像をリアルタイムで見れるシステムを改造し、ディープラーニングしてみました。
ジャンケンの3つのジェスチャーを、赤外線アレイモジュールからの入力データをもとに推論します。
ジャンケンの画像データを使いディープラーニングする、というアイデアは雑誌Interface1月号の記事のまねをさせていただきました。他にもいろいろと参考にさせていただきました。
なお、雑誌Interface1月号の記事ではソニーのSPRESENSEボード、および通常のカメラボードを使って行っています。
ディープラーニングのツールとしてNeural Network Console(NNC)のクラウドを使い学習させました。
学習済みモデルをNNPファイルとしてダウンロードし、NNPファイルをC言語ファイルに変換後、ESP32に組み込みました。
【2019/02/04 追記】NNP以外にNNB(Cランタイムフォーマット)でもESP32に組み込むことができることを確認しましたので、それも追記しました。
以下のステップを行いました。
最終的なファイル構成は以下の構成です。
platformio.ini // IDE設定ファイル src ─ main.cpp // メインプログラム lib ┬ MLX90640_API ┬ MLX90640_API.cpp // 以下のファイルはSparkFunにあるサンプルから取得 │ └ MLX90640_API.h ├ MLX90640_I2C_Driver ┬ MLX90640_I2C_Driver.cpp │ └ MLX90640_I2C_Driver.h ├ MainRuntime_inference ┬ MainRuntime_inference.c // 推論するプログラム、NNPから変換(NNPを使う場合) │ └ MainRuntime_inference.h └ MainRuntime_parameters┬ MainRuntime_parameters.c // 重みの配列、NNPから変換(NNPを使う場合) └ MainRuntime_parameters.h data ┬ index.html // SPIFFSに置く静的コンテンツ ├ result.nnb // (NNBを使う場合) ├ app.js ├ app.css ├ ace.js.gz // 以降のファイルはSPIFFSEditor使用時に使う ├ ext-searchbox.js.gz ├ mode-css.js.gz ├ mode-html.js.gz ├ mode-javascript.js.gz └ worker-html.js.gz work ┬ image_data_generator.py // WebSocketクライアントでデータ収集するプログラム ├ check.htm // 作成したデータを確認するHTML ├ image_data_generator.py // データを水増し(Data Augmentation)するプログラム └ make_csv.py // データ変換しCSVファイル作成するプログラム
データ収集のやり方は、雑誌Interface1月号と同様に連写でジャンケンの同じジェスチャーをポーズを変えて撮影し、行いました。
データファイルの書き出しは、Pythonプログラムで作成したWebSocketクライアントで行いました。ブラウザでジェスチャーを確認しつつ、もう一つのWebSocketクライアントで接続してデータの収集・書き出しをします。
実行時の引数で、結果のラベルを指定します。
「グー」を0、「チョキ」を1、「パー」を2とします。
以下を実行し、その間カメラの前で「グー」のジェスチャーを、ポーズを変えてデータ収集します。
各ジェスチャー毎に500枚、計1500枚撮影しました。
0.5秒に1枚撮れるので、500×0.5=250秒 の3セットで10分ちょっとかかりました。
データは、あとでブラウザで確認するためJSONデータで出力します。
$ python collect_data.py 0
・collect_data.py
import websocket
import struct
from datetime import datetime
import json
import sys
args = sys.argv
if len(args) >= 2:
y = int(args[1])
else:
y = 9
path = "./temperature_%s.json" % datetime.now().strftime("%Y%m%d%H%M%S")
counter = 0
def on_message(ws, message):
global counter
tmps = [chunk[0] for chunk in struct.iter_unpack('<f', message)]
data= [datetime.now().timestamp(), y, tmps]
if counter%50 == 0:
print('counter: %s' % counter)
counter += 1
with open(path, mode='a') as f:
f.write(json.dumps(data) + ",\n")
def on_error(ws, error):
print(error)
def on_close(ws):
print("### closed ###")
def on_open(ws):
print("### open ###")
if __name__ == "__main__":
# websocket.enableTrace(True)
ws = websocket.WebSocketApp("ws://esp32.local/ws",
on_message = on_message,
on_error = on_error,
on_close = on_close)
ws.on_open = on_open
ws.run_forever()
上記で作成したデータファイル3つを結合します。その際、ファイルの頭に”[“を挿入、ファイルの最後の”,”を”]”に置換します。以下のようなかんじのコマンドを実行します。
$ cat temperature_201901* | sed -e '$s/.$/]/' -e '1s/^/[/' > temperature_201901.json
収集したデータをブラウザで確認します。
以下のようなHTMLファイルを作成し、ダブルクリックしてブラウザに表示します。前回作成したapp.js、app.cssを利用しています。
「ファイルを選択」で、作成したデータファイルを選択します。
画面上部のスライダーまたは「←」「→」キーで戻る・進むができます。
・check.htm
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="Content-type" content="text/html; charset=utf-8">
<meta name="viewport" content="width=350,initial-scale=0.5">
<title>赤外線アレイカメラ MLX90640</title>
<link rel="stylesheet" type="text/css" href="../data/app.css" >
<link rel="stylesheet" type="text/css" href="https://code.jquery.com/ui/1.12.1/themes/base/jquery-ui.css" >
<style>
a#download{
position: relative;
display: inline-block;
font-weight: bold;
padding: 0.25em 0.5em;
text-decoration: none;
color: #00BCD4;
background: #ECECEC;
transition: .4s;
}
a#download:hover{
background: #00bcd4;
color: white;
}
#slider {
width: 640px;
margin: 10px 0;
}
#pos {
display: inline-block;
}
#labeling {
height: 60px;
}
#dt, #yw {
margin: 0 10px;
}
#yw {
display: inline-block;
}
</style>
</head>
<body id="body" onload="onBodyLoad()">
<div id="container">
<div id="labeling">
<form method="post" enctype="multipart/form-data">
<input type="file" id="file" accept="application/json">
<!-- <span><a id="download" href="#" download="test.txt" onclick="tmps.handleDownload()">ダウンロード</a></span> -->
<div id="pos">
<span id="slider-pos"></span>
<span id="dt"></span>
<span id="yw" style="display: none;">
<select id="y">
<option value="0">グー</option>
<option value="1">チョキ</option>
<option value="2">パー</option>
<!-- <option value="3">3</option>
<option value="4">4</option>
<option value="5">5</option>
<option value="6">6</option>
<option value="7">7</option>
<option value="8">8</option>
<option value="9">9</option> -->
</select>
</span>
</div>
</form>
<div id='slider'></div>
</div>
<canvas id="canvas" width="32" height="24"></canvas>
<div id="scale"></div>
<div id="scale-divisions">
<div id="min-tmp-division"><span id="min-down" class="divisionBtn">◀</span><span id="min-tmp"></span><span id="min-up" class="divisionBtn">▶</span></div>
<div id="max-tmp-division"><span id="max-down" class="divisionBtn">◀</span><span id="max-tmp"></span><span id="max-up" class="divisionBtn">▶</span></div>
</div>
<div id="messages"></div>
</div>
<script src="https://code.jquery.com/jquery-3.3.1.min.js"></script>
<script src="https://code.jquery.com/ui/1.12.1/jquery-ui.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/moment.js/2.23.0/moment.min.js"></script>
<script src="../data/app.js"></script>
<script>
const tmps = {
obj:null,
index: 0,
orgFileName: null,
handleDownload: function () {
var content = JSON.stringify(this.obj);
var blob = new Blob([ content ], { "type" : "text/plain" });
var d = ge("download");
d.download = this.orgFileName;
d.href = window.URL.createObjectURL(blob);
},
draw: function() {
const data = this.obj[this.index];
const timestamp = data[0];
const y = data[1];
const temperature = data[2];
this.setDt(timestamp);
cv.draw(temperature);
$("#y").val(y);
},
setDt: function(timestamp) {
ge("dt").innerHTML = moment.unix(timestamp).format();
},
setY: function (value) {
value = parseInt(value);
if (value != this.obj[this.index][1]) {
console.log(`y is ${value}`);
this.obj[this.index][1] = value;
}
}
};
const slider = {
slider: null,
setUp: (length)=> {
this.slider = $("#slider").slider({
value:0,
min:0,
max:length - 1,
step:1,
change: function (e, ui) {
tmps.index = ui.value;
tmps.draw();
}
});
},
setValue: (value) => {
this.slider.slider("value", value);
}
};
onBodyLoad = function(){
const reader = new FileReader();
//HTMLを初期化し、新たなファイルを文字列として読込む
file.addEventListener('change', function(e) {
tmps.orgFileName = e.target.files[0].name;
reader.readAsText(e.target.files[0]);
});
//ファイルをオブジェクト化して表示
reader.onload = function(e) {
try {
tmps.obj = JSON.parse(e.target.result);
} catch (err1) {
console.log(err1);
try {
tmps.obj = JSON.parse('[' + e.target.result.slice( 0, -2 ) + ']');
} catch (err2) {
alert("jsonに誤りがあります。");
console.log(e2);
return;
}
}
console.log(tmps.obj);
slider.setUp(tmps.obj.length);
$("#yw").show();
tmps.draw();
};
cv.createScale();
cv.createCanvas();
}
$(function() {
$("#y").change(function() {
const value = $(this).val();
tmps.setY(value);
});
$(document).on('keydown', function(e) {
if (!tmps.obj) return;
console.log(`pressed keyCode:${e.keyCode}`);
if (48 <= e.keyCode && e.keyCode <= 57) {
const value = e.keyCode - 48;
$("#y").val(value);
tmps.setY(value);
}
switch( e.keyCode ) {
case 37:
if (tmps.index > 0) {
console.log("戻る");
tmps.index--;
tmps.draw();
slider.setValue(tmps.index);
}
break;
case 39:
// 進む
if (tmps.index < tmps.obj.length) {
console.log("進む");
tmps.index++;
tmps.draw();
slider.setValue(tmps.index);
}
break;
}
});
});
</script>
</body>
</html>
確認画面
KerasのImageDataGeneratorを使って、データの水増しを行います。画像を回転や水平/垂直方向に移動・ズーム等してデータを10倍にします。
このとき、ついでに以下のようにデータを加工しています。「28度以下は0」というのは、環境によってはうまくいかないかもしれないです。
これも後で確認するためjsonファイルで出力しときます。
・image_data_generator.py
import json
import numpy as np
import scipy.stats
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import ImageDataGenerator
from datetime import datetime
data_file = '<入力ファイルを設定する>'
with open(data_file) as f:
df = json.load(f)
X = [data[2] for data in df]
X = np.array(X)
# 28度以下を0、35度以上を1とし、0以上1以下の値に変換
max = 35
min = 28
X = np.where(X >max, 1, (X - min)/(max - min))
X = np.where(X < 0, 0, X)
y = [data[1] for data in df]
y = np.array(y)
y = y.reshape(y.shape[0], 1)
datagen = ImageDataGenerator(
rotation_range=90, # 整数.画像をランダムに回転する回転範囲
width_shift_range=0.1, # 浮動小数点数(横幅に対する割合).ランダムに水平シフトする範囲
height_shift_range=0.1, # 浮動小数点数(縦幅に対する割合).ランダムに垂直シフトする範囲
fill_mode='constant', # 入力画像の境界周りを埋めるモード
cval=0.0, # constantで0.0で埋める
horizontal_flip=True, # 真理値.水平方向に入力をランダムに反転します
vertical_flip=True, # 真理値.垂直方向に入力をランダムに反転します
zoom_range=0.3 # 浮動小数点数または[lower,upper].ランダムにズームする範囲.浮動小数点数
) # randomly flip images
# ImageDataGeneratorの入力値は(samples, channels, height, width)なのでそれに合わせる
X = X.reshape(X.shape[0], 24, 32, 1)
# 空のnumpyを作成し、ループでappendする
new_X = np.empty((0, 24, 32, 1), float)
new_Y = np.empty((0, 1), int)
counter = 0
for x_batch, y_batch in datagen.flow(X, y, batch_size=32):
new_X = np.append(new_X, x_batch, axis=0)
new_Y = np.append(new_Y, y_batch, axis=0)
counter += 1
if counter >= (X.shape[0] * 10 / 32): # 10倍に水増し
break
# Xの値を0=<X<=1 から28<=X<=35に戻す
new_X = new_X*(max - min) + min
new_X = new_X.reshape(new_X.shape[0], 32*24)
data = []
for (each_x, each_y) in zip(new_X, new_Y):
data.append([
datetime.now().timestamp(),
each_y[0],
each_x
])
class MyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return super(MyEncoder, self).default(obj)
path = "image_data_generator_%s.json" % datetime.now().strftime("%Y%m%d%H%M%S")
with open(path, mode='w') as f:
f.write(json.dumps(data, cls = MyEncoder))
ImageDataGeneratorで生成した画像の例
上記で作成したJSONファイルからNNCにアップロードするCSVファイルを作成します。
・make_csv.py
import json
import numpy as np
import scipy.stats
from sklearn.model_selection import train_test_split
data_file = '<入力ファイルを設定する>'
with open(data_file) as f:
df = json.load(f)
X = [data[2] for data in df]
X = np.array(X)
# 28度以下を0、35度以上を1とし、0以上1以下の値に変換
max = 35
min = 28
X = np.where(X >max, 1, (X - min)/(max - min))
X = np.where(X < 0, 0, X)
y = [data[1] for data in df]
y = np.array(y)
y = y.reshape(y.shape[0], 1)
index = ['x__{0}'.format(i) for i in range(0, 768)]
index.append('y')
header = ','.join(index)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
train_data = np.hstack((X_train, y_train))
test_data = np.hstack((X_test, y_test))
np.savetxt('<trainの出力ファイルを設定する>', train_data, fmt='%.3f', delimiter=',', header=header, comments='')
np.savetxt('<testの出力ファイルを設定する>', test_data, fmt='%.3f', delimiter=',', header=header, comments='')
NNC専用のアップローダーでtrainとtestデータのCSVをアップロードします。
サンプルにあったimage_recognition.MNIST.LeNetのモデルをほぼそのまま使いました。
インプットは768個(32×24)の1次元、出力は3つです。
はまったのが、ESP32のメモリのサイズです。
最初にサンプルにあったパラメータのままで作成したら、生成されたESP32に実装する重みパラメータが書かれたファイル(MainRuntime_parameters.c)が2MB近くありました。そのためESP32のコンパイルのリンク時に以下のエラーが発生しました。
region `dram0_0_seg' overflowed by 225088 bytes
Convolution(畳み込み)のフィルター数やフィルターサイズを調整して、最終的にMainRuntime_parameters.cのサイズを30KBにしたら、上記エラーがでなくなりました。
学習した結果の精度(Acuacy)ですが、水増し(Data Augmentation)しなかった場合は、0.9865ですが、水増しした場合、0.8768となりました。データを加工しすぎかもしれません。
■NNPを使った実装の場合
NNPファイルをダウンロードします。
以下のコマンドで展開します。
$ nnabla_cli convert -O CSRC -b 1 result_evaluate.nnp <出力ディレクトリ>
以下の5つのファイルができます。
・GNUmakefile
・MainRuntime_example.c // サンプルプログラム
・MainRuntime_inference.c // 推論するプログラム
・MainRuntime_inference.h
・MainRuntime_parameters.c // 重みパラメータの配列
・MainRuntime_parameters.h
MainRuntime_inference.c、MainRuntime_inference.h、MainRuntime_parameters.c、ainRuntime_parameters.hをlibディレクトリ以下に配置します。
これらはnnabla-c-runtimeライブラリに依存しますので、platformio.iniに以下のように追記します。
lib_deps = https://github.com/me-no-dev/ESPAsyncWebServer.git, https://github.com/sony/nnabla-c-runtime.git
推論を実装したプログラム(NNP版)です。
・main.c
#include <Wire.h>
#include "MLX90640_API.h"
#include "MLX90640_I2C_Driver.h"
#include <WiFi.h>
#include <ESPmDNS.h>
#include <ArduinoOTA.h>
// #include <FS.h>
#include <ESPAsyncWebServer.h>
#include <SPIFFS.h>
#include <SPIFFSEditor.h>
#include "MainRuntime_inference.h"
#include "MainRuntime_parameters.h"
// WIFI設定
const char* ssid = "*******";
const char* password = "******";
// mDNS
const char *hostName = "esp32";
// SPIFFSEditorの認証
const char *http_username = "admin";
const char *http_password = "admin";
// MLX90640
const byte MLX90640_address = 0x33; //Default 7-bit unshifted address of the MLX90640
#define TA_SHIFT 8 //Default shift for MLX90640 in open air
float mlx90640To[768];
paramsMLX90640 mlx90640;
// SKETCH BEGIN
AsyncWebServer server(80);
AsyncWebSocket ws("/ws");
void *_context = NULL;
int mode = 0;
// template <class T, size_t N>
// void standard(T (&data)[N])
// {
// float ave = std::accumulate(std::begin(data), std::end(data), 0.0) / N;
// float sd = sqrt(std::inner_product(std::begin(data), std::end(data), std::begin(data), 0.0) / N - ave * ave);
// std::for_each(std::begin(data), std::end(data), [&ave, &sd](float &temperature) {
// temperature = (temperature - ave) / sd;
// });
// }
// 28度以下を0、35度以上を1とし、0以上1以下の値に変換
template <class T, size_t N>
int normalize(T (&data)[N], float min, float max)
{
int cnt = 0;
std::for_each(std::begin(data), std::end(data), [&min, &max, &cnt](float &temperature) {
if (min > temperature)
{
temperature = 0;
cnt++;
} else if (max < temperature) {
temperature = 1;
} else {
temperature = (temperature - min)/(max - min);
}
});
return cnt;
}
int predict(float *data, float r)
{
memcpy(nnablart_mainruntime_input_buffer(_context, 0), data, 768*4);
nnablart_mainruntime_inference(_context);
float *probs = nnablart_mainruntime_output_buffer(_context, 0);
Serial.printf("predict %.3f, %.3f, %.3f\n", probs[0], probs[1], probs[2]);
for (int cl = 0; cl < NNABLART_MAINRUNTIME_OUTPUT0_SIZE; cl++)
{
if (probs[cl] > r)
{
return cl;
}
}
return 9;
}
void onWsEvent(AsyncWebSocket *server, AsyncWebSocketClient *client, AwsEventType type, void *arg, uint8_t *data, size_t len)
{
if (type == WS_EVT_CONNECT)
{
Serial.printf("ws[%s][%u] connect\n", server->url(), client->id());
}
else if (type == WS_EVT_DISCONNECT)
{
Serial.printf("ws[%s][%u] disconnect\n", server->url(), client->id());
}
else if (type == WS_EVT_ERROR)
{
Serial.printf("ws[%s][%u] error(%u): %s\n", server->url(), client->id(), *((uint16_t *)arg), (char *)data);
}
}
//Returns true if the MLX90640 is detected on the I2C bus
boolean isConnected()
{
Wire.beginTransmission((uint8_t)MLX90640_address);
if (Wire.endTransmission() != 0)
return (false); //Sensor did not ACK
return (true);
}
void setUpMLX90640()
{
if (isConnected() == false)
{
Serial.println("MLX90640 not detected at default I2C address. Please check wiring. Freezing.");
while (1)
;
}
//Get device parameters - We only have to do this once
int status;
uint16_t eeMLX90640[832];
status = MLX90640_DumpEE(MLX90640_address, eeMLX90640);
if (status != 0)
Serial.println("Failed to load system parameters");
status = MLX90640_ExtractParameters(eeMLX90640, &mlx90640);
if (status != 0)
Serial.println("Parameter extraction failed");
Serial.println(status);
//Once params are extracted, we can release eeMLX90640 array
//MLX90640_SetRefreshRate(MLX90640_address, 0x02); //Set rate to 2Hz
MLX90640_SetRefreshRate(MLX90640_address, 0x03); //Set rate to 4Hz
//MLX90640_SetRefreshRate(MLX90640_address, 0x07); //Set rate to 64Hz
}
void setUpOTA()
{
ArduinoOTA.onStart([]() { Serial.println("Update Start"); });
ArduinoOTA.onEnd([]() { Serial.println("Update End"); });
ArduinoOTA.onProgress([](unsigned int progress, unsigned int total) {
Serial.printf("Progress: %u%%\r", (progress / (total / 100)));
});
ArduinoOTA.onError([](ota_error_t error) {
Serial.println("OTA ERROR");
});
ArduinoOTA.setHostname(hostName);
ArduinoOTA.begin();
}
void setup()
{
Wire.begin();
Serial.begin(115200);
Serial.setDebugOutput(true);
WiFi.begin(ssid, password);
while (WiFi.status() != WL_CONNECTED)
{
delay(1000);
Serial.println("Connecting to WiFi..");
}
Serial.println("WiFi connected!");
//OTA
setUpOTA();
// mDNS
if (!MDNS.begin(hostName))
{
Serial.println("Error setting up MDNS responder!");
while (1)
{
delay(1000);
}
}
SPIFFS.begin(true);
ws.onEvent(onWsEvent);
server.addHandler(&ws);
// SPIFFSにあるファイルをブラウザで/editから編集できる
server.addHandler(new SPIFFSEditor(SPIFFS, http_username, http_password));
// SPIFFS
server.serveStatic("/", SPIFFS, "/").setDefaultFile("index.htm");
// predictモード開始
server.on("/start", HTTP_GET, [](AsyncWebServerRequest *request) {
_context = nnablart_mainruntime_allocate_context(MainRuntime_parameters);
mode = 1;
request->send(200, "text/plain", String("ok"));
});
// predictモード終了
server.on("/stop", HTTP_GET, [](AsyncWebServerRequest *request) {
mode = 0;
nnablart_mainruntime_free_context(_context);
request->send(200, "text/plain", String("ok"));
});
server.onNotFound([](AsyncWebServerRequest *request) {
Serial.printf("NOT_FOUND: ");
request->send(404);
});
server.begin();
// MLX90640の初期設定
setUpMLX90640();
}
void loop()
{
ArduinoOTA.handle();
// WebSocket接続してない時は何もしない
if (ws.count() <= 0)
{
return;
}
long startTime = millis();
for (byte x = 0; x < 2; x++)
{
uint16_t mlx90640Frame[834];
MLX90640_GetFrameData(MLX90640_address, mlx90640Frame);
// float vdd = MLX90640_GetVdd(mlx90640Frame, &mlx90640);
float Ta = MLX90640_GetTa(mlx90640Frame, &mlx90640);
float tr = Ta - TA_SHIFT; //Reflected temperature based on the sensor ambient temperature
float emissivity = 0.95;
MLX90640_CalculateTo(mlx90640Frame, &mlx90640, emissivity, tr, mlx90640To);
}
long calculatedTime = millis();
AsyncWebSocketMessageBuffer *buffer = ws.makeBuffer((uint8_t *)&mlx90640To, sizeof(mlx90640To));
ws.binaryAll(buffer); // バイナリー(uint8_tの配列)で全クライアントに送信
int top_class = 9;
if (mode == 1)
{
// 28度以下を0、35度以上を1とし、0以上1以下の値に変換
int cnt = normalize(mlx90640To, 28, 35);
Serial.printf("predict mode: count of below 28C: %d\n", cnt);
// 28度以下が768ドット中の700ドット以上の場合は、predictしない
if (cnt < 700)
{
// 精度が0.5以上の場合のみ、結果を返す
top_class = predict(mlx90640To, 0.5);
}
ws.textAll("result:" + String(top_class));
}
long finishedTime = millis();
Serial.printf("calculated secs:%.2f, finished secs:%.2f, top_class: %d\n", (float)(calculatedTime - startTime) / 1000, (float)(finishedTime - startTime) / 1000, top_class);
}
■NNBを使った実装の場合
NNBファイルをダウンロードします。
SPIFFSでファイルを読み込めるようにするため、dataディレクトリ以下にNNBファイルを置きます。
nnabla-c-runtimeライブラリに依存しますので、platformio.iniに以下のように追記します。
lib_deps = https://github.com/me-no-dev/ESPAsyncWebServer.git, https://github.com/sony/nnabla-c-runtime.git
推論を実装したプログラム(NNB版)です。
NNBファイルをcharの配列として読み込んで、それをnn_network_t構造体にキャストしています。
NNBファイルのサイズは98KBだと大丈夫でしたが、136KBだと読み込み時にエラーが発生しました。使えるヒープサイズは150KB以上あるのですが。連続したヒープメモリが必要だからかもしれません。
#include <Wire.h>
#include "MLX90640_API.h"
#include "MLX90640_I2C_Driver.h"
#include <WiFi.h>
#include <ESPmDNS.h>
#include <ArduinoOTA.h>
#include <FS.h>
#include <ESPAsyncWebServer.h>
#include <SPIFFS.h>
#include <SPIFFSEditor.h>
// #include "MainRuntime_inference.h"
// #include "MainRuntime_parameters.h"
#include <nnablart/network.h>
#include <nnablart/runtime.h>
// WIFI設定
const char *ssid = "******";
const char *password = "******";
// mDNS
const char *hostName = "esp32";
// SPIFFSEditorの認証
const char *http_username = "admin";
const char *http_password = "admin";
// MLX90640
const byte MLX90640_address = 0x33; //Default 7-bit unshifted address of the MLX90640
#define TA_SHIFT 8 //Default shift for MLX90640 in open air
float mlx90640To[768];
paramsMLX90640 mlx90640;
// SKETCH BEGIN
AsyncWebServer server(80);
AsyncWebSocket ws("/ws");
void *_context = NULL;
int mode = 0;
char *nnb;
// template <class T, size_t N>
// void standard(T (&data)[N])
// {
// float ave = std::accumulate(std::begin(data), std::end(data), 0.0) / N;
// float sd = sqrt(std::inner_product(std::begin(data), std::end(data), std::begin(data), 0.0) / N - ave * ave);
// std::for_each(std::begin(data), std::end(data), [&ave, &sd](float &temperature) {
// temperature = (temperature - ave) / sd;
// });
// }
// 28度以下を0、35度以上を1とし、0以上1以下の値に変換
template <class T, size_t N>
int normalize(T (&data)[N], float min, float max)
{
int cnt = 0;
std::for_each(std::begin(data), std::end(data), [&min, &max, &cnt](float &temperature) {
if (min > temperature)
{
temperature = 0;
cnt++;
} else if (max < temperature) {
temperature = 1;
} else {
temperature = (temperature - min)/(max - min);
}
});
return cnt;
}
int predict(float *data, float r)
{
memcpy(rt_input_buffer(_context, 0), data, 768*4);
rt_forward(_context);
// Serial.printf("num:%d, ", rt_num_of_output(_context));
// Serial.printf("size:%d\n", rt_output_size(_context, 0));
float *probs = (float *)rt_output_buffer(_context, 0);
Serial.printf("predict %.3f, %.3f, %.3f\n", probs[0], probs[1], probs[2]);
for (int cl = 0; cl < 3; cl++)
{
if (probs[cl] > r)
{
return cl;
}
}
return 9;
}
void onWsEvent(AsyncWebSocket *server, AsyncWebSocketClient *client, AwsEventType type, void *arg, uint8_t *data, size_t len)
{
if (type == WS_EVT_CONNECT)
{
Serial.printf("ws[%s][%u] connect\n", server->url(), client->id());
}
else if (type == WS_EVT_DISCONNECT)
{
Serial.printf("ws[%s][%u] disconnect\n", server->url(), client->id());
}
else if (type == WS_EVT_ERROR)
{
Serial.printf("ws[%s][%u] error(%u): %s\n", server->url(), client->id(), *((uint16_t *)arg), (char *)data);
}
}
//Returns true if the MLX90640 is detected on the I2C bus
boolean isConnected()
{
Wire.beginTransmission((uint8_t)MLX90640_address);
if (Wire.endTransmission() != 0)
return (false); //Sensor did not ACK
return (true);
}
void setUpMLX90640()
{
if (isConnected() == false)
{
Serial.println("MLX90640 not detected at default I2C address. Please check wiring. Freezing.");
while (1)
;
}
//Get device parameters - We only have to do this once
int status;
uint16_t eeMLX90640[832];
status = MLX90640_DumpEE(MLX90640_address, eeMLX90640);
if (status != 0)
Serial.println("Failed to load system parameters");
status = MLX90640_ExtractParameters(eeMLX90640, &mlx90640);
if (status != 0)
Serial.println("Parameter extraction failed");
Serial.println(status);
//Once params are extracted, we can release eeMLX90640 array
//MLX90640_SetRefreshRate(MLX90640_address, 0x02); //Set rate to 2Hz
MLX90640_SetRefreshRate(MLX90640_address, 0x03); //Set rate to 4Hz
//MLX90640_SetRefreshRate(MLX90640_address, 0x07); //Set rate to 64Hz
}
void setUpOTA()
{
ArduinoOTA.onStart([]() { Serial.println("Update Start"); });
ArduinoOTA.onEnd([]() { Serial.println("Update End"); });
ArduinoOTA.onProgress([](unsigned int progress, unsigned int total) {
Serial.printf("Progress: %u%%\r", (progress / (total / 100)));
});
ArduinoOTA.onError([](ota_error_t error) {
Serial.println("OTA ERROR");
});
ArduinoOTA.setHostname(hostName);
ArduinoOTA.begin();
}
void setup()
{
Wire.begin();
Serial.begin(115200);
Serial.setDebugOutput(true);
WiFi.begin(ssid, password);
while (WiFi.status() != WL_CONNECTED)
{
delay(1000);
Serial.println("Connecting to WiFi..");
}
Serial.println("WiFi connected!");
//OTA
setUpOTA();
// mDNS
if (!MDNS.begin(hostName))
{
Serial.println("Error setting up MDNS responder!");
while (1)
{
delay(1000);
}
}
SPIFFS.begin(true);
ws.onEvent(onWsEvent);
server.addHandler(&ws);
// SPIFFSにあるファイルをブラウザで/editから編集できる
server.addHandler(new SPIFFSEditor(SPIFFS, http_username, http_password));
// SPIFFS
server.serveStatic("/", SPIFFS, "/").setDefaultFile("index.htm");
// predictモード開始
server.on("/start", HTTP_GET, [](AsyncWebServerRequest *request) {
// _context = nnablart_mainruntime_allocate_context(MainRuntime_parameters);
/* READ FILE */
File fp = SPIFFS.open("/result.nnb", FILE_READ); // 読み取り
Serial.printf("file size:%d\n", fp.size());
rt_return_value_t ret = rt_allocate_context(&_context);
nnb = (char *)malloc(fp.size()); // malloc使わないとバッファーオーバーフローエラーが発生した
fp.readBytes(nnb, fp.size());
fp.close();
nn_network_t *net = (nn_network_t *)nnb;
ret = rt_initialize_context(_context, net);
Serial.println(ret);
mode = 1;
request->send(200, "text/plain", String("ok"));
});
// predictモード終了
server.on("/stop", HTTP_GET, [](AsyncWebServerRequest *request) {
mode = 0;
rt_free_context(&_context);
free(nnb);
request->send(200, "text/plain", String("ok"));
});
server.onNotFound([](AsyncWebServerRequest *request) {
Serial.printf("NOT_FOUND: ");
request->send(404);
});
server.begin();
// MLX90640の初期設定
setUpMLX90640();
}
void loop()
{
ArduinoOTA.handle();
// WebSocket接続してない時は何もしない
if (ws.count() <= 0)
{
return;
}
long startTime = millis();
for (byte x = 0; x < 2; x++)
{
uint16_t mlx90640Frame[834];
MLX90640_GetFrameData(MLX90640_address, mlx90640Frame);
// float vdd = MLX90640_GetVdd(mlx90640Frame, &mlx90640);
float Ta = MLX90640_GetTa(mlx90640Frame, &mlx90640);
float tr = Ta - TA_SHIFT; //Reflected temperature based on the sensor ambient temperature
float emissivity = 0.95;
MLX90640_CalculateTo(mlx90640Frame, &mlx90640, emissivity, tr, mlx90640To);
}
long calculatedTime = millis();
AsyncWebSocketMessageBuffer *buffer = ws.makeBuffer((uint8_t *)&mlx90640To, sizeof(mlx90640To));
ws.binaryAll(buffer); // バイナリー(uint8_tの配列)で全クライアントに送信
int top_class = 9;
if (mode == 1)
{
// 28度以下を0、35度以上を1とし、0以上1以下の値に変換
int cnt = normalize(mlx90640To, 28, 35);
Serial.printf("predict mode: count of below 28C: %d\n", cnt);
// 28度以下が768ドット中の700ドット以上の場合は、predictしない
if (cnt < 700)
{
// 精度が0.5以上の場合のみ、結果を返す
top_class = predict(mlx90640To, 0.5);
}
ws.textAll("result:" + String(top_class));
}
long finishedTime = millis();
Serial.printf("calculated secs:%.2f, finished secs:%.2f, top_class: %d\n", (float)(calculatedTime - startTime) / 1000, (float)(finishedTime - startTime) / 1000, top_class);
}
推論の結果をブラウザで表示するためHTML,js、cssを以下のように修正しました。
・index.htm
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="Content-type" content="text/html; charset=utf-8">
<meta name="viewport" content="width=350,initial-scale=0.5">
<title>赤外線アレイカメラ MLX90640</title>
<link rel="stylesheet" type="text/css" href="app.css" >
</head>
<body id="body" onload="onBodyLoad()">
<div id="container">
<div id="header"><span><a id="predict" href="#">判定する</a></span>
</div>
<canvas id="canvas" width="32" height="24"></canvas>
<div id="scale"></div>
<div id="scale-divisions">
<div id="min-tmp-division"><span id="min-down" class="divisionBtn">◀</span><span id="min-tmp"></span><span id="min-up" class="divisionBtn">▶</span></div>
<div id="max-tmp-division"><span id="max-down" class="divisionBtn">◀</span><span id="max-tmp"></span><span id="max-up" class="divisionBtn">▶</span></div>
</div>
<div id="messages"></div>
<div id="result"></div>
</div>
<script src="app.js"></script>
</body>
</html>
・app.js
const ge = (s) => { return document.getElementById(s); }
const ce = (s) => { return document.createElement(s); }
const gc = (s) => { return document.getElementsByClassName(s); }
const addMessage = (m) => {
// メッセージ表示
// console.log(m);
const msg = ce("div");
msg.innerText = m;
ge("messages").appendChild(msg);
ge("messages").append
}
skt = {
ws: null,
start: function () {
// WebSocketを開始
ws = ws = new WebSocket('ws://' + document.location.host + '/ws', ['arduino']);
ws.binaryType = "arraybuffer";
ws.onopen = (e) => {
addMessage("Connected");
};
ws.onclose = (e) => {
addMessage("Disconnected");
};
ws.onerror = (e) => {
console.log("ws error", e);
addMessage("Error");
};
ws.onmessage = (e) => {
if (e.data instanceof ArrayBuffer) {
// バイナリーデータの場合
this.parseTemparatures(e.data);
} else {
console.log(predict.mode);
const result = ge("result");
const resultStrings = { "0": "グー", "1": "チョキ", "2": "パー" };
const m = e.data.match(/^result:([0-2])/);
if (m) {
result.innerHTML = resultStrings[m[1]];
} else {
result.innerHTML = "";
}
}
};
},
parseTemparatures: (data) => {
// Uint8Arrayにセットされた4バイトのfloatの配列をFloat32Arrayの型付き配列にセットし、描画します。
const dv = new DataView(data);
const byteSize = 4;
const tmps = new Float32Array(data.byteLength / byteSize);
for (let i = 0; i < tmps.length; i++) {
tmps[i] = dv.getFloat32(i * byteSize, true);
}
cv.draw(tmps);
}
};
const HSVtoRGB = (h, s, v) => {
// HSVからRGBに変換 パラメータh,s,vは0以上1以下
let r, g, b, i, f, p, q, t;
i = Math.floor(h * 6);
f = h * 6 - i;
p = v * (1 - s);
q = v * (1 - f * s);
t = v * (1 - (1 - f) * s);
switch (i % 6) {
case 0: r = v, g = t, b = p; break;
case 1: r = q, g = v, b = p; break;
case 2: r = p, g = v, b = t; break;
case 3: r = p, g = q, b = v; break;
case 4: r = t, g = p, b = v; break;
case 5: r = v, g = p, b = q; break;
}
return {
r: Math.round(r * 255),
g: Math.round(g * 255),
b: Math.round(b * 255)
};
}
const rgb = {
min: 20, // 表示する最低温度
max: 35, // 表示する最高温度
get: function (tmp) {
// 温度から色(RGB)を取得
let rate = 1 - (tmp - this.min) / (this.max - this.min);
if (rate < 0) {
rate = 0;
} else if (rate > 1) {
rate = 1;
}
// const h = 0.7*rate;
const h = (Math.tanh(rate * 2 - 1.5) + 1) / 2 - 0.04; // 適当
return HSVtoRGB(h, 1, 1);
}
};
const cv = {
canvas: null,
content: null,
imageData: null,
createCanvas: function () {
// Canvasを作成
this.canvas = ge('canvas');
this.context = this.canvas.getContext('2d');
this.imageData = this.context.createImageData(32, 24);
},
createScale: function () {
// スケールを作成
const scale = ge('scale');
let color, t, span;
for (let i = 0; i < 100; i++) {
t = i * (rgb.max - rgb.min) / 100 + rgb.min;
span = ce('span');
color = rgb.get(t);
span.style.backgroundColor = span.style.color = 'rgb(' + color.r + ',' + color.g + ',' + color.b + ')';
scale.appendChild(span);
}
this.createDivisions();
},
createDivisions: () => {
// 目盛り作成
ge("min-tmp").textContent = rgb.min;
ge("max-tmp").textContent = rgb.max;
var scaleDivisions = ge('scale-divisions');
const divisions = gc('division');
while (divisions.length > 0) {
scaleDivisions.removeChild(divisions[0]);
}
let div;
for (let temp = rgb.min + 5; temp < rgb.max; temp += 5) {
div = ce('div');
div.innerText = temp;
div.classList.add("division");
div.style.left = 640 * (temp - rgb.min) / (rgb.max - rgb.min) - 7 + 'px';
scaleDivisions.appendChild(div);
}
},
draw: function (tmps) {
// 描画
const data = this.imageData.data; // RGBA の順番のデータを含んだ 1次元配列。それぞれの値は 0 ~ 255 の範囲となります。
if (data.length / 4 != tmps.length) {
alert(なにかおかしいです);
return;
}
let tmp, color, j, mirror;
const maxValue = 255;
for (let i = 0; i < tmps.length; i++) {
if (true) {
j = 4 * i;
} else {
// 左右反転させる
mirror = (31 - i % 32) + parseInt(i / 32) * 32;
j = 4 * mirror;
}
tmp = tmps[i];
color = rgb.get(tmp);
data[j] = color.r;
data[j + 1] = color.g;
data[j + 2] = color.b;
data[j + 3] = maxValue;
}
this.context.putImageData(this.imageData, 0, 0);
}
};
const divisionBtns = gc('divisionBtn');
for (let i = 0; i < divisionBtns.length; i++) {
divisionBtns[i].addEventListener('click', function () {
// 目盛り変更
const d = 5; // 目盛りの間隔
switch (this.id) {
case 'min-down':
if (rgb.min >= 5) rgb.min -= d;
break;
case 'min-up':
if (rgb.min <= rgb.max - 2 * d) rgb.min += d;
break;
case 'max-down':
if (rgb.min <= rgb.max - 2 * d) rgb.max -= d;
break;
case 'max-up':
if (rgb.max <= 90) rgb.max += d;
break;
}
cv.createDivisions();
});
}
const predict = {
mode: false,
sw: function () {
const elem = ge("predict");
const result = ge("result");
if (this.mode) {
fetch("stop");
elem.classList.remove('predictiong');
elem.innerHTML = "判定する";
result.style.display = "none";
this.mode = false;
} else {
fetch("start");
elem.classList.add('predictiong');
elem.innerHTML = "判定中";
result.style.display = "block";
this.mode = true;
}
}
}
let onBodyLoad = function () {
cv.createScale();
skt.start();
cv.createCanvas();
ge("predict").onclick = predict.sw;
}
・app.css
body {
display: flex;
justify-content: center;
align-items: center;
background-color: black;
color: #ffffff;
font-size: 12px;
}
#container {
position: relative;
}
#canvas {
background: #666;
width: 640px;
height: 480px;
}
#scale, #scale-divisions {
width: 100%;
height: 24px;
}
#scale-divisions {
position: relative;
}
#min-tmp-division {
position: absolute;
left: -20px;
}
#max-tmp-division {
position: absolute;
right: -20px;
}
.division {
position: absolute;
}
#scale {
display: flex;
}
#scale span {
display: block;
width: 1%;
height: 23px;
}
#messages {
overflow-y: auto;
}
.divisionBtn {
cursor: pointer;
}
.divisionBtn:hover{
color: #99b2ce;
}
#result {
position: absolute;
right: 1px;
top: 1px;
height: 18px;
width: 50px;
background: #faf62d;
padding: 3px 10px;
color: #000000;
font-weight: bold;
display: none;
top: 30px;
}
#header {
text-align: right;
padding-bottom: 5px;
}
a#predict{
position: relative;
display: inline-block;
font-weight: bold;
padding: 0.25em 0.5em;
text-decoration: none;
color: rgb(242, 248, 250);
background: rgb(20, 85, 7);
transition: .4s;
width: 80px;
text-align: center;
}
a#predict.predictiong {
background: rgb(230, 33, 7);
}
a#predict:hover{
background: #5bf15b;
color: white;
}
@media only screen and (max-device-width: 480px) {
body {
font-size:24px;
}
#messages {
margin-top:20px;
}
}
上記プログラムを実行した結果、だいたい正確に推論するようでした。
ただし背景に熱を持つもの、モニタとか蛍光灯等があると、ノイズが入ってうまくいかないです。