みらいテックラボ

音声・画像認識や機械学習など, 週末プログラマである管理人が興味のある技術の紹介や実際にトライしてみた様子などメモしていく.

UnityでBarracuda + YOLOv5を試してみる(2)

Unityで物体検出を試してみようと調査していたら, Barracudaを使うことでonnx形式のモデルを扱えることが分かった.
そこで, Barracuda + YOLOv5で物体検知を試すことにしたのだが, いくつか注意すべきポイントがあったので少しまとめておく.


関連記事:


[開発環境]

  • Unity Editor 2021.3.19f1
  • Barracuda 3.0.0
  • YOLOv5 v7.0
  • onnx 1.11.0


前回[1]前処理まで記載したので, 今回は物体検出処理及び全体処理について記す.


1. 物体検出処理
入力画像をモデルに入力することで, 物体検出自体は容易に行える.
しかし, モデルは同じクラスとして認識された重なっている状態の領域をそのまま返してくるので, NMS(Non-Maximum Suppression)により重複している領域を抑制する必要がある.

1.1 モデルの出力
モデルの出力は, 前回記したように[1, 1, 85, 6300]ですが, 構造はこんな感じ.

1.2 NMS[2]
YOLOv5のpython実装コードでも, 後処理としてNMSが実装されているので, そのコードを参考にしつつC#でNMSを実装してみた.

[コード]

using System;
using System.Collections.Generic;
using Unity.Barracuda;
using UnityEngine;

public class MLEngine 
{
    private readonly int _inputWidth;
    private readonly int _inputHeight;
    private readonly IWorker _worker;
    //
    // Model : yolov5n.onnx
    // NodeArg(name='images', type='tensor(float)', shape=[1, 3, 320, 320])
    // NodeArg(name='output0', type='tensor(float)', shape=[1, 6300, 85])
    // ⇒ barracuda model
    //    inputs[0].shape : [1, 1, 1, 1, 1, 320, 320, 3]
    //    outputs.shape : [1, 1, 85, 6300]
    //
    private readonly string[] _labels = {
        "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", 
        "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow",
        "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
        "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", 
        "wine glass", "cup", "fork", "knife", "spoon",  "bowl", "banana", "apple", "sandwich", "orange", 
        "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", 
        "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", 
        "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrus"};
    private const int OUTPUT_BUF_SIZE = 6300;
    private const int NUM_CLASS = 80;

    public MLEngine(NNModel modelAsset)
    {
        var runtimeModel = ModelLoader.Load(modelAsset);
        var inputShape = runtimeModel.inputs[0].shape;
        Debug.Log($"Input shape: {string.Join(",", inputShape)}");
        _inputWidth = inputShape[6];
        _inputHeight = inputShape[5];
        _worker = WorkerFactory.CreateWorker(WorkerFactory.Type.ComputePrecompiled, runtimeModel);
    }

    public IList<Candidate> Execute(Tensor input, float confThres, float iouThres)
    {
        _worker.Execute(input);
        var outputs = _worker.PeekOutput();  
        Debug.Log($"Output shape: {string.Join(",",  outputs.shape)}");
        var cands = ParseOutputs(outputs, confThres);
        var preds = NonMaxSuppression(cands, confThres, iouThres);
        return preds;
    }

    public void Destory()
    {
        _worker.Dispose();
    }

    private IList<Candidate> ParseOutputs(Tensor outputs, float threshold)
    {
        var cands = new List<Candidate>();
        for (int i = 0; i < OUTPUT_BUF_SIZE; i++)
        {
            float boxConf = GetConfidence(outputs, i);
            if (boxConf < threshold) 
            {
                continue;
            }
            BoundingBox bBox = GetBoundingBox(outputs, i);
            (int classId, float objConf) = GetBestClass(outputs, i);
            float score = boxConf * objConf;
            if (score < threshold)
            {
                continue;
            }
            cands.Add(new Candidate{
                bbox = bBox,
                label = _labels[classId],
                score = score
            });
        }
        return cands;

    }

    private float GetConfidence(Tensor outputs, int idx)
    {
        float x = outputs[0, 0, 4, idx];
        return x;
    }

    private BoundingBox GetBoundingBox(Tensor outputs, int idx)
    {
        return new BoundingBox
        {
            x1 = outputs[0, 0, 0, idx] - outputs[0, 0, 2, idx] / 2,
            y1 = outputs[0, 0, 1, idx] - outputs[0, 0, 3, idx] / 2,
            x2 = outputs[0, 0, 0, idx] + outputs[0, 0, 2, idx] / 2,
            y2 = outputs[0, 0, 1, idx] + outputs[0, 0, 3, idx] / 2
        };
    }

    private ValueTuple<int, float> GetBestClass(Tensor outputs, int idx)
    {
        int classId = -1;
        float maxConf = -1.0f;
        for (int i = 0; i < NUM_CLASS; i++) 
        {
            // find max confidence
            if (outputs[0, 0, 5 + i, idx] > maxConf)
            {
                maxConf = outputs[0, 0, 5 + i, idx];
                classId = i;
            }
        }
        // Debug.Log($"Class : {classId}, {_labels[classId]}, Confidence : {outputs[0, 0, 5 + classId, idx]}");
        return (classId, maxConf);
    }

    private IList<Candidate> NonMaxSuppression(IList<Candidate> cands, float confThres, float iouThres)
    {
        IList<Candidate> newCands = new List<Candidate>();

        while (cands.Count > 0) 
        {
            int idx = 0;
            float maxScore = 0.0f;
            for (int i = 0; i < cands.Count; i++) {
                if (cands[i].score > maxScore) {
                    idx = i;
                    maxScore = cands[i].score;
                }
            }
            Candidate cand = cands[idx];
            cands.RemoveAt(idx);

            // 追加
            newCands.Add(cand);
            // Debug.Log("---------------------");
            // Debug.Log($"idx = {idx}, BoxA[{cand.bbox.x1}, {cand.bbox.y1}, {cand.bbox.x2}, {cand.bbox.y2}]");
            List<int> deletes = new List<int>();
            for (int i = 0; i < cands.Count; i++) 
            {
                // ラベルチェック
                if (cand.label != cands[i].label) {
                    continue;
                }
                // IOUチェック
                float iou = Iou(cand.bbox, cands[i].bbox);
                if (iou >= iouThres)
                {
                    deletes.Add(i);
                }
                // else {
                //    Debug.Log($"IoU : {iou} : BoxB[{cands[i].bbox.x1}, {cands[i].bbox.y1}, {cands[i].bbox.x2}, {cands[i].bbox.y2}]");
                // }
            }
            for (int i = deletes.Count - 1; i >= 0; i--)
            {
                cands.RemoveAt(deletes[i]);
            }
        }
        // Debug.Log($"New cands = {newCands.Count}");
        return newCands;
    }

    private float Iou(BoundingBox boxA, BoundingBox boxB)
    {
        if (boxA == boxB)
        {
            return 1.0f;
        }
        else if (((boxA.x1 <= boxB.x1 && boxA.x2 > boxB.x1) || (boxA.x1 >= boxB.x1 && boxB.x2 > boxA.x1))
            && ((boxA.y1 <= boxB.y1 && boxA.y2 > boxB.y1) || (boxA.y1 >= boxB.y1 && boxB.y2 > boxA.y1)))
        {
            float intersection = (Mathf.Min(boxA.x2, boxB.x2) - Mathf.Max(boxA.x1, boxB.x1)) 
                * (Mathf.Min(boxA.y2, boxB.y2) - Mathf.Max(boxA.y1, boxB.y1));
            float union = (boxA.x2 - boxA.x1) * (boxA.y2 - boxA.y1) + (boxB.x2 - boxB.x1) * (boxB.y2 - boxB.y1) - intersection;
            return (intersection / union);
        }
        else
        {
            return 0.0f;
        }
    }
}


2. 全体制御
WebCAMからの画像を取得し, 物体検知して結果を表示する部分のコードを記載しておく.

[コード]

sing System.Collections;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.UI;
using Unity.Barracuda;

public class CameraManager : MonoBehaviour
{
    [SerializeField] RawImage rawImage;
    [SerializeField] float confThreshold = 0.50f;
    [SerializeField] float iouThreshold = 0.45f;
    private WebCamTexture _webCamTexture;

    // Parameters
    private const int CAM_WIDTH = 1280; // 1920;
    private const int CAM_HEIGHT = 720; // 1080;
    private const int CAM_FPS = 30;
    private const int IMG_SIZE = 320;
    // Scaling
    private float _scaleX;
    private float _scaleY;

    // YOLOv5
    [SerializeField] NNModel modelAsset;
    private MLEngine _engine;
    private Tensor _input;
    private IList<Candidate> _preds;
    // 描画
    private Texture2D _lineTexture;
    private GUIStyle _guiStyle;
#if UNITY_ANDROID 
    // Android Tab 2560x1530
    private float _dispScale = 1.333333f;
    private int _dispOffsetX = 0;
    private int _dispOffsetY = 48;
#else
    private float _dispScale = 1.0f;
    private int _dispOffsetX = 0;
    private int _dispOffsetY = 0;
#endif
    private bool _isWorking = false;

    // Start is called before the first frame update
    void Start()
    {
        WebCamDevice[] devices = WebCamTexture.devices;
        _webCamTexture = new WebCamTexture(devices[0].name, CAM_WIDTH, CAM_HEIGHT, CAM_FPS);
        rawImage.texture = _webCamTexture;
        _webCamTexture.Play();

        Debug.Log($"Screen Width : {Screen.width}, height : {Screen.height}");
   
        _scaleX = (float)(_webCamTexture.width * _dispScale / IMG_SIZE);
        _scaleY = (float)(_webCamTexture.height * _dispScale / IMG_SIZE) ;
        Debug.Log($"Camera Width : {_webCamTexture.width}, height : {_webCamTexture.height}");
        Debug.Log($"Disp Scale : {_dispScale}, xscale : {_scaleX}, yscale : {_scaleY}");
        Debug.Log($"Offset x : {_dispOffsetX}, y : {_dispOffsetY}");

        // Object Detection (YOLOv5)
        _engine = new MLEngine(modelAsset);

        // Line Texture
        _lineTexture = new Texture2D(1, 1);
        _lineTexture.SetPixel(0, 0, Color.red);
        _lineTexture.Apply();
        // GUI Style
        _guiStyle = new GUIStyle();
        _guiStyle.fontSize = 24;
        _guiStyle.normal.textColor = Color.red;
    }

    // Update is called once per frame
    void Update()
    {
        if (_isWorking) return;
        _isWorking = true;
        // 推論
        _input = PreProcessor.PreProcImage(_webCamTexture, IMG_SIZE);
        _preds = _engine.Execute(_input, confThreshold, iouThreshold);
        _input.Dispose();
        _isWorking = false;
    }

    private void OnGUI() 
    {
        if (_preds.Count > 0) 
        {
            for (int i = 0; i < _preds.Count; i++) 
            {
                DrawBoundingBox(_preds[i]);
            }
            
        }
    }
    
    void DrawBoundingBox(Candidate pred)
    {
        int x = (int)(pred.bbox.x1 * _scaleX) + _dispOffsetX;
        int y = (int)(pred.bbox.y1 * _scaleY) + _dispOffsetY;
        int width = (int)((pred.bbox.x2 - pred.bbox.x1) * _scaleX);
        int height = (int)((pred.bbox.y2 - pred.bbox.y1) * _scaleY);
        Debug.Log($"{pred.label}, Box[{pred.bbox.x1}, {pred.bbox.y1}, {pred.bbox.x2}, {pred.bbox.y2}] -> Box[{x}, {y}, {width}, {height}]");
        DrawRectangle(new Rect(x, y, width, height), 3, Color.red);
        DrawLabel(new Rect(x + 10, y - 30, 200, 30), $"{pred.label}: {pred.score:F3}");
    }

    void DrawRectangle(Rect area, int frameWidth, Color color)
    {
        Rect lineArea = area;
        lineArea.height = frameWidth;
        GUI.DrawTexture(lineArea, _lineTexture);
        lineArea.y = area.yMax - frameWidth;
        GUI.DrawTexture(lineArea, _lineTexture);
        lineArea = area;
        lineArea.width = frameWidth;
        GUI.DrawTexture(lineArea, _lineTexture);
        lineArea.x = area.xMax - frameWidth;
        GUI.DrawTexture(lineArea, _lineTexture);
    }

    void DrawLabel(Rect pos, string text)
    {
        GUI.Label(pos, text, _guiStyle);
    }
}

一応, これでUnityで物体検出YOLOv5を使用することができるようになった.
また, ANDROID端末で動作検証した結果, 内蔵カメラの画像で物体検知できることも確認した.
ただし, ANDROID端末では処理が重く, 5-10fps程度であった.

----
参照URL:
[1] UnityでBarracuda + YOLOv5を試してみる(1) - みらいテックラボ
[2] ultralytics/yolov5: YOLOv5 in PyTorch > ONNX > CoreML > TFLite