NumPyでフィルタリングを高速に行う方法

例えば、OpenCVで2値化済みの画像から黒色(つまり値が0)のピクセルの数を数えることを考えます。

一番分かりやすいのは単純にfor文で回して数を数える方法です。bin_imgは2値化済みの画像でnumpy.ndarray型の2次元配列とします。

count = 0
for color in bin_img.flat:
    count += 1 if color == 0 else 0

また、せっかくのPythonなので、filterやreduceを用いて関数型言語っぽく書くと、よりスマートになります。

count = len(filter(lambda x: x == 0, bin_img.flat))
count = reduce(lambda m, x: m+1 if x == 0 else m, bin_img.flat, 0)

問題点

さて、ここまでの実装ですが、分かりやすくスマートである一方、イテレーションと演算がPythonインタプリタ上で行われるため、残念ながら速度が遅いのがネックです。

じゃあ、もっと良い方法があるのか、という話になるのですが、これまた残念ながらNumPyでは任意のソート*1やフィルタリングを行うためのメソッドは提供されていません。というのも、そもそもNumPyは各種演算をネイティブバイナリに任せることで高速化を図っているので、ユーザが比較関数などをPythonで記述しては意味が無いのですね。しかしながら、代わりに、NumPyは基本的な演算メソッドを大量に用意してくれています。

Universal functions (ufunc) — NumPy v1.7 Manual (DRAFT)

Boolean Arrayを利用したフィルタリング

ようやく本題となりますが、ここではNumPyのBoolean Arrayによるマスクを利用してフィルタリングを行うことで高速化してみたいと思います。コードは超シンプルです。

count = len(bin_img[bin_img==0])

一見、不思議な気持ちになる記述ですが、これで問題ありません。bin_img==0numpy.equalエイリアスのようですが、これにより、True/Falseのマップが作成されます。そのマップを利用してbin_imgをフィルタするという仕組みです。

ベンチマーク

せっかくなので、適当なベンチマークも取ってみました。

コードは以下の通りです。

import numpy as np

def test1(a):
    count = 0
    for x in a.flat:
        count += 1 if x == 0 else 0
    return count

def test2(a):
    return len(filter(lambda x: x == 0, a.flat))

def test3(a):
    return reduce(lambda m, x: m+1 if x == 0 else m, a.flat, 0)

def test4(a):
    return len(a[a==0])

def main():
    for i in range(200):
        size = i + 1
        a = np.identity(size)
        ans = size**2 - size
        assert test1(a) == ans
        assert test2(a) == ans
        assert test3(a) == ans
        assert test4(a) == ans

if __name__ == "__main__":
    main()

プロファイリングはcProfileで、

$ python -m cProfile -s cumulative filter-test.py > filter-test.log

結果の一部抜粋です。

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
...
      200    0.211    0.001    4.188    0.021 filter-test.py:10(test2)
      200    0.001    0.000    4.101    0.021 filter-test.py:13(test3)
...
      200    3.259    0.016    3.259    0.016 filter-test.py:4(test1)
...
      200    0.134    0.001    0.134    0.001 filter-test.py:16(test4)

「filter・reduce>for文>>NumPy」という結果が得られました。順当と言った所だとは思いますが、しかしNumPy早いですねー。

参考文献

*1:比較関数をユーザが書くという意味で「任意」の