source: trunk/GSASIImath.py @ 457

Last change on this file since 457 was 457, checked in by vondreele, 10 years ago

begin distance angle calcs
move gpxfile routines from GSASIIstruct.py to GSASIIIO.py
move getVCov & ValEsd? to GSASIImath.py
add some text to help/gsasII.html

File size: 5.7 KB
Line 
1#GSASIImath - major mathematics routines
2########### SVN repository information ###################
3# $Date: 2012-01-13 11:48:53 -0600 (Fri, 13 Jan 2012) $
4# $Author: vondreele $
5# $Revision: 451 $
6# $URL: https://subversion.xor.aps.anl.gov/pyGSAS/trunk/GSASIImath.py $
7# $Id: GSASIImath.py 451 2012-01-13 17:48:53Z vondreele $
8########### SVN repository information ###################
9import sys
10import os
11import os.path as ospath
12import numpy as np
13import numpy.linalg as nl
14import cPickle
15import time
16import math
17import GSASIIpath
18import scipy.optimize as so
19import scipy.linalg as sl
20
21sind = lambda x: np.sin(x*np.pi/180.)
22cosd = lambda x: np.cos(x*np.pi/180.)
23tand = lambda x: np.tan(x*np.pi/180.)
24asind = lambda x: 180.*np.arcsin(x)/np.pi
25atan2d = lambda y,x: 180.*np.arctan2(y,x)/np.pi
26
27def HessianLSQ(func,x0,Hess,args=(),ftol=1.49012e-8,xtol=1.49012e-8, maxcyc=0):
28   
29    """
30    Minimize the sum of squares of a set of equations.
31
32    ::
33   
34                    Nobs
35        x = arg min(sum(func(y)**2,axis=0))
36                    y=0
37
38    Parameters
39    ----------
40    func : callable
41        should take at least one (possibly length N vector) argument and
42        returns M floating point numbers.
43    x0 : ndarray
44        The starting estimate for the minimization of length N
45    Hess : callable
46        A required function or method to compute the weighted vector and Hessian for func.
47        It must be a symmetric NxN array
48    args : tuple
49        Any extra arguments to func are placed in this tuple.
50    ftol : float
51        Relative error desired in the sum of squares.
52    xtol : float
53        Relative error desired in the approximate solution.
54    maxcyc : int
55        The maximum number of cycles of refinement to execute, if -1 refine
56        until other limits are met (ftol, xtol)
57
58    Returns
59    -------
60    x : ndarray
61        The solution (or the result of the last iteration for an unsuccessful
62        call).
63    cov_x : ndarray
64        Uses the fjac and ipvt optional outputs to construct an
65        estimate of the jacobian around the solution.  ``None`` if a
66        singular matrix encountered (indicates very flat curvature in
67        some direction).  This matrix must be multiplied by the
68        residual standard deviation to get the covariance of the
69        parameter estimates -- see curve_fit.
70    infodict : dict
71        a dictionary of optional outputs with the key s::
72
73            - 'fvec' : the function evaluated at the output
74
75
76    Notes
77    -----
78
79    """
80               
81    x0 = np.array(x0, ndmin=1)      #might be redundant?
82    n = len(x0)
83    if type(args) != type(()):
84        args = (args,)
85       
86    icycle = 0
87    One = np.ones((n,n))
88    lam = 0.001
89    lamMax = lam
90    nfev = 0
91    while icycle <= maxcyc:
92        lamMax = max(lamMax,lam)
93        M = func(x0,*args)
94        nfev += 1
95        chisq0 = np.sum(M**2)
96        Yvec,Amat = Hess(x0,*args)
97        Adiag = np.sqrt(np.diag(Amat))
98        if 0.0 in Adiag:                #hard singularity in matrix
99            psing = list(np.where(Adiag == 0.)[0])
100            return [x0,None,{'num cyc':icycle,'fvec':M,'nfev':nfev,'lamMax':lamMax,'psing':psing}]
101        Anorm = np.outer(Adiag,Adiag)
102        Yvec /= Adiag
103        Amat /= Anorm       
104        while True:
105            Lam = np.eye(Amat.shape[0])*lam
106            Amatlam = Amat*(One+Lam)
107            try:
108                Xvec = nl.solve(Amatlam,Yvec)
109            except LinAlgError:
110                psing = list(np.where(np.diag(nl.gr(Amatlam)[1]) < 1.e-14)[0])
111                return [x0,None,{'num cyc':icycle,'fvec':M,'nfev':nfev,'lamMax':lamMax,'psing':psing}]
112            Xvec /= Adiag
113            M2 = func(x0+Xvec,*args)
114            nfev += 1
115            chisq1 = np.sum(M2**2)
116            if chisq1 > chisq0:
117                lam *= 10.
118            else:
119                x0 += Xvec
120                lam /= 10.
121                break
122        if (chisq0-chisq1)/chisq0 < ftol:
123            break
124        icycle += 1
125    M = func(x0,*args)
126    nfev += 1
127    Yvec,Amat = Hess(x0,*args)
128    try:
129        Bmat = nl.inv(Amat)
130        return [x0,Bmat,{'num cyc':icycle,'fvec':M,'nfev':nfev,'lamMax':lamMax,'psing':[]}]
131    except LinAlgError:
132        psing = list(np.where(np.diag(nl.gr(Amat)[1]) < 1.e-14)[0])
133        return [x0,None,{'num cyc':icycle,'fvec':M,'nfev':nfev,'lamMax':lamMax,'psing':psing}] 
134   
135def getVCov(varyNames,varyList,covMatrix):
136    vcov = np.zeros((len(varyNames),len(varyNames)))
137    for i1,name1 in enumerate(varyNames):
138        for i2,name2 in enumerate(varyNames):
139            try:
140                vcov[i1][i2] = covMatrix[varyList.index(name1)][varyList.index(name2)]
141            except ValueError:
142                vcov[i1][i2] = 0.0
143    return vcov
144   
145def ValEsd(value,esd=0,nTZ=False):                  #NOT complete - don't use
146    # returns value(esd) string; nTZ=True for no trailing zeros
147    # use esd < 0 for level of precision shown e.g. esd=-0.01 gives 2 places beyond decimal
148    #get the 2 significant digits in the esd
149    edig = lambda esd: int(round(10**(math.log10(esd) % 1+1)))
150    #get the number of digits to represent them
151    epl = lambda esd: 2+int(1.545-math.log10(10*edig(esd)))
152   
153    mdec = lambda esd: -int(round(math.log10(abs(esd))))+1
154    ndec = lambda esd: int(1.545-math.log10(abs(esd)))
155    if esd > 0:
156        fmt = '"%.'+str(ndec(esd))+'f(%d)"'
157        return str(fmt%(value,int(round(esd*10**(mdec(esd)))))).strip('"')
158    elif esd < 0:
159         return str(round(value,mdec(esd)))
160    else:
161        text = str("%f"%(value))
162        if nTZ:
163            return text.rstrip('0')
164        else:
165            return text
166
167   
Note: See TracBrowser for help on using the repository browser.