
import numpy
from numpy.linalg import eig
import math 
from dolfin import *

dolfin_set("form compiler", "sfc")


class Solution(Function): 
    def eval(self, values, xx): 
        r = math.sqrt((xx[0]-0.5)**2 + (xx[1]-0.5)**2) 

        values[0] = (xx[0]-0.5)/r 
        values[1] = (xx[1]-0.5)/r 

class DirichletBoundary(SubDomain): 
    def inside(self, x, on_boundary): 
        return bool(on_boundary)

# Geometry
N = 5  
if N%2==0: 
    print "The solution will have a singularity in a nodal point."
    #sys.exit(0)

mesh = UnitSquare(N,N)

# Function spaces
X = VectorFunctionSpace(mesh, 'CG', 1)
Y = FunctionSpace(mesh, 'CG', 1)
XY = X + Y

# Solution function (in mixed space)
u = Function(XY)

# Basis functions (in mixed space)
vv = TestFunction(XY)
uu = TrialFunction(XY)

# Form coefficients
norms = []
eigs = []
for do_split in (True, False):
    if do_split:
        x, y = split(u)
    else:
        x = Function(X)
        y = Function(Y)

    # Forms
    L = inner(grad(x),grad(x))*dx + dot(x,x)*y*dx
    #L =                             dot(x,x)*y*dx
    #L = inner(grad(x),grad(x))*dx
    #L =    dot(x,x)*dx + dot(x,x)*y*dx 
    
    if do_split:
        F = derivative(L, u, vv)
        J = derivative(F, u, uu)
    else:
        F = derivative(L, (x,y), vv)
        J = derivative(F, (x,y), uu)

    # Start vector 
    dirichlet_function = Solution(X)
    x0 = project(dirichlet_function,X)
#    x0.vector().zero()

    y0 = Function(Y)
    y0.vector().zero()

    #Some ugly code related to Function <-> SubFunction stuff
    #--------------------------------------------------------
    xn = x0.vector().size()
    yn = y0.vector().size()

    if not do_split:
        x.vector().assign(x0.vector())
        y.vector().assign(y0.vector())
    else:
        uarr = u.vector().array()
        uarr[:xn] = x0.vector().array()
        uarr[xn:] = y0.vector().array()
        u.vector().set(uarr)

    #--------------------------------------------------------

    b = assemble(F)
    A = assemble(J)
    print 
    print "do_split =", do_split
    print "xn, yn =", xn, yn
    bnorm = b.norm()
    print "Norm of b ", bnorm
    norms.append(bnorm)
    eigval, eigvec = eig(A.array())
    eigA = sorted(eigval)
    eigA.reverse()
    eigs.append(eigA)
    print "Eigenvalues of A ", eigA[:5], ", ...", eigA[-5:]

# TODO: Compare eigs
#print "Eig difference: ", (numpy.ndarray(eigs[0][:5]) - numpy.ndarray(eigs[1][:5]))
#print "Eig difference: ", (numpy.ndarray(eigs[0][-5:]) - numpy.ndarray(eigs[1][-5:]))
