मेरा कोड यहां बहुत खराब प्रदर्शन कर रहा है। स्लाइडर पर चीजें बदलते समय मुझे मुश्किल से 10 एफपीएस से अधिक मिलता है। माना कि मैं matplotlib से बहुत अच्छी तरह वाकिफ नहीं हूं, लेकिन क्या कोई बता सकता है कि मैं क्या गलत कर रहा हूं और इसे कैसे ठीक किया जाए?

नोट: मैं सबसे खराब स्थिति में बहुत सारे डेटा को संभाल रहा हूं, लगभग 3 * 100000 अंक ... यह भी सुनिश्चित नहीं है कि इसकी आवश्यकता है लेकिन मैं 'TkAgg' बैकएंड पर चल रहा हूं।

यहाँ मेरा कोड है (यह एक SIR महामारी विज्ञान गणितीय मॉडल की साजिश रचने और चलाने के लिए एक कोड है):

import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button
import matplotlib.patches as patches

p = 1                                                       #population
i = 0.01*p                                                  #infected
s = p-i                                                     #susceptible
r = 0                                                       #recovered/removed

a = 3.2                                                     #transmission parameter
b = 0.23                                                    #recovery parameter

initialTime = 0
deltaTime = 0.001                                           #smaller the delta, better the approximation to a real derivative
maxTime = 10000                                             #more number of points, better is the curve generated

def sPrime(oldS, oldI, transmissionRate):                   #differential equations being expressed as functions to
    return -1*((transmissionRate*oldS*oldI)/p)              #calculate rate of change between time intervals of the
                                                            #different quantities i.e susceptible, infected and recovered/removed
def iPrime(oldS, oldI, transmissionRate, recoveryRate):             
    return (((transmissionRate*oldS)/p)-recoveryRate)*oldI

def rPrime(oldI, recoveryRate):
    return recoveryRate*oldI

maxTimeInitial = maxTime

def genData(transRate, recovRate, maxT):
    global a, b, maxTimeInitial
    a = transRate
    b = recovRate
    maxTimeInitial = maxT

    sInitial = s
    iInitial = i
    rInitial = r

    time = []
    sVals = []
    iVals = []
    rVals = []

    for t in range(initialTime, maxTimeInitial+1):              #generating the data through a loop
        time.append(t)
        sVals.append(sInitial)
        iVals.append(iInitial)
        rVals.append(rInitial)

        newDeltas = (sPrime(sInitial, iInitial, transmissionRate=a), iPrime(sInitial, iInitial, transmissionRate=a, recoveryRate=b), rPrime(iInitial, recoveryRate=b))
        sInitial += newDeltas[0]*deltaTime
        iInitial += newDeltas[1]*deltaTime
        rInitial += newDeltas[2]*deltaTime

        if sInitial < 0 or iInitial < 0 or rInitial < 0:        #as soon as any of these value become negative, the data generated becomes invalid
            break                                               #according to the SIR model, we assume all values of S, I and R are always positive.

    return (time, sVals, iVals, rVals)

fig, ax = plt.subplots()
plt.subplots_adjust(bottom=0.4, top=0.94)

plt.title('SIR epidemiology curves for a disease')

plt.xlim(0, maxTime+1)
plt.ylim(0, p*1.4)

plt.xlabel('Time (t)')
plt.ylabel('Population (p)')

initialData = genData(a, b, maxTimeInitial)

susceptible, = ax.plot(initialData[0], initialData[1], label='Susceptible', color='b')
infected, = ax.plot(initialData[0], initialData[2], label='Infected', color='r')
recovered, = ax.plot(initialData[0], initialData[3], label='Recovered/Removed', color='g')

plt.legend()

transmissionAxes = plt.axes([0.125, 0.25, 0.775, 0.03], facecolor='white')
recoveryAxes = plt.axes([0.125, 0.2, 0.775, 0.03], facecolor='white')
timeAxes = plt.axes([0.125, 0.15, 0.775, 0.03], facecolor='white')

transmissionSlider = Slider(transmissionAxes, 'Transmission parameter', 0, 10, valinit=a, valstep=0.01)
recoverySlider = Slider(recoveryAxes, 'Recovery parameter', 0, 10, valinit=b, valstep=0.01)
timeSlider = Slider(timeAxes, 'Max time', 0, 100000, valinit=maxTime, valstep=1, valfmt="%i")

def updateTransmission(newVal):
    newData = genData(newVal, b, maxTimeInitial)

    susceptible.set_ydata(newData[1])
    infected.set_ydata(newData[2])
    recovered.set_ydata(newData[3])

    r_o.set_text(r'$R_O$={:.2f}'.format(a/b))

    fig.canvas.draw_idle()

def updateRecovery(newVal):
    newData = genData(a, newVal, maxTimeInitial)

    susceptible.set_ydata(newData[1])
    infected.set_ydata(newData[2])
    recovered.set_ydata(newData[3])

    r_o.set_text(r'$R_O$={:.2f}'.format(a/b))

    fig.canvas.draw_idle()

def updateMaxTime(newVal):
    global susceptible, infected, recovered

    newData = genData(a, b, int(newVal.item()))

    del ax.lines[:3]

    susceptible, = ax.plot(newData[0], newData[1], label='Susceptible', color='b')
    infected, = ax.plot(newData[0], newData[2], label='Infected', color='r')
    recovered, = ax.plot(newData[0], newData[3], label='Recovered/Removed', color='g')

transmissionSlider.on_changed(updateTransmission)
recoverySlider.on_changed(updateRecovery)
timeSlider.on_changed(updateMaxTime)

resetAxes = plt.axes([0.8, 0.025, 0.1, 0.05])
resetButton = Button(resetAxes, 'Reset', color='white')

r_o = plt.text(0.1, 1.5, r'$R_O$={:.2f}'.format(a/b), fontsize=12)

def reset(event):
    transmissionSlider.reset()
    recoverySlider.reset()
    timeSlider.reset()

resetButton.on_clicked(reset)

plt.show()

0
Prithvidiamond 30 मार्च 2020, 10:39

1 उत्तर

सबसे बढ़िया उत्तर

गति के लिए scipy.integrate.odeint जैसे उचित ODE सॉल्वर का उपयोग करें। फिर आप आउटपुट के लिए बड़े समय के चरणों का उपयोग कर सकते हैं। एक अंतर्निहित सॉल्वर जैसे odeint या solve_ivp method="Radau" के साथ, समन्वय विमान जो सटीक समाधान में सीमाएं हैं, संख्यात्मक समाधान में भी सीमाएं होंगी, ताकि मान कभी भी नकारात्मक न हों।

प्लॉट छवि के वास्तविक रिज़ॉल्यूशन से मिलान करने के लिए प्लॉट किए गए डेटा सेट को कम करें। ३०० अंक से १००० अंक का अंतर अभी भी दिखाई दे सकता है, १००० अंक से ५००० अंक तक कोई दृश्य अंतर नहीं होगा, शायद वास्तविक अंतर भी नहीं।

Matplotlib धीमी गति से अजगर पुनरावृत्ति का उपयोग करके, एक दृश्य पेड़ के माध्यम से वस्तुओं के रूप में अपनी छवियों को खींचता है। यदि ड्रॉ करने के लिए कुछ 10000 से अधिक ऑब्जेक्ट हैं, तो यह बहुत धीमा हो जाता है, इसलिए विवरणों की संख्या को इस संख्या तक सीमित करना सबसे अच्छा है।

ODE सॉल्वर के लिए कोड

ODE को हल करने के लिए मैंने Solve_ivp का उपयोग किया, लेकिन अगर odeint का उपयोग किया जाता है तो इससे कोई फर्क नहीं पड़ता,

def SIR_prime(t,SIR,trans, recov): # solver expects t argument, even if not used
    S,I,R = SIR
    dS = (-trans*I/p) * S 
    dI = (trans*S/p-recov) * I
    dR = recov*I
    return [dS, dI, dR]

def genData(transRate, recovRate, maxT):
    SIR = solve_ivp(SIR_prime, [0,maxT], [s,i,r], args=(transRate, recovRate), method="Radau", dense_output=True)
    time = np.linspace(0,SIR.t[-1],1001)
    sVals, iVals, rVals = SIR.sol(time)
    return (time, sVals, iVals, rVals)

प्लॉट अद्यतन प्रक्रियाओं के लिए सुव्यवस्थित कोड

कोई बहुत से डुप्लिकेट कोड को हटा सकता है। मैंने एक लाइन भी जोड़ी है ताकि टाइम एक्सिस मैक्सटाइम वेरिएबल के साथ बदल जाए, ताकि कोई वास्तव में ज़ूम इन कर सके

def updateTransmission(newVal):
    global trans_rate
    trans_rate = newVal
    updatePlot()

def updateRecovery(newVal):
    global recov_rate
    recov_rate = newVal
    updatePlot()

def updateMaxTime(newVal):
    global maxTime
    maxTime = newVal
    updatePlot()

def updatePlot():
    newData = genData(trans_rate, recov_rate, maxTime)

    susceptible.set_data(newData[0],newData[1])
    infected.set_data(newData[0],newData[2])
    recovered.set_data(newData[0],newData[3])

    ax.set_xlim(0, maxTime+1)

    r_o.set_text(r'$R_O$={:.2f}'.format(trans_rate/recov_rate))

    fig.canvas.draw_idle()

बीच और आसपास का कोड वही रहता है।

1
Lutz Lehmann 2 अप्रैल 2020, 08:56