Schnelle Überprüfung auf NaN in NumPy

120

Ich suche nach dem schnellsten Weg, um das Auftreten von NaN ( np.nan) in einem NumPy-Array zu überprüfen X. np.isnan(X)kommt nicht in Frage, da es ein boolesches Array von Formen bildet X.shape, das möglicherweise gigantisch ist.

Ich habe es versucht np.nan in X, aber das scheint nicht zu funktionieren, weil np.nan != np.nan. Gibt es eine schnelle und speichereffiziente Möglichkeit, dies überhaupt zu tun?

(Für diejenigen, die fragen würden, "wie gigantisch": Ich kann es nicht sagen. Dies ist eine Eingabevalidierung für den Bibliothekscode.)

Fred Foo
quelle
Funktioniert die Überprüfung der Benutzereingaben in diesem Szenario nicht? Wie bei der Überprüfung auf NaN vor dem Einfügen
Woot4Moo
@ Woot4Moo: Nein, die Bibliothek verwendet NumPy-Arrays oder scipy.sparse-Matrizen als Eingabe.
Fred Foo
2
Wenn Sie dies viel tun, habe ich gute Dinge über Bottleneck ( pypi.python.org/pypi/Bottleneck ) gehört
matt

Antworten:

160

Rays Lösung ist gut. Allerdings auf meiner Maschine handelt es sich um 2,5 - fach schneller verwenden numpy.sumanstelle von numpy.min:

In [13]: %timeit np.isnan(np.min(x))
1000 loops, best of 3: 244 us per loop

In [14]: %timeit np.isnan(np.sum(x))
10000 loops, best of 3: 97.3 us per loop

Im Gegensatz dazu minist sumkeine Verzweigung erforderlich, was auf moderner Hardware eher teuer ist. Dies ist wahrscheinlich der Grund, warum sumes schneller geht.

Bearbeiten Der obige Test wurde mit einem einzigen NaN direkt in der Mitte des Feldes durchgeführt.

Es ist interessant festzustellen, dass dies minin Gegenwart von NaN langsamer ist als in Abwesenheit. Es scheint auch langsamer zu werden, wenn sich NaNs dem Anfang des Arrays nähern. Andererseits sumscheint der Durchsatz konstant zu sein, unabhängig davon, ob es NaNs gibt und wo sie sich befinden:

In [40]: x = np.random.rand(100000)

In [41]: %timeit np.isnan(np.min(x))
10000 loops, best of 3: 153 us per loop

In [42]: %timeit np.isnan(np.sum(x))
10000 loops, best of 3: 95.9 us per loop

In [43]: x[50000] = np.nan

In [44]: %timeit np.isnan(np.min(x))
1000 loops, best of 3: 239 us per loop

In [45]: %timeit np.isnan(np.sum(x))
10000 loops, best of 3: 95.8 us per loop

In [46]: x[0] = np.nan

In [47]: %timeit np.isnan(np.min(x))
1000 loops, best of 3: 326 us per loop

In [48]: %timeit np.isnan(np.sum(x))
10000 loops, best of 3: 95.9 us per loop
NPE
quelle
1
np.minist schneller, wenn das Array keine NaNs enthält, was meine erwartete Eingabe ist. Aber ich habe beschlossen, dieses trotzdem zu akzeptieren, weil es fängt infund neginfauch.
Fred Foo
2
Dies fängt nur ab infoder -infwenn die Eingabe beide enthält, und es gibt Probleme, wenn die Eingabe große, aber endliche Werte enthält, die beim Addieren überlaufen.
user2357112 unterstützt Monica
4
min und max müssen nicht für Gleitkommadaten auf sse-fähigen x86-Chips verzweigen. Ab Numpy sind 1,8 Minuten nicht langsamer als die Summe, bei meinem AMD-Phänomen sind es sogar 20% schneller.
Jtaylor
1
Auf meinem Intel Core i5 ist Numpy 1.9.2 unter OSX np.sumimmer noch etwa 30% schneller als np.min.
Matthew Brett
np.isnan(x).any(0)ist etwas schneller als np.sumund np.minauf meinem Computer, obwohl es möglicherweise zu unerwünschtem Caching kommt.
jsignell
28

Ich denke np.isnan(np.min(X))sollte tun was du willst.

Strahl
quelle
Hmmm ... das ist immer O (n), wenn es O (1) sein könnte (für einige Arrays).
user48956
17

Auch wenn es eine akzeptierte Antwort gibt, möchte ich Folgendes demonstrieren (mit Python 2.7.2 und Numpy 1.6.0 unter Vista):

In []: x= rand(1e5)
In []: %timeit isnan(x.min())
10000 loops, best of 3: 200 us per loop
In []: %timeit isnan(x.sum())
10000 loops, best of 3: 169 us per loop
In []: %timeit isnan(dot(x, x))
10000 loops, best of 3: 134 us per loop

In []: x[5e4]= NaN
In []: %timeit isnan(x.min())
100 loops, best of 3: 4.47 ms per loop
In []: %timeit isnan(x.sum())
100 loops, best of 3: 6.44 ms per loop
In []: %timeit isnan(dot(x, x))
10000 loops, best of 3: 138 us per loop

Daher kann der wirklich effiziente Weg stark vom Betriebssystem abhängen. Auf jeden dot(.)Fall scheint die Basis die stabilste zu sein.

Essen
quelle
1
Ich vermute, es hängt nicht so sehr vom Betriebssystem ab, sondern von der zugrunde liegenden BLAS-Implementierung und dem C-Compiler. Vielen Dank, aber ein Punktprodukt läuft nur ein bisschen häufiger über, wenn es xgroße Werte enthält, und ich möchte auch nach Informationen suchen.
Fred Foo
1
Nun, Sie können das Punktprodukt immer mit denen machen und verwenden isfinite(.). Ich wollte nur auf die enorme Leistungslücke hinweisen. Danke
essen
Das gleiche gilt für meine Maschine.
Kawing-Chiu
1
Clever, nein? Wie Fred Foo vorschlägt, sind Effizienzgewinne des auf Dot-Produkten basierenden Ansatzes mit ziemlicher Sicherheit auf eine lokale NumPy-Installation zurückzuführen, die mit einer optimierten BLAS-Implementierung wie ATLAS, MKL oder OpenBLAS verknüpft ist. Dies ist beispielsweise bei Anaconda der Fall. Angesichts dessen wird dieses Punktprodukt über alle verfügbaren Kerne parallelisiert . Das Gleiche gilt nicht für die min- oder sum-basierten Ansätze, die auf einen einzelnen Kern beschränkt sind. Ergo, diese Leistungslücke.
Cecil Curry
16

Hier gibt es zwei allgemeine Ansätze:

  • Überprüfen Sie jedes Array-Element auf nanund nehmen Sie any.
  • Wenden Sie eine kumulative Operation an, bei der nans (like sum) erhalten bleibt, und überprüfen Sie das Ergebnis.

Während der erste Ansatz sicherlich der sauberste ist, kann die starke Optimierung einiger der kumulativen Operationen (insbesondere derjenigen, die wie in BLAS ausgeführt werden dot) diese recht schnell machen. Beachten Sie, dass dotwie bei einigen anderen BLAS-Operationen unter bestimmten Bedingungen Multithreading ausgeführt wird. Dies erklärt den Geschwindigkeitsunterschied zwischen verschiedenen Maschinen.

Geben Sie hier die Bildbeschreibung ein

import numpy
import perfplot


def min(a):
    return numpy.isnan(numpy.min(a))


def sum(a):
    return numpy.isnan(numpy.sum(a))


def dot(a):
    return numpy.isnan(numpy.dot(a, a))


def any(a):
    return numpy.any(numpy.isnan(a))


def einsum(a):
    return numpy.isnan(numpy.einsum("i->", a))


perfplot.show(
    setup=lambda n: numpy.random.rand(n),
    kernels=[min, sum, dot, any, einsum],
    n_range=[2 ** k for k in range(20)],
    logx=True,
    logy=True,
    xlabel="len(a)",
)
Nico Schlömer
quelle
4
  1. benutze .any ()

    if numpy.isnan(myarray).any()

  2. numpy.isfinite vielleicht besser als isnan für die Überprüfung

    if not np.isfinite(prop).all()

woso
quelle
3

Wenn Sie sich wohl fühlen Es ermöglicht die Erzeugung eines schnellen Kurzschlusses (stoppt, sobald ein NaN gefunden wird):

import numba as nb
import math

@nb.njit
def anynan(array):
    array = array.ravel()
    for i in range(array.size):
        if math.isnan(array[i]):
            return True
    return False

Wenn dies nicht NaNder Fall ist, ist die Funktion möglicherweise langsamer als np.min. Ich denke, das liegt daran, dass np.minMultiprocessing für große Arrays verwendet wird:

import numpy as np
array = np.random.random(2000000)

%timeit anynan(array)          # 100 loops, best of 3: 2.21 ms per loop
%timeit np.isnan(array.sum())  # 100 loops, best of 3: 4.45 ms per loop
%timeit np.isnan(array.min())  # 1000 loops, best of 3: 1.64 ms per loop

Wenn sich jedoch ein NaN im Array befindet, insbesondere wenn seine Position bei niedrigen Indizes liegt, ist es viel schneller:

array = np.random.random(2000000)
array[100] = np.nan

%timeit anynan(array)          # 1000000 loops, best of 3: 1.93 µs per loop
%timeit np.isnan(array.sum())  # 100 loops, best of 3: 4.57 ms per loop
%timeit np.isnan(array.min())  # 1000 loops, best of 3: 1.65 ms per loop

Ähnliche Ergebnisse können mit Cython oder einer C-Erweiterung erzielt werden. Diese sind etwas komplizierter (oder leicht verfügbar als bottleneck.anynan), tun aber letztendlich das Gleiche wie meine anynanFunktion.

MSeifert
quelle
1

Damit verbunden ist die Frage, wie das erste Auftreten von NaN gefunden werden kann. Dies ist der schnellste Weg, um mit dem umzugehen, von dem ich weiß:

index = next((i for (i,n) in enumerate(iterable) if n!=n), None)
vitiral
quelle