Cómo usar la función NumPy argmax() en Python

En este tutorial, aprenderá a usar la función NumPy argmax() para encontrar el índice del elemento máximo en matrices.

NumPy es una poderosa biblioteca para computación científica en Python; proporciona matrices N-dimensionales que tienen un mejor rendimiento que las listas de Python. Una de las operaciones comunes que realizará cuando trabaje con matrices NumPy es encontrar el valor máximo en la matriz. Sin embargo, a veces es posible que desee encontrar el índice en el que se produce el valor máximo.

La función argmax() lo ayuda a encontrar el índice del máximo en matrices tanto unidimensionales como multidimensionales. Procedamos a aprender cómo funciona.

Cómo encontrar el índice del elemento máximo en una matriz NumPy

Para seguir este tutorial, debe tener Python y NumPy instalados. Puede codificar iniciando un REPL de Python o iniciando un cuaderno Jupyter.

Primero, importemos NumPy bajo el alias habitual np.

import numpy as np

Puede usar la función NumPy max() para obtener el valor máximo en una matriz (opcionalmente a lo largo de un eje específico).

array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))

# Output
10

En este caso, np.max(array_1) devuelve 10, lo cual es correcto.

Suponga que desea encontrar el índice en el que se produce el valor máximo en la matriz. Puede tomar el siguiente enfoque de dos pasos:

  • Encuentre el elemento máximo.
  • Encuentre el índice del elemento máximo.
  • En array_1, el valor máximo de 10 se produce en el índice 4, después de la indexación cero. El primer elemento está en el índice 0; el segundo elemento está en el índice 1, y así sucesivamente.

    Para encontrar el índice en el que se produce el máximo, puede utilizar la función NumPy where(). np.where(condición) devuelve una matriz de todos los índices donde la condición es verdadera.

    Tendrá que tocar la matriz y acceder al elemento en el primer índice. Para encontrar dónde ocurre el valor máximo, establecemos la condición en array_1==10; recuerda que 10 es el valor máximo en array_1.

    print(int(np.where(array_1==10)[0]))
    
    # Output
    4

    Hemos usado np.where() solo con la condición, pero este no es el método recomendado para usar esta función.

    📑 Nota: Función NumPy where():
    np.where(condición,x,y) devuelve:

    – Elementos de x cuando la condición es Verdadera, y
    – Elementos de y cuando la condición es Falsa.

    Por lo tanto, encadenando las funciones np.max() y np.where(), podemos encontrar el elemento máximo, seguido del índice en el que ocurre.

    En lugar del proceso de dos pasos anterior, puede usar la función NumPy argmax() para obtener el índice del elemento máximo en la matriz.

    Sintaxis de la función NumPy argmax()

    La sintaxis general para usar la función NumPy argmax() es la siguiente:

    np.argmax(array,axis,out)
    # we've imported numpy under the alias np

    En la sintaxis anterior:

    • matriz es cualquier matriz NumPy válida.
    • El eje es un parámetro opcional. Cuando trabaje con matrices multidimensionales, puede usar el parámetro de eje para encontrar el índice de máximo a lo largo de un eje específico.
    • out es otro parámetro opcional. Puede establecer el parámetro out en una matriz NumPy para almacenar la salida de la función argmax().

    Nota: A partir de la versión 1.22.0 de NumPy, hay un parámetro adicional de keepdims. Cuando especificamos el parámetro del eje en la llamada a la función argmax(), la matriz se reduce a lo largo de ese eje. Pero establecer el parámetro keepdims en True garantiza que la salida devuelta tenga la misma forma que la matriz de entrada.

    Usando NumPy argmax() para encontrar el índice del elemento máximo

    #1. Usemos la función NumPy argmax() para encontrar el índice del elemento máximo en array_1.

    array_1 = np.array([1,5,7,2,10,9,8,4])
    print(np.argmax(array_1))
    
    # Output
    4

    La función argmax() devuelve 4, ¡lo cual es correcto! ✅

    #2. Si redefinimos array_1 de modo que 10 aparezca dos veces, la función argmax() devuelve solo el índice de la primera aparición.

    array_1 = np.array([1,5,7,2,10,10,8,4])
    print(np.argmax(array_1))
    
    # Output
    4

    Para el resto de los ejemplos, usaremos los elementos de array_1 que definimos en el ejemplo #1.

    Usando NumPy argmax() para encontrar el índice del elemento máximo en una matriz 2D

    Reformemos la matriz NumPy array_1 en una matriz bidimensional con dos filas y cuatro columnas.

    array_2 = array_1.reshape(2,4)
    print(array_2)
    
    # Output
    [[ 1  5  7  2]
     [10  9  8  4]]

    Para una matriz bidimensional, el eje 0 denota las filas y el eje 1 denota las columnas. Las matrices NumPy siguen la indexación cero. Entonces, los índices de las filas y columnas de la matriz NumPy array_2 son los siguientes:

    Ahora, llamemos a la función argmax() en el arreglo bidimensional, arreglo_2.

    print(np.argmax(array_2))
    
    # Output
    4

    Aunque llamamos a argmax() en la matriz bidimensional, todavía devuelve 4. Esto es idéntico a la salida de la matriz unidimensional, array_1 de la sección anterior.

    ¿Por qué pasó esto? 🤔

    Esto se debe a que no hemos especificado ningún valor para el parámetro del eje. Cuando este parámetro de eje no está establecido, de forma predeterminada, la función argmax() devuelve el índice del elemento máximo a lo largo de la matriz plana.

    ¿Qué es una matriz plana? Si hay una matriz N-dimensional de forma d1 x d2 x … x dN, donde d1, d2, hasta dN son los tamaños de la matriz a lo largo de las N dimensiones, entonces la matriz aplanada es una matriz larga unidimensional de tamaño d1 * d2 * … * dN.

    Para verificar cómo se ve la matriz aplanada para array_2, puede llamar al método flatten(), como se muestra a continuación:

    array_2.flatten()
    
    # Output
    array([ 1,  5,  7,  2, 10,  9,  8,  4])

    Índice del elemento máximo a lo largo de las filas (eje = 0)

    Procedamos a encontrar el índice del elemento máximo a lo largo de las filas (eje = 0).

    np.argmax(array_2,axis=0)
    
    # Output
    array([1, 1, 1, 1])

    Esta salida puede ser un poco difícil de comprender, pero entenderemos cómo funciona.

    Hemos establecido el parámetro del eje en cero (eje = 0), ya que nos gustaría encontrar el índice del elemento máximo a lo largo de las filas. Por lo tanto, la función argmax() devuelve el número de fila en el que se encuentra el elemento máximo, para cada una de las tres columnas.

    Vamos a visualizar esto para una mejor comprensión.

    Del diagrama anterior y la salida argmax(), tenemos lo siguiente:

    • Para la primera columna en el índice 0, el valor máximo 10 ocurre en la segunda fila, en el índice = 1.
    • Para la segunda columna en el índice 1, el valor máximo 9 ocurre en la segunda fila, en el índice = 1.
    • Para las columnas tercera y cuarta en el índice 2 y 3, los valores máximos 8 y 4 ocurren en la segunda fila, en el índice = 1.

    Esta es precisamente la razón por la que tenemos la matriz de salida ([1, 1, 1, 1]) porque el elemento máximo a lo largo de las filas se encuentra en la segunda fila (para todas las columnas).

    Índice del elemento máximo a lo largo de las columnas (eje = 1)

    A continuación, usemos la función argmax() para encontrar el índice del elemento máximo a lo largo de las columnas.

    Ejecute el siguiente fragmento de código y observe el resultado.

    np.argmax(array_2,axis=1)
    array([2, 0])

    ¿Puedes analizar la salida?

    Hemos establecido axis = 1 para calcular el índice del elemento máximo a lo largo de las columnas.

    La función argmax() devuelve, para cada fila, el número de columna en el que se produce el valor máximo.

    Aquí hay una explicación visual:

    Del diagrama anterior y la salida argmax(), tenemos lo siguiente:

    • Para la primera fila en el índice 0, el valor máximo 7 ocurre en la tercera columna, en el índice = 2.
    • Para la segunda fila en el índice 1, el valor máximo 10 ocurre en la primera columna, en el índice = 0.

    Espero que ahora entiendas cuál es la salida, array([2, 0]) medio.

    Uso del parámetro de salida opcional en NumPy argmax()

    Puede usar el parámetro opcional out the en la función NumPy argmax() para almacenar la salida en una matriz NumPy.

    Inicialicemos una matriz de ceros para almacenar el resultado de la llamada de función argmax() anterior, para encontrar el índice del máximo a lo largo de las columnas (eje = 1).

    out_arr = np.zeros((2,))
    print(out_arr)
    [0. 0.]

    Ahora, revisemos el ejemplo de encontrar el índice del elemento máximo a lo largo de las columnas (eje = 1) y establecer la salida en out_arr que hemos definido anteriormente.

    np.argmax(array_2,axis=1,out=out_arr)

    Vemos que el intérprete de Python arroja un TypeError, ya que out_arr se inicializó en una matriz de flotantes de forma predeterminada.

    ---------------------------------------------------------------------------
    TypeError                                 Traceback (most recent call last)
    /usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
         56     try:
    ---> 57         return bound(*args, **kwds)
         58     except TypeError:
    
    TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'

    Por lo tanto, al establecer el parámetro out en la matriz de salida, es importante asegurarse de que la matriz de salida tenga la forma y el tipo de datos correctos. Como los índices de la matriz son siempre números enteros, debemos establecer el parámetro dtype en int al definir la matriz de salida.

    out_arr = np.zeros((2,),dtype=int)
    print(out_arr)
    
    # Output
    [0 0]

    Ahora podemos seguir adelante y llamar a la función argmax() con los parámetros de eje y de salida, y esta vez se ejecuta sin errores.

    np.argmax(array_2,axis=1,out=out_arr)

    Ahora se puede acceder a la salida de la función argmax() en la matriz out_arr.

    print(out_arr)
    # Output
    [2 0]

    Conclusión

    Espero que este tutorial te haya ayudado a entender cómo usar la función NumPy argmax(). Puede ejecutar los ejemplos de código en un cuaderno Jupyter.

    Repasemos lo que hemos aprendido.

    • La función NumPy argmax() devuelve el índice del elemento máximo en una matriz. Si el elemento máximo aparece más de una vez en una matriz a, entonces np.argmax(a) devuelve el índice de la primera aparición del elemento.
    • Cuando trabaje con matrices multidimensionales, puede usar el parámetro de eje opcional para obtener el índice del elemento máximo a lo largo de un eje en particular. Por ejemplo, en una matriz bidimensional: al establecer axis = 0 y axis = 1, puede obtener el índice del elemento máximo a lo largo de las filas y columnas, respectivamente.
    • Si desea almacenar el valor devuelto en otra matriz, puede establecer el parámetro de salida opcional en la matriz de salida. Sin embargo, la matriz de salida debe tener una forma y un tipo de datos compatibles.

    A continuación, consulte la guía detallada sobre los conjuntos de Python.