Jak używać funkcji NumPy argmax() w Pythonie?

W tym artykule poznasz praktyczne zastosowanie funkcji `argmax()` z biblioteki NumPy, która pozwala na odnalezienie indeksu elementu o największej wartości w tablicach.

NumPy to niezwykle przydatna biblioteka Pythona, szczególnie w dziedzinie obliczeń naukowych. Udostępnia ona struktury danych w postaci tablic N-wymiarowych, oferujące znacznie większą efektywność niż standardowe listy Pythona. Często podczas pracy z tablicami NumPy pojawia się potrzeba wyznaczenia największej wartości. Niekiedy jednak, bardziej interesujące jest znalezienie indeksu, pod którym ta maksymalna wartość się znajduje.

Funkcja `argmax()` umożliwia szybkie i proste określenie indeksu maksymalnego elementu zarówno w tablicach jednowymiarowych, jak i wielowymiarowych. Zobaczmy, jak to działa w praktyce.

Jak zlokalizować indeks największej wartości w tablicy NumPy?

Aby w pełni skorzystać z tego poradnika, powinieneś mieć zainstalowane środowisko Python oraz bibliotekę NumPy. Kod można uruchamiać w interaktywnej konsoli Pythona (REPL) lub w notatniku Jupyter.

Na początku zaimportujmy bibliotekę NumPy, używając standardowego aliasu `np`.

import numpy as np

Do wyznaczenia maksymalnej wartości w tablicy (opcjonalnie wzdłuż konkretnej osi) możesz wykorzystać funkcję `np.max()`.

array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))

# Wyjście
10

W powyższym przykładzie, `np.max(array_1)` zwraca wartość 10, co jest prawidłowym wynikiem.

Załóżmy, że teraz chcesz poznać indeks, pod którym w tablicy występuje ta maksymalna wartość. Możesz do tego podejść dwuetapowo:

  • Wyszukaj maksymalny element.
  • Znajdź indeks tego elementu.

W tablicy `array_1`, wartość maksymalna (10) znajduje się na indeksie 4 (pamiętając o indeksowaniu od zera). Pierwszy element ma indeks 0, drugi 1, i tak dalej.

Do odnalezienia indeksu, na którym znajduje się wartość maksymalna, można wykorzystać funkcję `np.where()`. `np.where(warunek)` zwraca tablicę indeksów, dla których warunek jest prawdziwy.

Należy odwołać się do tablicy i pobrać element z pierwszego indeksu. Aby dowiedzieć się, gdzie leży maksymalna wartość, ustawiamy warunek `array_1==10`, pamiętając, że 10 jest maksymalną wartością w `array_1`.

print(int(np.where(array_1==10)[0]))

# Wyjście
4

Chociaż wykorzystaliśmy `np.where()` tylko z warunkiem, nie jest to zalecane podejście do korzystania z tej funkcji.

📑 **Warto wiedzieć**: Funkcja NumPy `where()`:
`np.where(warunek, x, y)` zwraca:

  • elementy z `x`, jeśli warunek jest spełniony, oraz
  • elementy z `y`, jeśli warunek jest fałszywy.

Zatem, łącząc funkcje `np.max()` i `np.where()` jesteśmy w stanie odnaleźć element o największej wartości, a następnie indeks, pod którym się on znajduje.

Jednak zamiast takiego dwuetapowego postępowania, możemy użyć funkcji `argmax()` z NumPy, która bezpośrednio zwraca indeks maksymalnej wartości w tablicy.

Składnia funkcji NumPy `argmax()`

Ogólna struktura składni funkcji `argmax()` w NumPy wygląda następująco:

np.argmax(tablica, oś, out)
# NumPy zaimportowaliśmy jako np

W powyższej składni:

  • `tablica` to dowolna prawidłowa tablica NumPy.
  • `oś` to parametr opcjonalny. W przypadku tablic wielowymiarowych, parametr `oś` pozwala na znalezienie indeksu maksimum wzdłuż konkretnej osi.
  • `out` to kolejny parametr opcjonalny. Można przypisać do niego tablicę NumPy, w której zostaną zapisane wyniki działania funkcji `argmax()`.

Warto wspomnieć: Od wersji NumPy 1.22.0 dostępny jest dodatkowy parametr `keepdims`. Gdy określimy parametr `oś` w wywołaniu `argmax()`, tablica jest redukowana wzdłuż tej osi. Ustawienie `keepdims` na `True` sprawia, że zwracana tablica zachowuje taki sam kształt jak tablica wejściowa.

Wykorzystanie `argmax()` do odnalezienia indeksu elementu o największej wartości

#1. Spróbujmy użyć funkcji `argmax()` do zlokalizowania indeksu maksymalnego elementu w tablicy `array_1`.

array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.argmax(array_1))

# Wyjście
4

Funkcja `argmax()` zwraca 4, co jest prawidłowym wynikiem!

#2. Jeżeli zdefiniujemy tablicę `array_1` tak, że wartość 10 występuje dwukrotnie, funkcja `argmax()` wskaże jedynie indeks pierwszego z tych wystąpień.

array_1 = np.array([1,5,7,2,10,10,8,4])
print(np.argmax(array_1))

# Wyjście
4

W dalszych przykładach będziemy pracować z tablicą `array_1` zdefiniowaną w przykładzie pierwszym.

Użycie `argmax()` w tablicy 2D

Zmieńmy strukturę tablicy `array_1` na tablicę dwuwymiarową, składającą się z dwóch wierszy i czterech kolumn.

array_2 = array_1.reshape(2,4)
print(array_2)

# Wyjście
[[ 1  5  7  2]
 [10  9  8  4]]

W tablicy dwuwymiarowej, oś 0 odnosi się do wierszy, natomiast oś 1 do kolumn. Indeksowanie w tablicach NumPy startuje od zera. Zatem indeksy wierszy i kolumn dla `array_2` przedstawiają się następująco:

Wywołajmy teraz funkcję `argmax()` na tablicy dwuwymiarowej `array_2`.

print(np.argmax(array_2))

# Wyjście
4

Mimo że `argmax()` została zastosowana do tablicy dwuwymiarowej, wciąż zwraca ona wartość 4. Jest to ten sam wynik, jaki otrzymaliśmy dla tablicy jednowymiarowej `array_1`.

Dlaczego tak się dzieje?

Dzieje się tak dlatego, że nie zdefiniowaliśmy wartości parametru `oś`. W takim przypadku, funkcja `argmax()` domyślnie zwraca indeks maksymalnego elementu wzdłuż spłaszczonej tablicy.

Czym jest tablica spłaszczona? Jeżeli mamy tablicę N-wymiarową o kształcie `d1 x d2 x … x dN`, gdzie `d1, d2, … dN` to wymiary tablicy, to spłaszczona tablica jest długą jednowymiarową tablicą o rozmiarze `d1 * d2 * … * dN`.

Aby zobaczyć, jak wygląda spłaszczona tablica dla `array_2`, można wywołać metodę `flatten()`, jak pokazano poniżej:

array_2.flatten()

# Wyjście
array([ 1,  5,  7,  2, 10,  9,  8,  4])

Indeks maksymalnego elementu wzdłuż wierszy (oś = 0)

Przejdźmy do poszukiwania indeksu maksymalnego elementu wzdłuż wierszy (oś = 0).

np.argmax(array_2,axis=0)

# Wyjście
array([1, 1, 1, 1])

Wynik ten może być na pierwszy rzut oka niezrozumiały, ale wyjaśnimy, jak on działa.

Ustawiliśmy parametr `oś` na zero (oś=0), ponieważ chcemy znaleźć indeks maksymalnego elementu wzdłuż wierszy. W konsekwencji, funkcja `argmax()` zwraca numer wiersza, w którym znajduje się element o największej wartości – dla każdej z kolumn.

Aby to lepiej zrozumieć, posłużmy się wizualizacją.

Z powyższego schematu i wyniku `argmax()` wynika:

  • W pierwszej kolumnie (o indeksie 0), maksymalna wartość (10) leży w drugim wierszu (indeks = 1).
  • W drugiej kolumnie (o indeksie 1), maksymalna wartość (9) leży w drugim wierszu (indeks = 1).
  • W trzeciej i czwartej kolumnie (o indeksach 2 i 3), maksymalne wartości (8 i 4) również leżą w drugim wierszu (indeks = 1).

Właśnie dlatego wynikiem jest tablica `[1, 1, 1, 1]`, ponieważ maksymalny element wzdłuż wierszy znajduje się w drugim wierszu (dla wszystkich kolumn).

Indeks maksymalnego elementu wzdłuż kolumn (oś = 1)

Teraz użyjemy funkcji `argmax()` do znalezienia indeksu maksymalnego elementu wzdłuż kolumn.

Uruchom poniższy kod i przeanalizuj wynik.

np.argmax(array_2,axis=1)
array([2, 0])

Czy potrafisz zinterpretować ten rezultat?

Ustawiliśmy `oś=1`, aby obliczyć indeks maksymalnego elementu wzdłuż kolumn.

Funkcja `argmax()` dla każdego wiersza zwraca numer kolumny, w której występuje element o największej wartości.

Oto wizualne wyjaśnienie:

Z powyższego schematu i rezultatu `argmax()` wynika:

  • W pierwszym wierszu (o indeksie 0) maksymalna wartość (7) leży w trzeciej kolumnie (indeks = 2).
  • W drugim wierszu (o indeksie 1) maksymalna wartość (10) leży w pierwszej kolumnie (indeks = 0).

Mam nadzieję, że teraz rozumiesz, co oznacza wynik `array([2, 0])`.

Wykorzystanie opcjonalnego parametru `out` w `argmax()`

Możesz użyć opcjonalnego parametru `out` w funkcji `argmax()` z NumPy, aby zapisać wynik w tablicy NumPy.

Zainicjujmy tablicę zer, aby pomieścić wynik poprzedniego wywołania `argmax()` – czyli indeksy maksymalnych wartości wzdłuż kolumn (oś = 1).

out_arr = np.zeros((2,))
print(out_arr)
[0. 0.]

Powróćmy teraz do przykładu odnajdywania indeksu maksymalnej wartości wzdłuż kolumn (oś = 1) i ustawmy `out` na `out_arr`, którą przed chwilą zdefiniowaliśmy.

np.argmax(array_2,axis=1,out=out_arr)

Jak widać, interpreter Pythona zgłasza błąd `TypeError`, ponieważ `out_arr` został domyślnie zainicjowany jako tablica elementów zmiennoprzecinkowych.

TypeError                                 Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
     56     try:
---> 57         return bound(*args, **kwds)
     58     except TypeError:

TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'

Zatem, przy definiowaniu parametru `out` w postaci tablicy, należy upewnić się, że tablica wyjściowa ma odpowiedni kształt i typ danych. Ponieważ indeksy w tablicy zawsze są liczbami całkowitymi, powinniśmy przy definicji tablicy wyjściowej ustawić parametr `dtype` na `int`.

out_arr = np.zeros((2,),dtype=int)
print(out_arr)

# Wyjście
[0 0]

Teraz możemy bez obaw wywołać funkcję `argmax()` z parametrami `oś` i `out` i tym razem wszystko zadziała prawidłowo.

np.argmax(array_2,axis=1,out=out_arr)

Wynik działania `argmax()` jest teraz dostępny w tablicy `out_arr`.

print(out_arr)
# Wyjście
[2 0]

Podsumowanie

Mam nadzieję, że ten poradnik pomógł Ci zrozumieć, jak korzystać z funkcji `argmax()` biblioteki NumPy. Przykłady kodu możesz uruchomić w notesie Jupyter.

Spójrzmy jeszcze raz na to, czego się nauczyliśmy.

  • Funkcja `argmax()` z NumPy zwraca indeks elementu o największej wartości w tablicy. Jeżeli element maksymalny występuje więcej niż raz w tablicy `a`, `np.argmax(a)` zwróci indeks pierwszego wystąpienia tego elementu.
  • Pracując z tablicami wielowymiarowymi, możemy użyć parametru opcjonalnego `oś` do uzyskania indeksu elementu maksymalnego wzdłuż danej osi. Przykładowo, w tablicy dwuwymiarowej, ustawienie `oś = 0` i `oś = 1` pozwala na uzyskanie indeksu maksymalnego elementu odpowiednio wzdłuż wierszy i kolumn.
  • Jeżeli chcemy przechować zwróconą wartość w innej tablicy, możemy użyć parametru opcjonalnego `out` przypisując mu tablicę wyjściową. Pamiętać jednak należy, że tablica wyjściowa powinna mieć właściwy kształt i typ danych.

W następnej kolejności warto zapoznać się ze szczegółowym przewodnikiem po zbiorach w Pythonie.