みらいテックラボ

音声・画像認識や機械学習など, 週末プログラマである管理人が興味のある技術の紹介や実際にトライしてみた様子などメモしていく.

金魚って見分けられる? (4)

Code for YAMATOKORIMAYAでは, 「アーバンデータチャレンジ2018 」への取り組みの一つとして, 金魚AI(愛)育成プロジェクトなるものに取り組んでいる.

f:id:moonlight-aska:20180729174722j:plain:w400


関連記事:
金魚って見分けられる? (1)
金魚って見分けられる? (2)
金魚って見分けられる? (3)
・金魚って見分けられる? (4)


金魚AI(愛)育成プロジェクトでは金魚データの収集を行っているが, より多くの方にデータ収集に協力してもらおうと, Webを利用したデータ収集の仕組みづくりを検討している.
これに関連し, 今回はFlaskを使って金魚分類をWEB API化した話.


1. データ収集の仕組みづくり
Code for Naraの方が, Node-REDの勉強がてらデータ登録サイトを作成してくれた. (ありがたいことだ...)
画面はこんな感じ.
f:id:moonlight-aska:20181104194856p:plain:w500

「ファイルを選択」タッチすると, カメラ撮影か撮影済み画像かが選択できるようになっている.
f:id:moonlight-aska:20181104195057p:plain:w500

画像を選択し, 金魚系統図を見て種類を選択して登録するのだ.
種類を選んでもらうことで, 画像データのラベリングをやってもらおうというのだ.


2. 金魚分類のWEB API[1][2]
ただ, 金魚系統図があるとはいえ, 金魚の種類って30種以上あり, 種類を特定するのは結構難しい.
また, 金魚以外の画像がうっかり混じらないようにということもあり, 金魚分類を使って金魚画像か否かの判定や種類のラベル候補を出せないか, といったアイデアが出てきた.

これを実現するために, Flask + Kerasで金魚分類をWEB API化してみた.

[コード]

import numpy as np
import io
from PIL import Image

import tensorflow as tf
from keras import models
from keras.preprocessing.image import img_to_array
from flask import Flask, redirect, request, jsonify

CANDIDATE_MAX = 5
THRESHOLD = 0.5

names = [
    'azumanishiki', 'chakin', 'chinshurin', 'chobi', 'chotengan',
    'comet', 'demekin', 'edonishiki', 'hamanishiki', 'kyariko',
    'nankin', 'oranda', 'ranchu', 'ryukin', 'sakuranishiki', 
    'seibungyo', 'shubunkin', 'suihogan', 'tancho', 'tosakin',
    'wakin', 'zikin'
]

app = Flask(__name__)

def load_model():
    global model
    model = models.load_model('./model/MobileNet_size224.h5')
    model.summary()
    model._make_predict_function()
    print('Model loading completed.')

def prepare_image(image, target):
    if image.mode != 'RGB':
        image = image.convert('RGB')
    image = image.resize(target)
    image = img_to_array(image)
    image = (image - 127.5) / 127.5
    image = np.expand_dims(image, axis=0)
    return image

def set_result(res):
    data = []
    print(res)
    for i in range(CANDIDATE_MAX):
        no = np.argsort(res)[::-1][i]
        prob = np.sort(res)[::-1][i]
        if prob < THRESHOLD:
            break
        r = {'label': names[no], 'probability': str(prob)}
        print(r)
        data.append(r)
    return data
    
@app.route('/predict', methods=['POST'])
def predict():
    data = {'Success': False}
    if request.method == 'POST' and request.files.get('image'):
        # 画像データ読み込み
        image = request.files['image'].read()
        image = Image.open(io.BytesIO(image))
        # 前処理
        image = prepare_image(image,target=(224, 224))
        # 分類処理
        preds = model.predict(image)[0]
        data['predictions'] = set_result(preds)
        data['Success'] = True
        print(data)
        
    return jsonify(data)

if __name__ == '__main__':
    load_model()
    app.run(debug=False, port=5000)


3. 動作確認
以下の画像で動作確認.
f:id:moonlight-aska:20181104221559j:plain:w300

[サーバ側]

(tensorflow) aska@aska:~/work/GoldFish/GoldFishRoot$ python app.py
Using TensorFlow backend.
2018-11-04 22:02:47.619170: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
2018-11-04 22:02:47.620473: I tensorflow/core/common_runtime/process_util.cc:69] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input (InputLayer)           (None, 224, 224, 3)       0
_________________________________________________________________
conv1_pad (ZeroPadding2D)    (None, 225, 225, 3)       0
_________________________________________________________________
conv1 (Conv2D)               (None, 112, 112, 32)      864
_________________________________________________________________
conv1_bn (BatchNormalization (None, 112, 112, 32)      128
_________________________________________________________________

(省略)
_________________________________________________________________
conv_pw_13 (Conv2D)          (None, 7, 7, 1024)        1048576
_________________________________________________________________
conv_pw_13_bn (BatchNormaliz (None, 7, 7, 1024)        4096
_________________________________________________________________
conv_pw_13_relu (ReLU)       (None, 7, 7, 1024)        0
_________________________________________________________________
global_average_pooling2d_1 ( (None, 1024)              0
_________________________________________________________________
sequential_1 (Sequential)    (None, 22)                22550
=================================================================
Total params: 3,251,414
Trainable params: 3,229,526
Non-trainable params: 21,888
_________________________________________________________________
Model loading completed.
 * Serving Flask app "app" (lazy loading)
 * Environment: production
   WARNING: Do not use the development server in a production environment.
   Use a production WSGI server instead.
 * Debug mode: off
 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
[9.7076141e-12 1.7206460e-11 7.2557433e-13 9.9999797e-01 6.7276448e-11
 1.1155885e-12 7.2529420e-11 3.7379000e-10 6.2612190e-12 9.5529673e-10
 4.1174125e-13 3.0363852e-12 2.8485176e-11 2.8623672e-15 2.2687393e-13
 5.8064413e-12 1.5076776e-14 6.2707137e-14 6.9399694e-12 2.0334671e-06
 3.3028352e-13 1.0385630e-11]
{'label': 'chobi', 'probability': '0.999998'}
{'Success': True, 'predictions': [{'label': 'chobi', 'probability': '0.999998'}]}[f:id:moonlight-aska:20181104221559j:plain]
127.0.0.1 - - [04/Nov/2018 22:13:31] "POST /predict HTTP/1.1" 200 -

[クライアント側]

aska@aska:~/work/GoldFish/data$ curl -X POST -F image=@GF08-00030.jpg 'http://localhost:5000/predict'
{"Success":true,"predictions":[{"label":"chobi","probability":"0.999998"}]}

金魚分類のWEB API化は一応動作してそうなので, データ収集ツールの種類選択の改良はCode for Naraの方におまかせすることに...
どんなデータ収集ツールになるか楽しみ.

----
参照URL:
[1] Building a simple Keras + deep learning REST API
[2] Kerasにおける"_make_predict_function()"の重要性






エキスパートPythonプログラミング 改訂2版 (アスキードワンゴ)

エキスパートPythonプログラミング 改訂2版 (アスキードワンゴ)