#!/usr/bin/env python

#Copyright 2006-2008 John C. Vernaleo
#Unfortunately, I am not comfortable putting unobscured email addresses
#on the web, but these shouldn't be too hard to figure out.
#
#		(my_first_name)@netpurgatory.com
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 2 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program; if not, write to the Free Software
#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

#This is just a bunch of plotting and other generic functions
#that I need in my scripts, but are too small to be in a file
#alone
#jcv
#05/04/2006
#08/15/2006 new array functions

import re
import os
import sys
import socket

try:
    from Numeric import *
except ImportError:
    try:
        from numarray import *
    except ImportError:
        print "No array package present"
        sys.exit()

def set_paths():
    host=socket.gethostname()
    borg=re.search('borg',host)
    locutus=re.search('locutus',host)
    astroumd=re.search('astro.umd.edu',host)
    ganon=re.search('ganon',host)
    kirby=re.search('kirby',host)
    oldg4=re.search('oldg4',host)
    base="./"
    if astroumd:
        base="/home/vernaleo/research/"
    if ganon:
	base="/home2/john/"
    if oldg4:
	base="/Users/vernaleo/Desktop/research/"
    return base
    

def qplot(vector,title="",printing=0,path="./",eps=0,color=1):
    """Simple plotting of a single vector
    """
    pp=os.popen('gnuplot -persist','w')
    if printing:
        hardcopy(title,path,pp,eps,color)
    pp.write('set title "'+title+'"\n')
    pp.write("plot '-' with lines\n")
    for x in vector:
        x="%(#)f\n" % {"#" : x}
        pp.write(x)
    pp.write("e\n")
    pp.close()
    return

def qplot2(vector2,vector1,title="",printing=0,eps=0,path="./",color=1):
    pp=os.popen('gnuplot -persist','w')
    if printing:
        hardcopy(title,path,pp,eps,color)
    pp.write('set title "'+title+'"\n')
    pp.write("plot '-' using 1:2 with lines\n")
    i=0
    for x in vector1:
        x="%(#)f " % {"#" : x}
        y=x+"%(#)f\n" % {"#" : vector2[i]}
        pp.write(y)
        i+=1
    pp.write("e\n")
    pp.close()
    return

def overplot(vector1,vector2,*vector3):
    pp=os.popen('gnuplot -persist','w')
    pp.write("plot '-' using 1:2 with lines")
    for x in range(0,len(vector3)):
        pp.write(",'-' using 1:2 with lines")
    pp.write("\n")
    i=0
    for x in vector1:
        x="%(#)f " % {"#" : x}
        y=x+"%(#)f\n" % {"#" : vector2[i]}
        pp.write(y)
        i+=1
    pp.write("e\n")

    for z in range(0,len(vector3)):
        i=0
        for x in vector1:
            x="%(#)f " % {"#" : x}
            y=x+"%(#)f\n" % {"#" : vector3[z][i]}
            pp.write(y)
            i+=1
        pp.write("e\n")

    pp.close()
    return

def plot(name,path,xname,yname,legend,printing,x,*y):
    """This is pretty close to a full featured plotting
    command with axis labels and a legend
    and arbitrary number of lines
    """
    eps=1
    color=1
    pp=os.popen('gnuplot -persist','w')
    if printing:
        hardcopy(name,path,pp,eps,color)
    pp.write("set style data linespoints\n")
    pp.write("set pointsize 2\n")
    if legend:
        pp.write("set key left box\n")
    else:
        pp.write("unset key\n")
    pp.write("set xlabel '"+xname+"'\n")
    pp.write("set ylabel '"+yname+"'\n")
    #pp.write("set yrange[0.0:160]\n")
    if legend:
        pp.write("plot '-' using 1:2 title '"+legend[0]+"'")
    else:
        pp.write("plot '-' using 1:2 ")
    #if len(y) > 1:
    for i in range(1,len(y)):
        if legend:
            pp.write(",'-' using 1:2 title '"+legend[i]+"'")
            if (i == 2):
                pp.write(" with linespoints pointtype 6")
        else:
            pp.write(",'-' using 1:2")
    pp.write("\n")
    for i in range(0,len(y)):
        pairloop(x,y[i],pp)
    pp.close()

def plotline(name,path,xname,yname,legend,printing,x,*y):
    """This is pretty close to a full featured plotting
    command with axis labels and a legend
    and arbitrary number of lines
    """
    eps=1
    color=1
    pp=os.popen('gnuplot -persist','w')
    if printing:
        hardcopy(name,path,pp,eps,color)
    pp.write("set style data lines\n")
    if legend:
        pp.write("set key left box\n")
    else:
        pp.write("unset key\n")
    pp.write("set xlabel '"+xname+"'\n")
    pp.write("set ylabel '"+yname+"'\n")
    if legend:
        pp.write("plot '-' using 1:2 title '"+legend[0]+"'")
    else:
        pp.write("plot '-' using 1:2 ")
    #if len(y) > 1:
    for i in range(1,len(y)):
        if legend:
            pp.write(",'-' using 1:2 title '"+legend[i]+"'")
            if (i == 2):
                pp.write(" with lines type 6")
        else:
            pp.write(",'-' using 1:2")
    pp.write("\n")
    for i in range(0,len(y)):
        pairloop(x,y[i],pp)
    pp.close()

def image(value,ax1,ax2,title="",printing=0,path="./",eps=0,color=1):
    """Example usage:
    image(data.d[0,:,:],data.x1,data.x2,title)
    """
    pp=os.popen('gnuplot -persist','w')
    if printing:
        hardcopy(title,path,pp,eps,color)
    pp.write('set pm3d map\n')
    pp.write('set size square\n')
    pp.write('set palette defined ( 0 "black", 0.1 "green", 1 "blue", 2 "red", 3 "orange" )\n')
    #I used this color palette for red/blue images
    #pp.write('set palette rgbformulae 33,13,10\n')
    xmin=amin(ax1)
    ymin=amin(ax2)
    xmax=amax(ax1)
    ymax=amax(ax2)
    pp.write('set xrange['+str(xmin)+':'+str(xmax)+']\n')
    pp.write('set yrange['+str(ymin)+':'+str(ymax)+']\n')
    #pp.write('set cbrange[0:14]\n')
    pp.write('set title "'+title+'"\n')
    pp.write("splot '-' title''\n")

    i=0
    j=0
    for x in ax1:
        for y in ax2:
            num="%(#)f " % {"#" :x}
            num=num+"%(#)f " % {"#" :y}
            num=num+"%(#)f\n" % {"#" :value[i,j]}
            pp.write(num)
            i+=1
        i=0
        j+=1
        pp.write("\n")

    pp.write("e\n")

    pp.close()
    return

def qimage(value,title="",printing=0,path="./",eps=0,color=1):
    """Produces and image of array where you
    do not know the value of either axis.
    They are assumed to be evenly spaced.
    Example usage:
    image(data.d[0,:,:]title)
    """
    l1=len(value[0,:])
    l2=len(value[:,0])
    ax1=arange(l1)
    ax2=arange(l2)
    pp=os.popen('gnuplot -persist','w')
    if printing:
        hardcopy(title,path,pp,eps,color)
    pp.write('set pm3d map\n')
    pp.write('set size square\n')
    pp.write('set palette rgbformulae 33,13,10\n')
    xmin=amin(ax1)
    ymin=amin(ax2)
    xmax=amax(ax1)
    ymax=amax(ax2)
    pp.write('set xrange['+str(xmin)+':'+str(xmax)+']\n')
    pp.write('set yrange['+str(ymin)+':'+str(ymax)+']\n')
    pp.write('set title "'+title+'"\n')
    pp.write("splot '-'\n")

    i=0
    j=0
    for x in ax1:
        for y in ax2:
            num="%(#)f " % {"#" :x}
            num=num+"%(#)f " % {"#" :y}
            num=num+"%(#)f\n" % {"#" :value[i,j]}
            pp.write(num)
            i+=1
        i=0
        j+=1
        pp.write("\n")

    pp.write("e\n")

    pp.close()
    return

def limage(value,ax1,ax2,title="",printing=0,path="./",eps=0,color=1):
    """produces image of array but only sample some points
    making it go fast, but produce lower quality
    Example usage:
    image(data.d[0,:,:],data.x1,data.x2,title)
    """
    shrink=2
    pp=os.popen('gnuplot -persist','w')
    if printing:
        hardcopy(title,path,pp,eps,color)
    pp.write('set pm3d map\n')
    pp.write('set size square\n')
    pp.write('set palette rgbformulae 33,13,10\n')
    xmin=amin(ax1)
    ymin=amin(ax2)
    xmax=amax(ax1)
    ymax=amax(ax2)
    pp.write('set xrange['+str(xmin)+':'+str(xmax)+']\n')
    pp.write('set yrange['+str(ymin)+':'+str(ymax)+']\n')
    pp.write('set title "'+title+'"\n')
    pp.write("splot '-' title ''\n")

    i=0
    j=0
    for ii in range(0,len(ax1),shrink):
        for jj in range(0,len(ax2),shrink):
            num="%(#)f " % {"#" :ax1[ii]}
            num=num+"%(#)f " % {"#" :ax2[jj]}
            num=num+"%(#)f\n" % {"#" :value[i,j]}
            pp.write(num)
            i+=shrink
        i=0
        j+=shrink
        pp.write("\n")

    pp.write("e\n")

    pp.close()
    return


def contours(value,ax1,ax2,title="",printing=0,path="./",eps=0,color=1):
    pp=os.popen('gnuplot -persist','w')
    if printing:
        hardcopy(title,path,pp,eps,color)
    pp.write('set title"'+title+'"\n')
    pp.write('set data style lines\n')
    pp.write('set contour\n')
    pp.write('set cntrparam levels 20\n')
    pp.write('set nosurface\n')
    pp.write('set view 0,0\n')
    pp.write("splot '-'\n")

    i=0
    j=0
    for x in ax1:
        for y in ax2:
            num="%(#)f " % {"#" :x}
            num=num+"%(#)f " % {"#" :y}
            num=num+"%(#)f\n" % {"#" :value[i,j]}
            pp.write(num)
            i+=1
        i=0
        j+=1
        pp.write("\n")

    pp.write("e\n")
    pp.close()
    return

def list_max(list):
    list.sort()
    return list[-1]

def list_min(list):
    list.sort()
    return list[0]

def amax(array):
    output=max(array)
    for i in range(1,len(array.shape)):
        output=max(output)
    return output

def amin(array):
    output=min(array)
    for i in range(1,len(array.shape)):
        output=min(output)
    return output

def aminmax(array):
    output1=min(array)
    output2=max(array)
    for i in range(1,len(array.shape)):
        output1=min(output1)
        output2=max(output2)
    return output1,output2

def asum(array):
    output=sum(array)
    for i in range(1,len(array.shape)):
        output=sum(output)
    return output

def fix_num(current):
    if current <= 9:
        current='00'+str(current)
    if current > 9 and current < 100:
        current='0'+str(current)
    if current >= 100:
        current=str(current)
    return(current)

def fix_names(current,file,type,rbase,wbase):
    current=fix_num(current)
    cfile=rbase+file+current
    pics=wbase+type+current+'.png'
    print cfile
    return(cfile,pics)

def chkprint(argv):
    #Set defaults
    printing=0
    color=1
    eps=1
    #Get argv as global so I can remove the print options from it
    #That way the length can mean something else and no need to worry
    #about -p and so on.
    #from sys import argv
    #global argv
    i=0
    #Then search for new values
    if len(argv) >= 2:
	for item in argv:
            if item == "-p":
                printing=1
                argv.pop(i)
            if item == "-pb":
                printing=1
                color=0
                argv.pop(i)
            if item == "-png":
                printing=1
                eps=0
                argv.pop(i)
            i+=1
    return(printing,color,eps,argv)

def hardcopy(title,path="./",handle="pp",eps=0,color=1,lines="solid",font="26"):
    nospace=re.compile('\s')
    noslash=re.compile('\/')
    title=nospace.sub('-',noslash.sub('-',title))
    if eps:
        handle.write("set output '"+path+title+".eps'\n")
        if color:
            ifc="color"
        else:
            ifc=""
        handle.write("set terminal postscript eps enhanced "+ifc+" "+lines+" enhanced lw 2 "+str(font)+"\n")
    else:
        handle.write("set output '"+path+title+".png'\n")
        handle.write("set terminal png crop\n")

def binsearch(value,array):
    """Finds the element in array that is closest to value
    """
    oind=len(array)
    ind=oind/2
    test=value-array[ind]
    if(test<=0):
        return(oind-1)
    while 1:
        move=abs((oind-ind)/2)
        oind=ind
        if test == 0.0:
            return(ind)
        if test > 0.0:
            if move > 0:
                ind=ind+move
            else:
                return(ind)
        if test < 0.0:
            if move > 0:
                ind=ind-move
            else:
                return(ind)
        if ind>=len(array):
            return(len(array)-1)
        else:
            test=value-array[ind]

def linsearch(value,array):
    """This is real slow, but should always work.
    """
    i=0
    difflow=abs(array[i]-value)
    for i in range(1,len(array)):
        diffhigh=abs(array[i]-value)
        if difflow <= diffhigh:
            return(i-1)
        else:
            difflow = diffhigh
    return(len(array)-1)

def mirrorx(a1):
    """produces mirror image of a floating point array around 1st axis
    """
    a2=zeros(a1.shape,Float)
    max=a1.shape[0]
    for i in range(0,max):
	a2[(max-1-i)]=a1[i]
    return(a2)

def mirrory(a1):
    """Unlike mirrorx, this assume an array with only 2 dimenstions
    """
    a2=zeros(a1.shape,Float)
    max=a1.shape[1]
    for i in range(0,max):
	a2[:,(max-1-i)]=a1[:,i]
    return(a2)

#MLab.rot90 is better
#def rot90(a1):
#    """rotates an array by 90 degrees counter clockwise
#    """
#    a2=zeros((a1.shape[1],a1.shape[0]),Float)
#    max=a1.shape[1]
#    for i in range(0,max):
#	a2[:,i]=a1[i]
#    return(a2)

def ascale(a,min,max):
    """
    This scales all values of an array to be between
    min and max.
    Only works on 1,2, or 3 D arrays.
    """
    imax=a.shape[0]
    if len(a.shape) > 3:
        print "Sorry, does not work on arrays with more than 3 Dims"
        sys.exit()
    if len(a.shape) == 1:
        for i in range(0,imax):
            if a[i] >= max:
                a[i] = max
            if a[i] <= min:
                a[i] = min
    if len(a.shape) == 2:
        jmax=a.shape[1]
        for i in range(0,imax):
            for j in range(0,jmax):
                if a[i,j] >= max:
                    a[i,j] = max
                if a[i,j] <= min:
                    a[i,j] = min
    if len(a.shape) == 3:
        jmax=a.shape[1]
        kmax=a.shape[2]
        for i in range(0,imax):
            for j in range(0,jmax):
                for k in range(0,kmax):
                    if a[i,j,k] >= max:
                        a[i,j,k] = max
                    if a[i,j,k] <= min:
                        a[i,j,k] = min
    #for i in range(0,len(ravel(a))):
    #    x=ravel(a)[i]
    #    if x >= max:
    #        ravel(a)[i]=max
    #    if x <=min:
    #        ravel(a)[i]=min
    a=a/max
    return(a)

def pairloop(data1,data2,handle,jump=1):
    jump=1
    for i in range(0,len(data1),jump):
        y="%(#)f " % {"#" : data1[i]}
        z=y+"%(#)f\n" % {"#" : data2[i]}
        handle.write(z)
    handle.write("e\n")
