import math
import numpy as np
import scipy.interpolate as ip

# Initial-final mass relation (from Lawlor and MacDonald 2006, MNRAS 371, 263)

minit = [0.80, 0.85, 0.90, 0.95, 1.000, 1.200, 1.500, 2.000, 2.500, 3.000, 4.000, 5.000, 6.500, 8.000, 10.0]
mfin  = [0.54, 0.54, 0.54, 0.54, 0.541, 0.557, 0.563, 0.595, 0.632, 0.678, 0.778, 0.874, 0.973, 1.080, 1.11]
fmfin = ip.interp1d(minit, mfin)

# Yields for Type II SNe, from Woosley and Weaver 1995, ApJ Suppl. 101, 181

mww = [11.0, 12.0, 13.0, 15.0, 18.0, 19.0, 20.0, 22.0, 25.0, 30.0, 35.0, 40.0]
m16O = [1.36e-01, 2.10e-01, 2.72e-01, 6.80e-01, 1.13, 1.43, 1.94, 2.38, 3.25, 3.65, 3.07, 2.36]
mFe  = [0.0804, 0.0554, 0.1458, 0.1297, 0.0828, 0.1177, 0.1063, 0.2247, 0.1509, 0.318, 0.0918, 0.051]

f16O = ip.interp1d(mww,m16O)
fFe  = ip.interp1d(mww,mFe)

# IMF slope
alpha = -2.35

def calcsfr(mg):
# Simple Schmidt-Kennicutt law, from Kennicutt (1998)
# Here, mg is the gas surface density (in Msun pc^-2)
    logMg = math.log10(mg)
    logSFR = logMg*1.4 - 3.5
    SFR = 1e-6 * 10**logSFR
    
    return SFR

def mrem(m):
# Calculate the remnant mass for a star of initial mass m
# For stars more massive than max(minit) we assume, arbitrarily, a remnant
# mass of 2 Msun

    if (m < max(minit)):
       mm = fmfin(m)
    else:
       mm = 2.0

# Alternatively, Iben & Tutukov 1984 (Pagel 1997, p. 239):
#    if (m < 9.5):
#       mm = 0.11*m + 0.45
#    else:
#       mm = 1.5

    return mm

def fq16O(m):
# Mass fraction of 16O returned by a star of mass m
    if ((m > min(mww)) and (m < max(mww))):
        return f16O(m)/m
    else:
        return 0.

def fqFe(m):
# Mass fraction of Fe returned by a star of mass m
    if ((m > min(mww)) and (m < max(mww))):
        return fFe(m)/m
    else:
        return 0.

def IMF(m):
# Evaluate the Salpeter (1955) IMF for a star of mass m
    mmin = 0.15
    mmax = 100.
    return m**alpha * -(alpha+2) / (mmin**(alpha+2) - mmax**(alpha+2))

def mej(ti, at, aSFR):
# Calculate the ejected mass, from stellar evolution, at time ti.
# at and aSFR are arrays containing the time and star formation rate.
# Stellar masses and lifetimes are read from the file t_m.txt

    logt, m = np.loadtxt('t_m.txt',usecols=(0,1), unpack=True)
    t = 10**logt
    fm = ip.interp1d(t, m)
    ft = ip.interp1d(np.flipud(m), np.flipud(t))

    if (ti < min(t)):      # No stars have evolved off MS yet
        ee = 0.
    elif (len(at) < 2):
        ee = 0.
    else:
        fSFR = ip.interp1d(at, aSFR)     # Interpolate in (t, SFR) arrays    
        mt = fm(ti)                      # Lower limit of the integral
        mU = max(m)                      # Upper limit
        mm = mt            
        ee = 0.
        while (mm < mU):
            dm = mm / 25.
            tSFR = max(ti - ft(mm),0)    # Calculate SFR when star of mass
	                                     # m was born
            if (tSFR > max(at)): tSFR = max(at)
            de = (mm - mrem(mm)) * fSFR(tSFR) * IMF(mm) * dm  # ejected mass
            ee = ee + de
            mm = mm + dm

    return ee


def mejZ_SNIa(ti, at, aSFR, mZ):
# Calculate the amount of metals returned by Type Ia SNe
# mZ is the mass of the element produced by one SN explosion

    mSNIamax = 8.0   # Maximum mass of stars that become Type Ia SNe
    fSNIa    = 0.02  # Fraction of those stars that actually become Type Ia SNe
    delay = 1e8      # Time delay, after endpoint of normal stellar evolution

    logt, m = np.loadtxt('t_m.txt',usecols=(0,1), unpack=True)
    t = 10**logt + delay     # Remember to add the delay
    fm = ip.interp1d(t, m)
    ft = ip.interp1d(np.flipud(m), np.flipud(t))
    w = np.where(m < mSNIamax)

    if (ti < min(t[w])):
        ee = 0.
    elif (len(at) < 2):
        ee = 0.
    else:
        fSFR = ip.interp1d(at, aSFR)
        mt = fm(ti)
        mm = mt
        ee = 0.
        while (mm < mSNIamax):
            dm = mm / 25.
            tSFR = max(ti - ft(mm),0)
            if (tSFR > max(at)): tSFR = max(at)
            de = fSNIa * mZ * fSFR(tSFR) * IMF(mm) * dm
            ee = ee + de
            mm = mm + dm

    return ee


def mejZ(ti, at, aSFR, aZ, fqZ):
# Calculate the amount of metals returned by Type II SNe
# aZ should contain the metallicity of the gas at times at
# fqZ(m) is a function that should return the mass fraction of element Z
# returned by a star with initial mass m

# The rest is largely equivalent to the function mej() above

    logt, m = np.loadtxt('t_m.txt',usecols=(0,1), unpack=True)
    t = 10**logt
    fm = ip.interp1d(t, m)
    ft = ip.interp1d(np.flipud(m), np.flipud(t))

    if (ti < min(t)):
        ee = 0.
    elif (len(at) < 2):
        ee = 0.
    else:
        fSFR = ip.interp1d(at, aSFR)
        fZ = ip.interp1d(at, aZ)
        mt = fm(ti)
        mU = max(m)
        mm = mt
        ee = 0.
        while (mm < mU):
            dm = mm / 50.
            tSFR = max(ti - ft(mm),0)
            if (tSFR > max(at)): tSFR = max(at)
            de = ((mm - mrem(mm))*fZ(tSFR) + mm*fqZ(mm)) * fSFR(tSFR) * IMF(mm) * dm
            ee = ee + de
            mm = mm + dm

    return ee


def gce_model(tmax=1.0e10, Mg0=10., Ms0=0.):
# Calculate the actual chemical evolution model.
# This largely follows "Nucleosynthesis and Chemical Evolution of Galaxies"
# by Pagel (2007), Section 7.4 (pp 243-244), also MBW Section 10.4
# Mg0 is the initial gas mass (surface density; Msun pc^-2) and
# Ms0 is the initial stellar mass (surface density)

    t = 0.          
    Mg = Mg0       # Gas mass
    Ms = Ms0       # Stellar mass
    M16O = 0.      # Mass in 16O in gas phase
    MFe  = 0.      # Mass in Fe in gas phase
    Z16Oacc = 0.   # Composition of accreted gas
    ZFeacc = 0.

# Arrays to store the enrichment history and related variables
    at, aMg, aMs, aSFR, aZ16O, aZFe = [], [], [], [], [], []

    while (t < tmax):

        dt = 1e6 + t/1e2    # Time steps: minimum 1e6 yr, then 1% of age

        SFR = calcsfr(Mg)
        et  = mej(t, at, aSFR)

        e16O = mejZ(t, at, aSFR, aZ16O, fq16O) # O from Type II SNe
        Z16O = M16O/Mg

        eFeI = mejZ_SNIa(t, at, aSFR, 0.61)    # Fe from Type Ia SNe
        eFeII = mejZ(t, at, aSFR, aZFe, fqFe)  # Fe from Type II SNe
        eFe = eFeI + eFeII
        ZFe = MFe/Mg

# Gas accretion
        racc = 0.        # No accretion

# Gas outflow via winds
        Z16Oej = Z16O     # Assume that ejected gas has same composition
                          # as current gas composition
        ZFeej = ZFe       
        rwind = 0.        # No Winds

# Update variables 
        dMg = (racc - rwind + et - SFR) * dt
        dMs = (SFR - et) * dt
        dM16O = (e16O - Z16O*SFR + Z16Oacc*racc - Z16Oej*rwind) * dt
        dMFe  = (eFe  -  ZFe*SFR +  ZFeacc*racc -  ZFeej*rwind) * dt
        Mg = Mg + dMg
        Ms = Ms + dMs
        M16O = M16O + dM16O
        MFe  = MFe  + dMFe 

        at.append(t)
        aSFR.append(SFR)
        aMg.append(Mg)
        aMs.append(Ms)
        aZ16O.append(M16O/Mg)
        aZFe.append(MFe/Mg)

        print("%6.3e %11.5e %11.5e %6.3e %6.3e %6.3e %9.6e %6.3e %6.3e %9.6e" % (t, Mg, Ms, SFR, et, e16O, Z16O, eFeI, eFeII, ZFe))
                
        t = t + dt
      
    return

# Run the model

gce_model()
