'Plotting multiple graphs vith different constant values

I've got a function of one variable (r) and two constants (R and γ). I would like to make 9 different plots for each combination of the two constants being equal to three values each.

The function is given by: enter image description here

Here's what I've come up with so far:

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
plt.rc('font', size=14)
fig, axs = plt.subplots(3, 3, figsize=(15,15))

# Defining function
def f(r, R, γ):
    if r <= R:
        return 1+γ*(r/R)**(γ+1) - (γ+1)*(r/R)**γ
    else:
        return 0

r = np.linspace(0, 10, 100)
γs = [1, 2, 3]
Rs = [2, 4, 6]

for i in range(9):
    for γ in range(3):
        for R in range(3):
            axs[i].plot(r, f(r, Rs[R], γs[γ]))

However, I get this error code:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-4-faa7cd86fd6b> in <module>
     19     for γ in range(3):
     20         for R in range(3):
---> 21             axs[i].plot(r, f(r, Rs[R], γs[γ]))

AttributeError: 'numpy.ndarray' object has no attribute 'plot'

How do I make python display 9 plots with f(r, 2, 1), f(r, 4, 1), f(r, 6, 1), f(r, 2, 2), ...?



Solution 1:[1]

In your example code, axs is a 3x3 numpy array containing your axes objects. Therefore, axs[0] still is a numpy array containing 3 axes objects. As you do not really need the variable i apart from addressing the axes objects, you should get rid of it and directly iterate over lists and arrays without the need for keeping track of three indexes:

from itertools import product as pr
from matplotlib import pyplot as plt

fig, axes = plt.subplots(3, 2, figsize = (8, 12))
iter = pr(range(axes.shape[0]), range(axes.shape[1]))
for ax, (row, col) in zip(axes.flat, iter):
    ax.set_title(f"Row: {row} Column {col}")

plt.show()

Sample output: enter image description here

This makes use of itertools.product to generate all combinations. Translated into your code example, we can adapt this approach to:

fig, axes = plt.subplots(3, 3, figsize = (8, 12))
...
iter = pr(?s, Rs)
for ax, (?, R) in zip(axes.flat, iter):
    ax.plot(r, f(r, R, ?))

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1