import numpy as np
import pylab
import matplotlib.pyplot as pl
import matplotlib.animation as animation
import sys
import math

#This code solves the restricted 3-body problem as in Toomre & Toomre (e.g. 1972, ApJ, 178, 623).
#It considers gravitational encounters between two galaxies. The bulk (halo) of each galaxy is representing by a central point mass.

#The two point masses arrive at the scence of the encounter surrounded by a flat annular disc of 120 non-interacting test particles. For details see section II of Toomre & Toomre 1972.

# Define 'Particle' class with mass, position, velocity and spin attributes
class Particle(object):
   def __init__(self, mass, pos, vel, spin):
      self.mass = float(mass)
      self.pos = np.array(pos, dtype='float')
      self.vel = np.array(vel, dtype='float')
      self.spin = float(spin)


#Define constants and units (use SI units for calculations)
G = 6.67e-11         # [N m^2/kg^2]
kpc_to_m = 3.086e19  # conversion factor betweenn kpc and m (for plotting)
yr_to_s = 3.15569e7  # conversion factor between yr and s
Msun = 1.989e30      # conversion factor between Msun and kg


#Define main simulation parameters
nb_rings  = 5                            # number of rings in the stellar disc
nb_stars_per_ring = [12, 18, 24, 30, 36] # number of stars in each ring
ring_delta_r  = 10 * kpc_to_m            # initial distance between rings
delta_t = 4e6 * yr_to_s                  # size of time steps
nb_steps   = 400                         # total number of time steps of simulations
soft_length = 5e-4 * kpc_to_m            # gravitational softening length parameter


      
# Compute the force applied on particle 2 exerted by particle 1
def grav_force(particle1, particle2):
   distance_vector = particle2.pos - particle1.pos
   distance_mag = np.sqrt(np.sum(distance_vector**2))
   force = G * particle1.mass * particle2.mass * distance_vector / (distance_mag**2 + soft_length**2)**1.5
   return force


#Compute circular velocity
def circ_vel(r, Mass):
   return np.sqrt(G*Mass/r)


# Rotate coordinates
def rotate(xyzp, euler):
    theta = euler[0]*math.pi/180.
    psi   = euler[1]*math.pi/180.
    phi   = euler[2]*math.pi/180.
    c1 = math.cos(theta)
    c2 = math.cos(psi)
    c3 = math.cos(phi)
    s1 = math.sin(theta)
    s2 = math.sin(psi)
    s3 = math.sin(phi)

    l1 = c2*c3 - c1*s2*s3
    l2 = -c2*s3 - c1*s2*c3
    l3 = s1*s2
    m1 = s2*c3 + c1*c2*s3
    m2 = -s2*s3 + c1*c2*c3
    m3 = -s1*c2
    n1 = s1*s3
    n2 = s1*c3
    n3 = c1

    xp = xyzp[0]
    yp = xyzp[1]
    zp = xyzp[2]

    x  = l1*xp + l2*yp + l3*zp
    y  = m1*xp + m2*yp + m3*zp
    z  = n1*xp + n2*yp + n3*zp

    return np.array([x, y, z])

# Specify initial conditions for halos and stars
# euler1 and euler2 specify the Euler angles for the coordinate systems
# of galaxy1 and galaxy2 (nutation, precession, rotation).
# See Bronshtein et al., "Handbook of Mathematics" Section 3.5.3.2

def set_ICs(mass1 = 1e11, mass2 = 1e11, pos1 = [-50,-50,0], pos2 = [50, 50, 0], vel1 = [0,1e5,0], vel2 = [0,-1e5,0], euler1 = [0, 0, 0], euler2 = [0, 0, 0], spin1 = 1, spin2 = 1):
    #Halos
    mass1 = mass1*Msun
    mass2 = mass2*Msun

    pos1 = np.array(pos1) * kpc_to_m
    pos2 = np.array(pos2) * kpc_to_m

    vel1 = np.array(vel1)
    vel2 = np.array(vel2)

    spin1 = spin1
    spin2 = spin2

    halo1 = Particle(mass1, pos1, vel1, spin1)
    halo2 = Particle(mass2, pos2, vel2, spin2)
    halo_array = [halo1, halo2]
    euler      = [euler1, euler2]

    # Discs of stars
    star_array = []
    for halo in halo_array:
        star_array.append([])

    for halo in range(len(halo_array)):
        for ring in range(nb_rings):
            for star in range(nb_stars_per_ring[ring]):
                angle_star = 2*np.pi*star/nb_stars_per_ring[ring]
                radius_ring  = ring_delta_r * (ring + 1)
                star_pos = rotate(np.array([np.cos(angle_star), np.sin(angle_star), 0])*radius_ring, euler[halo]) + halo_array[halo].pos
                v_circ = circ_vel(ring_delta_r * (ring+1), halo_array[halo].mass)
                star_vel = rotate(np.array([-np.sin(angle_star), np.cos(angle_star), 0])*v_circ * halo_array[halo].spin, euler[halo]) + halo_array[halo].vel

                star_particle =  Particle(1e7*Msun, star_pos, star_vel, 1)  # (mass of star particle not relevant for restricted 3-body problem)
                star_array[halo].append(star_particle)
    
    return halo_array, star_array



#Get positions of halos in the x-y plane
def halos_xy(halo_array):
   xpos_halo = [halo.pos[0] for halo in halo_array]
   ypos_halo = [halo.pos[1] for halo in halo_array]
   xy_pos = np.array([xpos_halo, ypos_halo]).T
   return xy_pos


#Get positions of star particles in the x-y plane
def stars_xy(star_array, halo):
   xpos_star = [star.pos[0] for star in star_array[halo]]
   ypos_star = [star.pos[1] for star in star_array[halo]]
   xy_pos = np.array([xpos_star, ypos_star]).T
   return xy_pos


#Integrate forward in time
def run_sim(t_step):
    #Compute gravitational acceleration and new positions/velocities at each time step
    for halo in range(len(halo_array)):      
        #Calculate gravitational force on halo
        if halo == 0: force = grav_force(halo_array[halo], halo_array[halo+1])
        if halo == 1: force = grav_force(halo_array[halo], halo_array[halo-1])
        acceleration_halo = force/halo_array[halo].mass

        #Update position and velocity of halo
        halo_array[halo].vel += acceleration_halo * delta_t
        halo_array[halo].pos += halo_array[halo].vel * delta_t 

        #Update position and velocity of stars due to gravity of halos
        for star in star_array[halo]:
            # Calculate acceleration
            acceleration_star = (grav_force(star, halo_array[0]) + grav_force(star, halo_array[1])) / star.mass
            star.vel += acceleration_star*delta_t
            star.pos += star.vel*delta_t

    #Plot new time step
    halo_plot.set_offsets(halos_xy(halo_array)/kpc_to_m)
    star_plot0.set_offsets(stars_xy(star_array, 0)/kpc_to_m)
    star_plot1.set_offsets(stars_xy(star_array, 1)/kpc_to_m)
    time_count.set_text('t='+ str(float(t_step+1) * delta_t / yr_to_s/1e6)+' Myr')

        

#################################
#################################

#Set initial conditions
halo_array, star_array = set_ICs()
print 'Toomre simulation started...'

########
#Initialise plot
pl.clf()
fig = pl.figure(1)
pl.axis([-170, 170, -170, 170])

#Plot halos
halo_pos = halos_xy(halo_array)
halo_plot = pl.scatter(halo_pos[:,0]/kpc_to_m, halo_pos[:,1]/kpc_to_m, marker='o', s=[300 for halo in halo_array], c=['k','k'])

#Plot stars
star_pos0 = stars_xy(star_array, 0)
star_pos1 = stars_xy(star_array, 1)
star_plot0 = pl.scatter(star_pos0[:,0]/kpc_to_m, star_pos0[:,1]/kpc_to_m, marker='*', s=[50]*len(star_pos0), c=['b' for s in star_pos0])
star_plot1 = pl.scatter(star_pos1[:,0]/kpc_to_m, star_pos1[:,1]/kpc_to_m, marker='*', s=[50]*len(star_pos1), c=['r' for s in star_pos1])
                         
pl.xlabel('X [kpc]', fontsize=16)
pl.ylabel('Y [kpc]', fontsize=16)
time_count = pl.text(70,140,'t= 0 Myr', fontsize=16)
#########

#Create animation and save
anim = animation.FuncAnimation(fig, run_sim, frames=nb_steps, blit=False, repeat=False)

anim.save('toomre.mp4', writer='ffmpeg', fps=30)
print 'Done.'
