Ich möchte ein Streudiagramm erstellen, in dem jeder Punkt durch die räumliche Dichte benachbarter Punkte gefärbt wird.
Ich bin auf eine sehr ähnliche Frage gestoßen, die ein Beispiel dafür mit R zeigt:
R Streudiagramm: Die Symbolfarbe repräsentiert die Anzahl der überlappenden Punkte
Was ist der beste Weg, um mit matplotlib etwas Ähnliches in Python zu erreichen?
python
matplotlib
2964502
quelle
quelle
Antworten:
Zusätzlich zu
hist2d
oderhexbin
wie von @askewchan vorgeschlagen, können Sie dieselbe Methode verwenden, die die akzeptierte Antwort in der von Ihnen verknüpften Frage verwendet.Wenn Sie das tun möchten:
import numpy as np import matplotlib.pyplot as plt from scipy.stats import gaussian_kde # Generate fake data x = np.random.normal(size=1000) y = x * 3 + np.random.normal(size=1000) # Calculate the point density xy = np.vstack([x,y]) z = gaussian_kde(xy)(xy) fig, ax = plt.subplots() ax.scatter(x, y, c=z, s=100, edgecolor='') plt.show()
Wenn Sie möchten, dass die Punkte in der Reihenfolge ihrer Dichte dargestellt werden, sodass die dichtesten Punkte immer oben liegen (ähnlich wie im verknüpften Beispiel), sortieren Sie sie einfach nach den Z-Werten. Ich werde hier auch eine kleinere Markierungsgröße verwenden, da sie etwas besser aussieht:
import numpy as np import matplotlib.pyplot as plt from scipy.stats import gaussian_kde # Generate fake data x = np.random.normal(size=1000) y = x * 3 + np.random.normal(size=1000) # Calculate the point density xy = np.vstack([x,y]) z = gaussian_kde(xy)(xy) # Sort the points by density, so that the densest points are plotted last idx = z.argsort() x, y, z = x[idx], y[idx], z[idx] fig, ax = plt.subplots() ax.scatter(x, y, c=z, s=50, edgecolor='') plt.show()
quelle
plt.colorbar()
, oder wenn Sie es vorziehen, expliziter zu sein, tun Sie diescax = ax.scatter(...)
und dannfig.colorbar(cax)
. Beachten Sie, dass die Einheiten unterschiedlich sind. Diese Methode schätzt die Wahrscheinlichkeitsverteilungsfunktion für die Punkte, sodass die Werte zwischen 0 und 1 liegen (und normalerweise nicht sehr nahe an 1 kommen). Sie können wieder in etwas konvertieren, das näher an der Histogrammzahl liegt, dies erfordert jedoch ein wenig Arbeit (Sie müssen diegaussian_kde
aus den Daten geschätzten Parameter kennen ).Sie könnten ein Histogramm erstellen:
import numpy as np import matplotlib.pyplot as plt # fake data: a = np.random.normal(size=1000) b = a*3 + np.random.normal(size=1000) plt.hist2d(a, b, (50, 50), cmap=plt.cm.jet) plt.colorbar()
quelle
Wenn die Anzahl der Punkte die KDE-Berechnung zu langsam macht, kann die Farbe in np.histogram2d interpoliert werden. [Aktualisierung als Antwort auf Kommentare: Wenn Sie die Farbleiste anzeigen möchten, verwenden Sie plt.scatter () anstelle von ax.scatter () von plt.colorbar ()]:
import numpy as np import matplotlib.pyplot as plt from matplotlib import cm from matplotlib.colors import Normalize from scipy.interpolate import interpn def density_scatter( x , y, ax = None, sort = True, bins = 20, **kwargs ) : """ Scatter plot colored by 2d histogram """ if ax is None : fig , ax = plt.subplots() data , x_e, y_e = np.histogram2d( x, y, bins = bins, density = True ) z = interpn( ( 0.5*(x_e[1:] + x_e[:-1]) , 0.5*(y_e[1:]+y_e[:-1]) ) , data , np.vstack([x,y]).T , method = "splinef2d", bounds_error = False) #To be sure to plot all data z[np.where(np.isnan(z))] = 0.0 # Sort the points by density, so that the densest points are plotted last if sort : idx = z.argsort() x, y, z = x[idx], y[idx], z[idx] ax.scatter( x, y, c=z, **kwargs ) norm = Normalize(vmin = np.min(z), vmax = np.max(z)) cbar = fig.colorbar(cm.ScalarMappable(norm = norm), ax=ax) cbar.ax.set_ylabel('Density') return ax if "__main__" == __name__ : x = np.random.normal(size=100000) y = x * 3 + np.random.normal(size=100000) density_scatter( x, y, bins = [30,30] )
quelle
Zeichnen von> 100.000 Datenpunkten?
Die akzeptierte Antwort mit gaussian_kde () nimmt viel Zeit in Anspruch . Auf meinem Computer dauerten 100.000 Zeilen ungefähr 11 Minuten . Hier werde ich zwei alternative Methoden hinzufügen ( mpl-Scatter-Density und Datashader ) und die gegebenen Antworten mit demselben Datensatz vergleichen.
Im Folgenden habe ich einen Testdatensatz von 100.000 Zeilen verwendet:
import matplotlib.pyplot as plt import numpy as np # Fake data for testing x = np.random.normal(size=100000) y = x * 3 + np.random.normal(size=100000)
Vergleich von Ausgabe- und Rechenzeit
Nachfolgend finden Sie einen Vergleich verschiedener Methoden.
1: mpl-scatter-density
Installation
Beispielcode
import mpl_scatter_density # adds projection='scatter_density' from matplotlib.colors import LinearSegmentedColormap # "Viridis-like" colormap with white background white_viridis = LinearSegmentedColormap.from_list('white_viridis', [ (0, '#ffffff'), (1e-20, '#440053'), (0.2, '#404388'), (0.4, '#2a788e'), (0.6, '#21a784'), (0.8, '#78d151'), (1, '#fde624'), ], N=256) def using_mpl_scatter_density(fig, x, y): ax = fig.add_subplot(1, 1, 1, projection='scatter_density') density = ax.scatter_density(x, y, cmap=white_viridis) fig.colorbar(density, label='Number of points per pixel') fig = plt.figure() using_mpl_scatter_density(fig, x, y) plt.show()
Das Zeichnen dauerte 0,05 Sekunden:
Und das Vergrößern sieht ganz gut aus:
2: datashader
pip install "git+https://github.com/nvictus/datashader.git@mpl"
Code (Quelle der Dsshow hier ):
from functools import partial import datashader as ds from datashader.mpl_ext import dsshow import pandas as pd dyn = partial(ds.tf.dynspread, max_px=40, threshold=0.5) def using_datashader(ax, x, y): df = pd.DataFrame(dict(x=x, y=y)) da1 = dsshow(df, ds.Point('x', 'y'), spread_fn=dyn, aspect='auto', ax=ax) plt.colorbar(da1) fig, ax = plt.subplots() using_datashader(ax, x, y) plt.show()
und das gezoomte Bild sieht gut aus!
3: scatter_with_gaussian_kde
def scatter_with_gaussian_kde(ax, x, y): # https://stackoverflow.com/a/20107592/3015186 # Answer by Joel Kington xy = np.vstack([x, y]) z = gaussian_kde(xy)(xy) ax.scatter(x, y, c=z, s=100, edgecolor='')
4: using_hist2d
import matplotlib.pyplot as plt def using_hist2d(ax, x, y, bins=(50, 50)): # https://stackoverflow.com/a/20105673/3015186 # Answer by askewchan ax.hist2d(x, y, bins, cmap=plt.cm.jet)
5: density_scatter
quelle