source: trunk/testDeriv.py

Last change on this file was 5509, checked in by vondreele, 4 weeks ago

more fixes to spin RB GUI & math.
Implement corrections to CalcPDF as suggested by Leighanne Gallington.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Date Author Revision URL Id
File size: 13.0 KB
Line 
1# -*- coding: utf-8 -*-
2#testDeriv.py
3'''
4*testDeriv: Check derivative computation*
5=========================================
6
7Use this to check derivatives used in structure least squares
8refinement against numerical values computed in this script.
9
10To use set ``DEBUG=True`` in GSASIIstrMain.py (line 40, as of version
112546); run the least squares - zero cycles is sufficient.  Do the "Save
12Results"; this will write the file testDeriv.dat in the local
13directory.
14
15Then run this program to see plots of derivatives for all
16parameters refined in the last least squares.  Shown will be numerical
17derivatives generated over all observations (including penalty terms)
18and the corresponding analytical ones produced in the least
19squares. They should match. Profiling is also done for function
20calculation & for the 1st selected derivative (rest should be the same).
21'''
22
23import sys
24import os
25import platform
26import copy
27if '2' in platform.python_version_tuple()[0]:
28    import cPickle
29    import StringIO
30else:
31    import pickle as cPickle
32    import io as StringIO
33import cProfile,pstats
34import wx
35import numpy as np
36import GSASIIpath
37GSASIIpath.SetBinaryPath()
38import GSASIIstrMath as G2stMth
39import GSASIItestplot as plot
40import GSASIImapvars as G2mv
41try:  # fails on doc build
42    import pytexture as ptx
43    ptx.pyqlmninit()            #initialize fortran arrays for spherical harmonics
44except ImportError:
45    pass
46
47try:
48    NewId = wx.NewIdRef
49except AttributeError:
50    NewId = wx.NewId
51[wxID_FILEEXIT, wxID_FILEOPEN, wxID_MAKEPLOTS, wxID_CLEARSEL,wxID_SELECTALL,
52] = [NewId() for _init_coll_File_Items in range(5)]
53
54def FileDlgFixExt(dlg,file):            #this is needed to fix a problem in linux wx.FileDialog
55    ext = dlg.GetWildcard().split('|')[2*dlg.GetFilterIndex()+1].strip('*')
56    if ext not in file:
57        file += ext
58    return file
59   
60class testDeriv(wx.Frame):
61
62    def _init_ctrls(self, parent):
63        wx.Frame.__init__(self, name='testDeriv', parent=parent,
64            size=wx.Size(800, 250),style=wx.DEFAULT_FRAME_STYLE, title='Test Jacobian Derivatives')
65        self.testDerivMenu = wx.MenuBar()
66        self.File = wx.Menu(title='')
67        self.File.Append(wxID_FILEOPEN,'Open testDeriv file\tCtrl+O','Open testDeriv')
68        self.File.Append(wxID_MAKEPLOTS,'Make plots\tCtrl+P','Make derivative plots')
69        self.File.Append(wxID_SELECTALL,'Select all\tCtrl+S')
70        self.File.Append(wxID_CLEARSEL,'Clear selections\tCtrl+C')
71        self.File.Append(wxID_FILEEXIT,'Exit\tALT+F4','Exit from testDeriv')
72        self.Bind(wx.EVT_MENU,self.OnTestRead, id=wxID_FILEOPEN)
73        self.Bind(wx.EVT_MENU,self.OnMakePlots,id=wxID_MAKEPLOTS)
74        self.Bind(wx.EVT_MENU,self.ClearSelect,id=wxID_CLEARSEL)
75        self.Bind(wx.EVT_MENU,self.SelectAll,id=wxID_SELECTALL)
76        self.Bind(wx.EVT_MENU,self.OnFileExit, id=wxID_FILEEXIT)
77        self.testDerivMenu.Append(menu=self.File, title='File')
78        self.SetMenuBar(self.testDerivMenu)
79        self.testDerivPanel = wx.ScrolledWindow(self)
80        self.plotNB = plot.PlotNotebook()
81        self.testFile = ''
82        arg = sys.argv
83        if len(arg) > 1 and arg[1]:
84            try:
85                self.testFile = os.path.splitext(arg[1])[0]+u'.testDeriv'
86            except:
87                self.testFile = os.path.splitext(arg[1])[0]+'.testDeriv'
88            self.TestRead()
89            self.UpdateControls(None)
90       
91    def __init__(self, parent):
92        self._init_ctrls(parent)
93        self.Bind(wx.EVT_CLOSE, self.ExitMain)   
94        self.dirname = ''
95        self.testfile = []
96        self.dataFrame = None
97        self.timingOn = False
98
99    def ExitMain(self, event):
100        sys.exit()
101       
102    def OnFileExit(self,event):
103        if self.dataFrame:
104            self.dataFrame.Clear() 
105            self.dataFrame.Destroy()
106        self.Close()
107       
108    def SelectAll(self,event):
109        self.use = [True for name in self.names]
110        for i,name in enumerate(self.names):
111            if 'Back' in name:
112                self.use[i] = False
113        self.UpdateControls(event)
114       
115    def ClearSelect(self,event):
116        self.use = [False for i in range(len(self.names))]
117        self.UpdateControls(event)
118
119    def OnTestRead(self,event):
120        dlg = wx.FileDialog(self, 'Open *.testDeriv file',defaultFile='*.testDeriv',
121            wildcard='*.testDeriv')
122        if self.dirname:
123            dlg.SetDirectory(self.dirname)
124        try:
125            if dlg.ShowModal() == wx.ID_OK:
126                self.dirname = dlg.GetDirectory()
127                self.testFile = dlg.GetPath()
128                self.TestRead()
129                self.UpdateControls(event)
130        finally:
131            dlg.Destroy()
132           
133    def TestRead(self):
134        file = open(self.testFile,'rb')
135        self.values = cPickle.load(file,encoding='Latin-1')
136        self.HistoPhases = cPickle.load(file,encoding='Latin-1')
137        (self.constrDict,self.fixedList,self.depVarList) = cPickle.load(file,encoding='Latin-1')
138        self.parmDict = cPickle.load(file,encoding='Latin-1')
139        self.varylist = cPickle.load(file,encoding='Latin-1')
140        self.calcControls = cPickle.load(file,encoding='Latin-1')
141        self.pawleyLookup = cPickle.load(file,encoding='Latin-1')
142        self.names = self.varylist+self.depVarList
143        self.use = [False for i in range(len(self.names))]
144        self.delt = [max(abs(self.parmDict[name])*0.0001,1e-6) for name in self.names]
145        for iname,name in enumerate(self.names):
146            if name.split(':')[-1] in ['Shift','DisplaceX','DisplaceY',]:
147                self.delt[iname] = 0.1
148        file.close()
149        G2mv.InitVars()
150        msg = G2mv.EvaluateMultipliers(self.constrDict,self.parmDict)
151        if msg:
152            print('Unable to interpret multiplier(s): '+msg)
153            raise Exception
154        G2mv.GenerateConstraints(self.varylist,self.constrDict,self.fixedList,self.parmDict)
155        print(G2mv.VarRemapShow(self.varylist))
156        print('Dependent Vary List:',self.depVarList)
157        G2mv.Map2Dict(self.parmDict,copy.copy(self.varylist))   # compute independent params, N.B. changes varylist
158        G2mv.Dict2Map(self.parmDict) # imposes constraints on dependent values
159
160    def UpdateControls(self,event):
161        def OnItemCk(event):
162            Obj = event.GetEventObject()
163            item = ObjInd[Obj.GetId()]
164            self.use[item] = Obj.GetValue()
165           
166        def OnDelValue(event):
167            event.Skip()
168            Obj = event.GetEventObject()
169            item = ObjInd[Obj.GetId()]
170            try:
171                value = float(Obj.GetValue())
172            except ValueError:
173                value = self.delt[item]
174            self.delt[item] = value
175            Obj.SetValue('%g'%(value))
176       
177        if self.testDerivPanel.GetSizer():
178            self.testDerivPanel.GetSizer().Clear(True)
179        ObjInd = {}
180        use = self.use
181        delt = self.delt
182        topSizer = wx.BoxSizer(wx.VERTICAL)
183        self.timingVal = wx.CheckBox(self.testDerivPanel,label='Show Execution Profiling')
184        topSizer.Add(self.timingVal,0)
185        topSizer.Add((-1,10))
186        mainSizer = wx.FlexGridSizer(0,8,5,5)
187        for id,[ck,name,d] in enumerate(zip(use,self.names,delt)):
188            useVal = wx.CheckBox(self.testDerivPanel,label=name)
189            useVal.SetValue(ck)
190            ObjInd[useVal.GetId()] = id
191            useVal.Bind(wx.EVT_CHECKBOX, OnItemCk)
192            mainSizer.Add(useVal,0)
193            delVal = wx.TextCtrl(self.testDerivPanel,wx.ID_ANY,'%g'%(d),style=wx.TE_PROCESS_ENTER)
194            ObjInd[delVal.GetId()] = id
195            delVal.Bind(wx.EVT_TEXT_ENTER,OnDelValue)
196            delVal.Bind(wx.EVT_KILL_FOCUS,OnDelValue)
197            mainSizer.Add(delVal,0)
198        topSizer.Add(mainSizer,0)
199        self.testDerivPanel.SetSizer(topSizer)   
200        Size = topSizer.GetMinSize()
201        self.testDerivPanel.SetScrollbars(10,10,int(Size[0]/10-4),int(Size[1]/10-1))
202        Size[1] = max(200,Size[1])
203        Size[0] += 20
204        self.SetSize(Size)
205
206    def OnMakePlots(self,event):
207       
208        def test1():
209            fplot = self.plotNB.add('function test').gca()
210            if self.timingOn:
211                pr = cProfile.Profile()
212                pr.enable()
213            M = G2stMth.errRefine(self.values,self.HistoPhases,
214                self.parmDict,self.varylist,self.calcControls,
215                self.pawleyLookup,None)
216            if self.timingOn:
217                pr.disable()
218                s = StringIO.StringIO()
219                sortby = 'tottime'
220                ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats(sortby)
221                print('Profiler of function calculation; top 50% of routines:')
222                ps.print_stats("GSASII",.5)
223                print(s.getvalue())
224            fplot.plot(M,'r',label='M')
225            fplot.legend(loc='best')
226           
227        def test2(name,delt,doProfile):
228            Title = 'derivatives test for '+name
229            ind = self.names.index(name)
230            hplot = self.plotNB.add(Title).gca()
231            if doProfile and self.timingOn:
232                pr = cProfile.Profile()
233                pr.enable()
234            #regenerate minimization fxn
235            G2stMth.errRefine(self.values,self.HistoPhases,
236                self.parmDict,self.varylist,self.calcControls,
237                self.pawleyLookup,None)
238            dMdV = G2stMth.dervRefine(self.values,self.HistoPhases,self.parmDict,
239                self.names,self.calcControls,self.pawleyLookup,None)
240            if doProfile and self.timingOn:
241                pr.disable()
242                s = StringIO.StringIO()
243                sortby = 'tottime'
244                ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats(sortby)
245                ps.print_stats("GSASII",.5)
246                print('Profiler of '+name+' derivative calculation; top 50% of routines:')
247                print(s.getvalue())
248            M2 = dMdV[ind]
249            hplot.plot(M2,'b',label='analytic deriv')
250            mmin = np.min(dMdV[ind])
251            mmax = np.max(dMdV[ind])
252            if name in self.varylist:
253                ind = self.varylist.index(name)
254                orig = copy.copy(self.parmDict)  # save parmDict before changes
255                self.parmDict[name] = self.values[ind] =  self.values[ind] - delt
256                G2mv.Dict2Map(self.parmDict)
257                first = True
258                for i in self.parmDict:
259                    if orig[i] != self.parmDict[i] and i != name:
260                        if first:
261                            print('Propagated changes from this shift')
262                            print(name,orig[name],self.parmDict[name],orig[name]-self.parmDict[name])
263                            print('are:')
264                            first = False
265                        print(i,orig[i],self.parmDict[i],orig[i]-self.parmDict[i])
266                M0 = G2stMth.errRefine(self.values,self.HistoPhases,self.parmDict,
267                    self.names,self.calcControls,self.pawleyLookup,None)
268                self.parmDict[name] = self.values[ind] =  self.values[ind] + 2.*delt
269                G2mv.Dict2Map(self.parmDict)
270                M1 = G2stMth.errRefine(self.values,self.HistoPhases,self.parmDict,
271                    self.names,self.calcControls,self.pawleyLookup,None)
272                self.parmDict[name] = self.values[ind] =  self.values[ind] - delt
273                G2mv.Dict2Map(self.parmDict)
274            elif name in self.depVarList:   #in depVarList
275                if 'dA' in name:
276                    name = name.replace('dA','A')
277                    #delt *= -1  # why???
278                self.parmDict[name] -= delt
279                G2mv.Dict2Map(self.parmDict)
280                M0 = G2stMth.errRefine(self.values,self.HistoPhases,self.parmDict,
281                        self.names,self.calcControls,self.pawleyLookup,None)
282                self.parmDict[name] += 2.*delt
283                G2mv.Dict2Map(self.parmDict)
284                M1 = G2stMth.errRefine(self.values,self.HistoPhases,self.parmDict,
285                        self.names,self.calcControls,self.pawleyLookup,None)
286                self.parmDict[name] -= delt   
287                G2mv.Dict2Map(self.parmDict)
288            Mn = (M1-M0)/(2.*abs(delt))
289            print('parameter:',name,self.parmDict[name],delt,mmin,mmax,np.sum(M0),np.sum(M1),np.sum(Mn))
290            hplot.plot(Mn,'r',label='numeric deriv')
291            hplot.legend(loc='best')           
292           
293        while self.plotNB.nb.GetPageCount():
294            self.plotNB.nb.DeletePage(0)
295           
296        test1()
297        self.timingOn = self.timingVal.GetValue()
298
299        doProfile = True
300        for use,name,delt in zip(self.use,self.names,self.delt):
301            if use:
302                test2(name,delt,doProfile)
303                doProfile = False
304       
305        self.plotNB.Show()
306       
307def main():
308    'Starts main application to compute and plot derivatives'
309    application = wx.App(0)
310    application.main = testDeriv(None)
311    application.main.Show()
312    application.SetTopWindow(application.main)
313    application.MainLoop()
314   
315if __name__ == '__main__':
316    GSASIIpath.InvokeDebugOpts()
317    main()
Note: See TracBrowser for help on using the repository browser.