Changeset 5323


Ignore:
Timestamp:
Aug 26, 2022 4:17:15 PM (16 months ago)
Author:
vondreele
Message:

improvements to cluster analysis

Location:
trunk
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/GSASIIdataGUI.py

    r5322 r5323  
    18071807                pth, '',extList[extOrd[0]]+extList[extOrd[1]]+'All files (*.*)|*.*', wx.FD_OPEN)
    18081808            if os.path.exists(lastIparmfile):
    1809                 dlg.SetFilename(lastIparmfile)
     1809                dlg.SetFilename(os.path.split(lastIparmfile)[-1])
    18101810            if dlg.ShowModal() == wx.ID_OK:
    18111811                instfile = dlg.GetPath()
     
    57115711            Id = self.GPXtree.AppendItem(self.root,text='Cluster Analysis')
    57125712            ClustDict = {'Files':[],'Method':'correlation','Limits':[0.,100.],'DataMatrix':[],'plots':'All',
    5713                 'LinkMethod':'average','Opt Order':False,'ConDistMat':[],'NumClust':2,'codes':None}
     5713                'LinkMethod':'average','Opt Order':False,'ConDistMat':[],'NumClust':2,'codes':None,'Scikit':'K-Means'}
    57145714            self.GPXtree.SetItemPyData(Id,ClustDict)
    57155715        else:
  • trunk/GSASIIplot.py

    r5322 r5323  
    1157811578    '''
    1157911579    import scipy.cluster.hierarchy as SCH
    11580     from mpl_toolkits.axes_grid1.inset_locator import inset_axes   
     11580    from mpl_toolkits.axes_grid1.inset_locator import inset_axes
     11581   
     11582    global SetPick
     11583    SetPick = True
    1158111584    def OnMotion(event):
    11582        
     11585        global SetPick
    1158311586        if event.xdata and event.ydata:       
    1158411587            G2frame.G2plotNB.status.SetStatusText('x=%.3f y=%.3f'%(event.xdata,event.ydata),1)
     11588            SetPick = True
    1158511589           
    1158611590    def OnPick(event):
    11587         line = event.artist
    11588         ind = int(line.get_label().split('tion')[1])
    11589         text = 'Data selected: %s'%(CLuDict['Files'][ind])
    11590         G2frame.G2plotNB.status.SetStatusText(text,1)
    11591         print(text)
     11591        global SetPick
     11592        if SetPick:
     11593            line = event.artist
     11594            ind = int(line.get_label().split('tion')[1])
     11595            text = 'PCA Data selected: (%d) %s'%(ind,CLuDict['Files'][ind])
     11596            G2frame.G2plotNB.status.SetStatusText(text,1)
     11597            SetPick = False
     11598            print(text)
    1159211599           
    1159311600    Colors = ['xkcd:blue','xkcd:red','xkcd:green','xkcd:cyan',
     
    1162611633    elif CLuDict['plots'] == '3D PCA':
    1162711634        if CLuDict['codes'] is not None:
    11628             Plot.scatter(XYZ[0],XYZ[1],XYZ[2],color=[Colors[code] for code in CLuDict['codes']],picker=True)
     11635            for ixyz,xyz in enumerate(XYZ.T):
     11636                Plot.scatter(xyz[0],xyz[1],xyz[2],color=Colors[CLuDict['codes'][ixyz]],picker=True)
    1162911637        else:
    11630             Plot.scatter(XYZ[0],XYZ[1],XYZ[2],color=Colors[0],picker=True)
     11638            for ixyz,xyz in enumerate(XYZ.T):
     11639                Plot.scatter(xyz[0],xyz[1],xyz[2],color=Colors[0],picker=True)
    1163111640        Plot.set_xlabel('PCA axis-1',fontsize=12)
    1163211641        Plot.set_ylabel('PCA axis-2',fontsize=12)
     
    1165411663            for ixyz,xyz in enumerate(XYZ.T):
    1165511664                ax2.scatter(xyz[0],xyz[1],xyz[2],color=Colors[CLuDict['codes'][ixyz]],picker=True)
    11656 #            ax2.scatter(XYZ[0],XYZ[1],XYZ[2],color=[Colors[code] for code in CLuDict['codes']],picker=True)
    1165711665        else:
    1165811666            for ixyz,xyz in enumerate(XYZ.T):
    1165911667                ax2.scatter(xyz[0],xyz[1],xyz[2],color=Colors[0],picker=True)
    11660 #            ax2.scatter(XYZ[0],XYZ[1],XYZ[2],color=Colors[0],picker=True)
    1166111668        ax2.set_xlabel('PCA axis-1',fontsize=12)
    1166211669        ax2.set_ylabel('PCA axis-2',fontsize=12)
  • trunk/GSASIIseqGUI.py

    r5322 r5323  
    15641564###############################################################################################################
    15651565
    1566 def UpdateClusterAnalysis(G2frame,ClusData):
     1566def UpdateClusterAnalysis(G2frame,ClusData,shoNum=-1):
    15671567    import scipy.spatial.distance as SSD
    15681568    import scipy.cluster.hierarchy as SCH
     
    17811781        def OnCompute(event):
    17821782            whitMat = SCV.whiten(ClusData['DataMatrix'])
    1783             codebook,dist = SCV.kmeans(whitMat,ClusData['NumClust'])
     1783            codebook,dist = SCV.kmeans2(whitMat,ClusData['NumClust'])   #use K-means++
    17841784            ClusData['codes'],ClusData['dists'] = SCV.vq(whitMat,codebook)
    17851785            wx.CallAfter(UpdateClusterAnalysis,G2frame,ClusData)
    1786                        
    17871786       
    17881787        kmeanssizer = wx.BoxSizer(wx.HORIZONTAL)
     
    18011800        ClusData['plots'] = plotsel.GetValue()
    18021801        G2plt.PlotClusterXYZ(G2frame,YM,XYZ,ClusData,PlotName=ClusData['Method'],Title=ClusData['Method'])
     1802       
     1803    def ScikitSizer():
     1804       
     1805        def OnClusMethod(event):
     1806            ClusData['Scikit'] = clusMethod.GetValue()
     1807            OnCompute(event)
     1808       
     1809        def OnClusNum(event):
     1810            ClusData['NumClust'] = int(numclust.GetValue())
     1811            OnCompute(event)
     1812           
     1813        def OnCompute(event):
     1814            whitMat = SCV.whiten(ClusData['DataMatrix'])
     1815            if ClusData['Scikit'] == 'K-Means':
     1816                result = SKC.KMeans(n_clusters=ClusData['NumClust'],algorithm='elkan').fit(whitMat)
     1817                print('K-Means sum squared dist. to means %.2f'%result.inertia_)
     1818            elif ClusData['Scikit'] == 'Spectral clustering':
     1819                result = SKC.SpectralClustering(n_clusters=ClusData['NumClust']).fit(whitMat)
     1820            elif ClusData['Scikit'] == 'Mean-shift':
     1821                result = SKC.MeanShift().fit(whitMat)
     1822                print('Number of Mean-shift clusters found: %d'%(np.max(result.labels_)+1))
     1823            elif ClusData['Scikit'] == 'Affinity propagation':
     1824                result = SKC.AffinityPropagation(affinity='precomputed').fit(SSD.squareform(ClusData['ConDistMat']))
     1825                print('Number of Affinity propagation clusters found: %d'%(np.max(result.labels_)+1))
     1826            elif ClusData['Scikit'] == 'Agglomerative clustering':
     1827                result = SKC.AgglomerativeClustering(n_clusters=ClusData['NumClust'],
     1828                    affinity='precomputed',linkage='average').fit(SSD.squareform(ClusData['ConDistMat']))
     1829           
     1830            ClusData['codes'] = result.labels_
     1831            wx.CallAfter(UpdateClusterAnalysis,G2frame,ClusData)
     1832                               
     1833        scikitSizer = wx.BoxSizer(wx.VERTICAL)
     1834        scikitSizer.Add(wx.StaticText(G2frame.dataWindow,label=SKLearnCite))
     1835        choice = ['K-Means','Affinity propagation','Mean-shift','Spectral clustering','Agglomerative clustering']
     1836        clusSizer = wx.BoxSizer(wx.HORIZONTAL)
     1837        clusSizer.Add(wx.StaticText(G2frame.dataWindow,label='Select clusering method: '),0,WACV)
     1838        clusMethod = wx.ComboBox(G2frame.dataWindow,choices=choice,style=wx.CB_READONLY|wx.CB_DROPDOWN)
     1839        clusMethod.SetValue(ClusData['Scikit'])
     1840        clusMethod.Bind(wx.EVT_COMBOBOX,OnClusMethod)
     1841        clusSizer.Add(clusMethod,0,WACV)
     1842        if ClusData['Scikit'] in ['K-Means','Spectral clustering','Agglomerative clustering']:
     1843            nchoice = [str(i) for i in range(2,16)]
     1844            clusSizer.Add(wx.StaticText(G2frame.dataWindow,label=' Select number of clusters (2-15): '),0,WACV)
     1845            numclust = wx.ComboBox(G2frame.dataWindow,choices=nchoice,style=wx.CB_READONLY|wx.CB_DROPDOWN)
     1846            numclust.SetValue(str(ClusData['NumClust']))
     1847            numclust.Bind(wx.EVT_COMBOBOX,OnClusNum)
     1848            clusSizer.Add(numclust,0,WACV)
     1849        compute = wx.Button(G2frame.dataWindow,label='Compute')
     1850        compute.Bind(wx.EVT_BUTTON,OnCompute)
     1851        clusSizer.Add(compute)
     1852        scikitSizer.Add(clusSizer)
     1853        useTxt = '%s used the whitened data matrix'%ClusData['Scikit']
     1854        if ClusData['Scikit'] in ['Spectral clustering','Agglomerative clustering']:
     1855            useTxt = '%s used %s for distance method'%(ClusData['Scikit'],ClusData['Method'])
     1856        print(useTxt)
     1857        scikitSizer.Add(wx.StaticText(G2frame.dataWindow,label=useTxt))
     1858        return scikitSizer
     1859   
     1860    def memberSizer():
     1861       
     1862        def OnClusNum(event):
     1863            shoNum = int(numclust.GetValue())
     1864            wx.CallAfter(UpdateClusterAnalysis,G2frame,ClusData,shoNum)
     1865           
     1866        NClust = np.max(ClusData['codes'])
     1867        memSizer = wx.BoxSizer(wx.VERTICAL)
     1868        memSizer.Add(wx.StaticText(G2frame.dataWindow,label='Cluster populations:'))       
     1869        for i in range(NClust+1):
     1870            nPop= len(ClusData['codes'])-np.count_nonzero(ClusData['codes']-i)
     1871            memSizer.Add(wx.StaticText(G2frame.dataWindow,label='Cluster #%d has %d members'%(i,nPop)))       
     1872        headSizer = wx.BoxSizer(wx.HORIZONTAL)
     1873        headSizer.Add(wx.StaticText(G2frame.dataWindow,label='Select cluster to list members: '),0,WACV)       
     1874        choice = [str(i) for i in range(NClust+1)]
     1875        numclust = wx.ComboBox(G2frame.dataWindow,choices=choice,value=str(shoNum),style=wx.CB_READONLY|wx.CB_DROPDOWN)
     1876        numclust.Bind(wx.EVT_COMBOBOX,OnClusNum)
     1877        headSizer.Add(numclust,0,WACV)
     1878        memSizer.Add(headSizer)       
     1879        if shoNum >= 0:
     1880            memSizer.Add(wx.StaticText(G2frame.dataWindow,label='Members of cluster %d:'%shoNum))
     1881            text = ''
     1882            for i,item in enumerate(ClusData['Files']):
     1883                if ClusData['codes'][i] == shoNum:
     1884                    text += '(%d) %s\n'%(i,item)
     1885            memSizer.Add(wx.StaticText(G2frame.dataWindow,label=text))       
     1886        return memSizer
    18031887           
    18041888    #patch
    18051889    ClusData['plots'] = ClusData.get('plots','All')
     1890    ClusData['Scikit'] = ClusData.get('Scikit','K-Means')
     1891    #end patch
    18061892    G2frame.dataWindow.ClearData()
    18071893    bigSizer = wx.BoxSizer(wx.HORIZONTAL)
     
    18101896    subSizer = wx.BoxSizer(wx.HORIZONTAL)
    18111897    subSizer.Add((-1,-1),1,wx.EXPAND)
    1812     subSizer.Add(wx.StaticText(G2frame.dataWindow,label='Cluster Analysis: '),0,WACV)   
     1898    subSizer.Add(wx.StaticText(G2frame.dataWindow,label='Scipy Cluster Analysis: '),0,WACV)   
    18131899    subSizer.Add((-1,-1),1,wx.EXPAND)
    18141900    mainSizer.Add(subSizer,0,wx.EXPAND)
     
    18301916            mainSizer.Add(MethodSizer())
    18311917            if len(ClusData['ConDistMat']):
    1832                 Y = ClusData['ConDistMat']
    1833                 YM = SSD.squareform(Y)
     1918                YM = SSD.squareform(ClusData['ConDistMat'])
    18341919                U,s,VT = nl.svd(YM) #s are the Eigenvalues
    18351920                ClusData['PCA'] = s
    18361921                s[3:] = 0.
    18371922                S = np.diag(s)
    1838                 XYZ = np.dot(U,np.dot(S,VT))
    1839                 G2plt.PlotClusterXYZ(G2frame,XYZ,XYZ[:3,:],ClusData,PlotName=ClusData['Method'],Title=ClusData['Method'])
     1923                XYZ = np.dot(S,VT)
     1924                G2plt.PlotClusterXYZ(G2frame,YM,XYZ[:3,:],ClusData,PlotName=ClusData['Method'],Title=ClusData['Method'])
    18401925                G2G.HorizontalLine(mainSizer,G2frame.dataWindow)
    18411926                mainSizer.Add(wx.StaticText(G2frame.dataWindow,label='Hierarchical Cluster Analysis:'))
     
    18451930                mainSizer.Add(wx.StaticText(G2frame.dataWindow,label='K-means Cluster Analysis:'))
    18461931                mainSizer.Add(kmeanSizer())
    1847                 if ClusData['codes'] is not None:
     1932                if 'dists' in ClusData:
    18481933                    kmeansres = wx.BoxSizer(wx.HORIZONTAL)
    18491934                    kmeansres.Add(wx.StaticText(G2frame.dataWindow,label='K-means ave. dist = %.2f'%np.mean(ClusData['dists'])))
    18501935                    mainSizer.Add(kmeansres)
     1936            if ClusData['codes'] is not None:
     1937                G2G.HorizontalLine(mainSizer,G2frame.dataWindow)
     1938                mainSizer.Add(memberSizer())
    18511939            G2G.HorizontalLine(mainSizer,G2frame.dataWindow)
    18521940            plotSizer = wx.BoxSizer(wx.HORIZONTAL)
     
    18621950            mainSizer.Add(plotSizer)
    18631951           
    1864         if ClusData['SKLearn']:
    1865             G2G.HorizontalLine(mainSizer,G2frame.dataWindow)
    1866             mainSizer.Add(wx.StaticText(G2frame.dataWindow,label=SKLearnCite))
    1867            
    1868            
    1869            
    1870            
    1871    
    1872    
    1873    
     1952            if ClusData['SKLearn'] and len(ClusData['ConDistMat']):
     1953                G2G.HorizontalLine(mainSizer,G2frame.dataWindow)
     1954                subSizer = wx.BoxSizer(wx.HORIZONTAL)
     1955                subSizer.Add((-1,-1),1,wx.EXPAND)
     1956                subSizer.Add(wx.StaticText(G2frame.dataWindow,label='Scikit-Learn Cluster Analysis: '),0,WACV)   
     1957                subSizer.Add((-1,-1),1,wx.EXPAND)
     1958                mainSizer.Add(subSizer,0,wx.EXPAND)
     1959                mainSizer.Add(ScikitSizer())
     1960           
     1961               
    18741962    bigSizer.Add(mainSizer)
    18751963       
Note: See TracChangeset for help on using the changeset viewer.