W tym artykule poznasz praktyczne zastosowanie funkcji `argmax()` z biblioteki NumPy, która pozwala na odnalezienie indeksu elementu o największej wartości w tablicach.
NumPy to niezwykle przydatna biblioteka Pythona, szczególnie w dziedzinie obliczeń naukowych. Udostępnia ona struktury danych w postaci tablic N-wymiarowych, oferujące znacznie większą efektywność niż standardowe listy Pythona. Często podczas pracy z tablicami NumPy pojawia się potrzeba wyznaczenia największej wartości. Niekiedy jednak, bardziej interesujące jest znalezienie indeksu, pod którym ta maksymalna wartość się znajduje.
Funkcja `argmax()` umożliwia szybkie i proste określenie indeksu maksymalnego elementu zarówno w tablicach jednowymiarowych, jak i wielowymiarowych. Zobaczmy, jak to działa w praktyce.
Jak zlokalizować indeks największej wartości w tablicy NumPy?
Aby w pełni skorzystać z tego poradnika, powinieneś mieć zainstalowane środowisko Python oraz bibliotekę NumPy. Kod można uruchamiać w interaktywnej konsoli Pythona (REPL) lub w notatniku Jupyter.
Na początku zaimportujmy bibliotekę NumPy, używając standardowego aliasu `np`.
import numpy as np
Do wyznaczenia maksymalnej wartości w tablicy (opcjonalnie wzdłuż konkretnej osi) możesz wykorzystać funkcję `np.max()`.
array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.max(array_1)) # Wyjście 10
W powyższym przykładzie, `np.max(array_1)` zwraca wartość 10, co jest prawidłowym wynikiem.
Załóżmy, że teraz chcesz poznać indeks, pod którym w tablicy występuje ta maksymalna wartość. Możesz do tego podejść dwuetapowo:
- Wyszukaj maksymalny element.
- Znajdź indeks tego elementu.
W tablicy `array_1`, wartość maksymalna (10) znajduje się na indeksie 4 (pamiętając o indeksowaniu od zera). Pierwszy element ma indeks 0, drugi 1, i tak dalej.
Do odnalezienia indeksu, na którym znajduje się wartość maksymalna, można wykorzystać funkcję `np.where()`. `np.where(warunek)` zwraca tablicę indeksów, dla których warunek jest prawdziwy.
Należy odwołać się do tablicy i pobrać element z pierwszego indeksu. Aby dowiedzieć się, gdzie leży maksymalna wartość, ustawiamy warunek `array_1==10`, pamiętając, że 10 jest maksymalną wartością w `array_1`.
print(int(np.where(array_1==10)[0])) # Wyjście 4
Chociaż wykorzystaliśmy `np.where()` tylko z warunkiem, nie jest to zalecane podejście do korzystania z tej funkcji.
📑 **Warto wiedzieć**: Funkcja NumPy `where()`:
`np.where(warunek, x, y)` zwraca:
- elementy z `x`, jeśli warunek jest spełniony, oraz
- elementy z `y`, jeśli warunek jest fałszywy.
Zatem, łącząc funkcje `np.max()` i `np.where()` jesteśmy w stanie odnaleźć element o największej wartości, a następnie indeks, pod którym się on znajduje.
Jednak zamiast takiego dwuetapowego postępowania, możemy użyć funkcji `argmax()` z NumPy, która bezpośrednio zwraca indeks maksymalnej wartości w tablicy.
Składnia funkcji NumPy `argmax()`
Ogólna struktura składni funkcji `argmax()` w NumPy wygląda następująco:
np.argmax(tablica, oś, out) # NumPy zaimportowaliśmy jako np
W powyższej składni:
- `tablica` to dowolna prawidłowa tablica NumPy.
- `oś` to parametr opcjonalny. W przypadku tablic wielowymiarowych, parametr `oś` pozwala na znalezienie indeksu maksimum wzdłuż konkretnej osi.
- `out` to kolejny parametr opcjonalny. Można przypisać do niego tablicę NumPy, w której zostaną zapisane wyniki działania funkcji `argmax()`.
Warto wspomnieć: Od wersji NumPy 1.22.0 dostępny jest dodatkowy parametr `keepdims`. Gdy określimy parametr `oś` w wywołaniu `argmax()`, tablica jest redukowana wzdłuż tej osi. Ustawienie `keepdims` na `True` sprawia, że zwracana tablica zachowuje taki sam kształt jak tablica wejściowa.
Wykorzystanie `argmax()` do odnalezienia indeksu elementu o największej wartości
#1. Spróbujmy użyć funkcji `argmax()` do zlokalizowania indeksu maksymalnego elementu w tablicy `array_1`.
array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.argmax(array_1)) # Wyjście 4
Funkcja `argmax()` zwraca 4, co jest prawidłowym wynikiem!
#2. Jeżeli zdefiniujemy tablicę `array_1` tak, że wartość 10 występuje dwukrotnie, funkcja `argmax()` wskaże jedynie indeks pierwszego z tych wystąpień.
array_1 = np.array([1,5,7,2,10,10,8,4]) print(np.argmax(array_1)) # Wyjście 4
W dalszych przykładach będziemy pracować z tablicą `array_1` zdefiniowaną w przykładzie pierwszym.
Użycie `argmax()` w tablicy 2D
Zmieńmy strukturę tablicy `array_1` na tablicę dwuwymiarową, składającą się z dwóch wierszy i czterech kolumn.
array_2 = array_1.reshape(2,4) print(array_2) # Wyjście [[ 1 5 7 2] [10 9 8 4]]
W tablicy dwuwymiarowej, oś 0 odnosi się do wierszy, natomiast oś 1 do kolumn. Indeksowanie w tablicach NumPy startuje od zera. Zatem indeksy wierszy i kolumn dla `array_2` przedstawiają się następująco:
Wywołajmy teraz funkcję `argmax()` na tablicy dwuwymiarowej `array_2`.
print(np.argmax(array_2)) # Wyjście 4
Mimo że `argmax()` została zastosowana do tablicy dwuwymiarowej, wciąż zwraca ona wartość 4. Jest to ten sam wynik, jaki otrzymaliśmy dla tablicy jednowymiarowej `array_1`.
Dlaczego tak się dzieje?
Dzieje się tak dlatego, że nie zdefiniowaliśmy wartości parametru `oś`. W takim przypadku, funkcja `argmax()` domyślnie zwraca indeks maksymalnego elementu wzdłuż spłaszczonej tablicy.
Czym jest tablica spłaszczona? Jeżeli mamy tablicę N-wymiarową o kształcie `d1 x d2 x … x dN`, gdzie `d1, d2, … dN` to wymiary tablicy, to spłaszczona tablica jest długą jednowymiarową tablicą o rozmiarze `d1 * d2 * … * dN`.
Aby zobaczyć, jak wygląda spłaszczona tablica dla `array_2`, można wywołać metodę `flatten()`, jak pokazano poniżej:
array_2.flatten() # Wyjście array([ 1, 5, 7, 2, 10, 9, 8, 4])
Indeks maksymalnego elementu wzdłuż wierszy (oś = 0)
Przejdźmy do poszukiwania indeksu maksymalnego elementu wzdłuż wierszy (oś = 0).
np.argmax(array_2,axis=0) # Wyjście array([1, 1, 1, 1])
Wynik ten może być na pierwszy rzut oka niezrozumiały, ale wyjaśnimy, jak on działa.
Ustawiliśmy parametr `oś` na zero (oś=0), ponieważ chcemy znaleźć indeks maksymalnego elementu wzdłuż wierszy. W konsekwencji, funkcja `argmax()` zwraca numer wiersza, w którym znajduje się element o największej wartości – dla każdej z kolumn.
Aby to lepiej zrozumieć, posłużmy się wizualizacją.
Z powyższego schematu i wyniku `argmax()` wynika:
- W pierwszej kolumnie (o indeksie 0), maksymalna wartość (10) leży w drugim wierszu (indeks = 1).
- W drugiej kolumnie (o indeksie 1), maksymalna wartość (9) leży w drugim wierszu (indeks = 1).
- W trzeciej i czwartej kolumnie (o indeksach 2 i 3), maksymalne wartości (8 i 4) również leżą w drugim wierszu (indeks = 1).
Właśnie dlatego wynikiem jest tablica `[1, 1, 1, 1]`, ponieważ maksymalny element wzdłuż wierszy znajduje się w drugim wierszu (dla wszystkich kolumn).
Indeks maksymalnego elementu wzdłuż kolumn (oś = 1)
Teraz użyjemy funkcji `argmax()` do znalezienia indeksu maksymalnego elementu wzdłuż kolumn.
Uruchom poniższy kod i przeanalizuj wynik.
np.argmax(array_2,axis=1)
array([2, 0])
Czy potrafisz zinterpretować ten rezultat?
Ustawiliśmy `oś=1`, aby obliczyć indeks maksymalnego elementu wzdłuż kolumn.
Funkcja `argmax()` dla każdego wiersza zwraca numer kolumny, w której występuje element o największej wartości.
Oto wizualne wyjaśnienie:
Z powyższego schematu i rezultatu `argmax()` wynika:
- W pierwszym wierszu (o indeksie 0) maksymalna wartość (7) leży w trzeciej kolumnie (indeks = 2).
- W drugim wierszu (o indeksie 1) maksymalna wartość (10) leży w pierwszej kolumnie (indeks = 0).
Mam nadzieję, że teraz rozumiesz, co oznacza wynik `array([2, 0])`.
Wykorzystanie opcjonalnego parametru `out` w `argmax()`
Możesz użyć opcjonalnego parametru `out` w funkcji `argmax()` z NumPy, aby zapisać wynik w tablicy NumPy.
Zainicjujmy tablicę zer, aby pomieścić wynik poprzedniego wywołania `argmax()` – czyli indeksy maksymalnych wartości wzdłuż kolumn (oś = 1).
out_arr = np.zeros((2,)) print(out_arr) [0. 0.]
Powróćmy teraz do przykładu odnajdywania indeksu maksymalnej wartości wzdłuż kolumn (oś = 1) i ustawmy `out` na `out_arr`, którą przed chwilą zdefiniowaliśmy.
np.argmax(array_2,axis=1,out=out_arr)
Jak widać, interpreter Pythona zgłasza błąd `TypeError`, ponieważ `out_arr` został domyślnie zainicjowany jako tablica elementów zmiennoprzecinkowych.
TypeError Traceback (most recent call last) /usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds) 56 try: ---> 57 return bound(*args, **kwds) 58 except TypeError: TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'
Zatem, przy definiowaniu parametru `out` w postaci tablicy, należy upewnić się, że tablica wyjściowa ma odpowiedni kształt i typ danych. Ponieważ indeksy w tablicy zawsze są liczbami całkowitymi, powinniśmy przy definicji tablicy wyjściowej ustawić parametr `dtype` na `int`.
out_arr = np.zeros((2,),dtype=int) print(out_arr) # Wyjście [0 0]
Teraz możemy bez obaw wywołać funkcję `argmax()` z parametrami `oś` i `out` i tym razem wszystko zadziała prawidłowo.
np.argmax(array_2,axis=1,out=out_arr)
Wynik działania `argmax()` jest teraz dostępny w tablicy `out_arr`.
print(out_arr) # Wyjście [2 0]
Podsumowanie
Mam nadzieję, że ten poradnik pomógł Ci zrozumieć, jak korzystać z funkcji `argmax()` biblioteki NumPy. Przykłady kodu możesz uruchomić w notesie Jupyter.
Spójrzmy jeszcze raz na to, czego się nauczyliśmy.
- Funkcja `argmax()` z NumPy zwraca indeks elementu o największej wartości w tablicy. Jeżeli element maksymalny występuje więcej niż raz w tablicy `a`, `np.argmax(a)` zwróci indeks pierwszego wystąpienia tego elementu.
- Pracując z tablicami wielowymiarowymi, możemy użyć parametru opcjonalnego `oś` do uzyskania indeksu elementu maksymalnego wzdłuż danej osi. Przykładowo, w tablicy dwuwymiarowej, ustawienie `oś = 0` i `oś = 1` pozwala na uzyskanie indeksu maksymalnego elementu odpowiednio wzdłuż wierszy i kolumn.
- Jeżeli chcemy przechować zwróconą wartość w innej tablicy, możemy użyć parametru opcjonalnego `out` przypisując mu tablicę wyjściową. Pamiętać jednak należy, że tablica wyjściowa powinna mieć właściwy kształt i typ danych.
W następnej kolejności warto zapoznać się ze szczegółowym przewodnikiem po zbiorach w Pythonie.