Code for YAMATOKORIMAYAでは, 「アーバンデータチャレンジ2018 」への取り組みの一つとして, 金魚AI(愛)育成プロジェクトなるものに取り組んでいる.
関連記事:
・金魚って見分けられる? (1)
・金魚って見分けられる? (2)
・金魚って見分けられる? (3)
・金魚って見分けられる? (4)
金魚AI(愛)育成プロジェクトでは金魚データの収集を行っているが, より多くの方にデータ収集に協力してもらおうと, Webを利用したデータ収集の仕組みづくりを検討している.
これに関連し, 今回はFlaskを使って金魚分類をWEB API化した話.
1. データ収集の仕組みづくり
Code for Naraの方が, Node-REDの勉強がてらデータ登録サイトを作成してくれた. (ありがたいことだ...)
画面はこんな感じ.
「ファイルを選択」タッチすると, カメラ撮影か撮影済み画像かが選択できるようになっている.
画像を選択し, 金魚系統図を見て種類を選択して登録するのだ.
種類を選んでもらうことで, 画像データのラベリングをやってもらおうというのだ.
2. 金魚分類のWEB API化[1][2]
ただ, 金魚系統図があるとはいえ, 金魚の種類って30種以上あり, 種類を特定するのは結構難しい.
また, 金魚以外の画像がうっかり混じらないようにということもあり, 金魚分類を使って金魚画像か否かの判定や種類のラベル候補を出せないか, といったアイデアが出てきた.
これを実現するために, Flask + Kerasで金魚分類をWEB API化してみた.
[コード]
import numpy as np import io from PIL import Image import tensorflow as tf from keras import models from keras.preprocessing.image import img_to_array from flask import Flask, redirect, request, jsonify CANDIDATE_MAX = 5 THRESHOLD = 0.5 names = [ 'azumanishiki', 'chakin', 'chinshurin', 'chobi', 'chotengan', 'comet', 'demekin', 'edonishiki', 'hamanishiki', 'kyariko', 'nankin', 'oranda', 'ranchu', 'ryukin', 'sakuranishiki', 'seibungyo', 'shubunkin', 'suihogan', 'tancho', 'tosakin', 'wakin', 'zikin' ] app = Flask(__name__) def load_model(): global model model = models.load_model('./model/MobileNet_size224.h5') model.summary() model._make_predict_function() print('Model loading completed.') def prepare_image(image, target): if image.mode != 'RGB': image = image.convert('RGB') image = image.resize(target) image = img_to_array(image) image = (image - 127.5) / 127.5 image = np.expand_dims(image, axis=0) return image def set_result(res): data = [] print(res) for i in range(CANDIDATE_MAX): no = np.argsort(res)[::-1][i] prob = np.sort(res)[::-1][i] if prob < THRESHOLD: break r = {'label': names[no], 'probability': str(prob)} print(r) data.append(r) return data @app.route('/predict', methods=['POST']) def predict(): data = {'Success': False} if request.method == 'POST' and request.files.get('image'): # 画像データ読み込み image = request.files['image'].read() image = Image.open(io.BytesIO(image)) # 前処理 image = prepare_image(image,target=(224, 224)) # 分類処理 preds = model.predict(image)[0] data['predictions'] = set_result(preds) data['Success'] = True print(data) return jsonify(data) if __name__ == '__main__': load_model() app.run(debug=False, port=5000)
3. 動作確認
以下の画像で動作確認.
[サーバ側]
(tensorflow) aska@aska:~/work/GoldFish/GoldFishRoot$ python app.py Using TensorFlow backend. 2018-11-04 22:02:47.619170: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA 2018-11-04 22:02:47.620473: I tensorflow/core/common_runtime/process_util.cc:69] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance. _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input (InputLayer) (None, 224, 224, 3) 0 _________________________________________________________________ conv1_pad (ZeroPadding2D) (None, 225, 225, 3) 0 _________________________________________________________________ conv1 (Conv2D) (None, 112, 112, 32) 864 _________________________________________________________________ conv1_bn (BatchNormalization (None, 112, 112, 32) 128 _________________________________________________________________ (省略) _________________________________________________________________ conv_pw_13 (Conv2D) (None, 7, 7, 1024) 1048576 _________________________________________________________________ conv_pw_13_bn (BatchNormaliz (None, 7, 7, 1024) 4096 _________________________________________________________________ conv_pw_13_relu (ReLU) (None, 7, 7, 1024) 0 _________________________________________________________________ global_average_pooling2d_1 ( (None, 1024) 0 _________________________________________________________________ sequential_1 (Sequential) (None, 22) 22550 ================================================================= Total params: 3,251,414 Trainable params: 3,229,526 Non-trainable params: 21,888 _________________________________________________________________ Model loading completed. * Serving Flask app "app" (lazy loading) * Environment: production WARNING: Do not use the development server in a production environment. Use a production WSGI server instead. * Debug mode: off * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit) [9.7076141e-12 1.7206460e-11 7.2557433e-13 9.9999797e-01 6.7276448e-11 1.1155885e-12 7.2529420e-11 3.7379000e-10 6.2612190e-12 9.5529673e-10 4.1174125e-13 3.0363852e-12 2.8485176e-11 2.8623672e-15 2.2687393e-13 5.8064413e-12 1.5076776e-14 6.2707137e-14 6.9399694e-12 2.0334671e-06 3.3028352e-13 1.0385630e-11] {'label': 'chobi', 'probability': '0.999998'} {'Success': True, 'predictions': [{'label': 'chobi', 'probability': '0.999998'}]}[f:id:moonlight-aska:20181104221559j:plain] 127.0.0.1 - - [04/Nov/2018 22:13:31] "POST /predict HTTP/1.1" 200 -
[クライアント側]
aska@aska:~/work/GoldFish/data$ curl -X POST -F image=@GF08-00030.jpg 'http://localhost:5000/predict' {"Success":true,"predictions":[{"label":"chobi","probability":"0.999998"}]}
金魚分類のWEB API化は一応動作してそうなので, データ収集ツールの種類選択の改良はCode for Naraの方におまかせすることに...
どんなデータ収集ツールになるか楽しみ.
----
参照URL:
[1] Building a simple Keras + deep learning REST API
[2] Kerasにおける"_make_predict_function()"の重要性
PythonでWebサービスを作る - Python3 + Flaskで作るWebアプリケーション開発入門 - その1
|
ゼロからFlaskがよくわかる本: Pythonで作るWebアプリケーション開発入門
|
エキスパートPythonプログラミング 改訂2版 (アスキードワンゴ)
|
PythonでWebサービスを作る - Python3 + Flaskで作るWebアプリケーション開発入門 - その2
|