1 | # To maximize python3/python2 compatibility |
---|
2 | from __future__ import print_function |
---|
3 | from __future__ import unicode_literals |
---|
4 | from __future__ import division |
---|
5 | from __future__ import absolute_import |
---|
6 | |
---|
7 | __copyright = """ |
---|
8 | PYCIFRW License Agreement (Python License, Version 2) |
---|
9 | ----------------------------------------------------- |
---|
10 | |
---|
11 | 1. This LICENSE AGREEMENT is between the Australian Nuclear Science |
---|
12 | and Technology Organisation ("ANSTO"), and the Individual or |
---|
13 | Organization ("Licensee") accessing and otherwise using this software |
---|
14 | ("PyCIFRW") in source or binary form and its associated documentation. |
---|
15 | |
---|
16 | 2. Subject to the terms and conditions of this License Agreement, |
---|
17 | ANSTO hereby grants Licensee a nonexclusive, royalty-free, world-wide |
---|
18 | license to reproduce, analyze, test, perform and/or display publicly, |
---|
19 | prepare derivative works, distribute, and otherwise use PyCIFRW alone |
---|
20 | or in any derivative version, provided, however, that this License |
---|
21 | Agreement and ANSTO's notice of copyright, i.e., "Copyright (c) |
---|
22 | 2001-2014 ANSTO; All Rights Reserved" are retained in PyCIFRW alone or |
---|
23 | in any derivative version prepared by Licensee. |
---|
24 | |
---|
25 | 3. In the event Licensee prepares a derivative work that is based on |
---|
26 | or incorporates PyCIFRW or any part thereof, and wants to make the |
---|
27 | derivative work available to others as provided herein, then Licensee |
---|
28 | hereby agrees to include in any such work a brief summary of the |
---|
29 | changes made to PyCIFRW. |
---|
30 | |
---|
31 | 4. ANSTO is making PyCIFRW available to Licensee on an "AS IS" |
---|
32 | basis. ANSTO MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR |
---|
33 | IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, ANSTO MAKES NO AND |
---|
34 | DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS |
---|
35 | FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYCIFRW WILL NOT |
---|
36 | INFRINGE ANY THIRD PARTY RIGHTS. |
---|
37 | |
---|
38 | 5. ANSTO SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYCIFRW |
---|
39 | FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS A |
---|
40 | RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYCIFRW, OR ANY |
---|
41 | DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. |
---|
42 | |
---|
43 | 6. This License Agreement will automatically terminate upon a material |
---|
44 | breach of its terms and conditions. |
---|
45 | |
---|
46 | 7. Nothing in this License Agreement shall be deemed to create any |
---|
47 | relationship of agency, partnership, or joint venture between ANSTO |
---|
48 | and Licensee. This License Agreement does not grant permission to use |
---|
49 | ANSTO trademarks or trade name in a trademark sense to endorse or |
---|
50 | promote products or services of Licensee, or any third party. |
---|
51 | |
---|
52 | 8. By copying, installing or otherwise using PyCIFRW, Licensee agrees |
---|
53 | to be bound by the terms and conditions of this License Agreement. |
---|
54 | |
---|
55 | """ |
---|
56 | |
---|
57 | |
---|
58 | # Python 2,3 compatibility |
---|
59 | try: |
---|
60 | from urllib import urlopen # for arbitrary opening |
---|
61 | from urlparse import urlparse, urlunparse |
---|
62 | except: |
---|
63 | from urllib.request import urlopen |
---|
64 | from urllib.parse import urlparse,urlunparse |
---|
65 | import re,os |
---|
66 | import copy |
---|
67 | import textwrap |
---|
68 | |
---|
69 | try: |
---|
70 | from StringIO import StringIO #not cStringIO as we cannot subclass |
---|
71 | except ImportError: |
---|
72 | from io import StringIO |
---|
73 | |
---|
74 | if isinstance(u"abc",str): #Python 3 |
---|
75 | unicode = str |
---|
76 | |
---|
77 | try: |
---|
78 | import numpy |
---|
79 | have_numpy = True |
---|
80 | except ImportError: |
---|
81 | have_numpy = False |
---|
82 | |
---|
83 | class StarList(list): |
---|
84 | def __getitem__(self,args): |
---|
85 | if isinstance(args,(int,slice)): |
---|
86 | return super(StarList,self).__getitem__(args) |
---|
87 | elif isinstance(args,tuple) and len(args)>1: #extended comma notation |
---|
88 | return super(StarList,self).__getitem__(args[0]).__getitem__(args[1:]) |
---|
89 | else: |
---|
90 | return super(StarList,self).__getitem__(args[0]) |
---|
91 | |
---|
92 | def __str__(self): |
---|
93 | return "SL("+super(StarList,self).__str__() + ")" |
---|
94 | |
---|
95 | class StarDict(dict): |
---|
96 | pass |
---|
97 | |
---|
98 | |
---|
99 | class LoopBlock(object): |
---|
100 | def __init__(self,parent_block,dataname): |
---|
101 | self.loop_no = parent_block.FindLoop(dataname) |
---|
102 | if self.loop_no < 0: |
---|
103 | raise KeyError('%s is not in a loop structure' % dataname) |
---|
104 | self.parent_block = parent_block |
---|
105 | |
---|
106 | def keys(self): |
---|
107 | return self.parent_block.loops[self.loop_no] |
---|
108 | |
---|
109 | def values(self): |
---|
110 | return [self.parent_block[a] for a in self.keys()] |
---|
111 | |
---|
112 | #Avoid iterator even though that is Python3-esque |
---|
113 | def items(self): |
---|
114 | return list(zip(self.keys(),self.values())) |
---|
115 | |
---|
116 | def __getitem__(self,dataname): |
---|
117 | if isinstance(dataname,int): #a packet request |
---|
118 | return self.GetPacket(dataname) |
---|
119 | if dataname in self.keys(): |
---|
120 | return self.parent_block[dataname] |
---|
121 | else: |
---|
122 | raise KeyError('%s not in loop block' % dataname) |
---|
123 | |
---|
124 | def __setitem__(self,dataname,value): |
---|
125 | self.parent_block[dataname] = value |
---|
126 | self.parent_block.AddLoopName(self.keys()[0],dataname) |
---|
127 | |
---|
128 | def __contains__(self,key): |
---|
129 | return key in self.parent_block.loops[self.loop_no] |
---|
130 | |
---|
131 | def has_key(self,key): |
---|
132 | return key in self |
---|
133 | |
---|
134 | def __iter__(self): |
---|
135 | packet_list = zip(*self.values()) |
---|
136 | names = self.keys() |
---|
137 | for p in packet_list: |
---|
138 | r = StarPacket(p) |
---|
139 | for n in range(len(names)): |
---|
140 | setattr(r,names[n].lower(),r[n]) |
---|
141 | yield r |
---|
142 | |
---|
143 | # for compatibility |
---|
144 | def __getattr__(self,attname): |
---|
145 | return getattr(self.parent_block,attname) |
---|
146 | |
---|
147 | def load_iter(self,coords=[]): |
---|
148 | count = 0 #to create packet index |
---|
149 | while not self.popout: |
---|
150 | # ok, we have a new packet: append a list to our subloops |
---|
151 | for aloop in self.loops: |
---|
152 | aloop.new_enclosing_packet() |
---|
153 | for iname in self.item_order: |
---|
154 | if isinstance(iname,LoopBlock): #into a nested loop |
---|
155 | for subitems in iname.load_iter(coords=coords+[count]): |
---|
156 | # print 'Yielding %s' % `subitems` |
---|
157 | yield subitems |
---|
158 | # print 'End of internal loop' |
---|
159 | else: |
---|
160 | if self.dimension == 0: |
---|
161 | # print 'Yielding %s' % `self[iname]` |
---|
162 | yield self,self[iname] |
---|
163 | else: |
---|
164 | backval = self.block[iname] |
---|
165 | for i in range(len(coords)): |
---|
166 | # print 'backval, coords: %s, %s' % (`backval`,`coords`) |
---|
167 | backval = backval[coords[i]] |
---|
168 | yield self,backval |
---|
169 | count = count + 1 # count packets |
---|
170 | self.popout = False # reinitialise |
---|
171 | # print 'Finished iterating' |
---|
172 | yield self,'###Blank###' #this value should never be used |
---|
173 | |
---|
174 | # an experimental fast iterator for level-1 loops (ie CIF) |
---|
175 | def fast_load_iter(self): |
---|
176 | targets = map(lambda a:self.block[a],self.item_order) |
---|
177 | while targets: |
---|
178 | for target in targets: |
---|
179 | yield self,target |
---|
180 | |
---|
181 | # Add another list of the required shape to take into account a new outer packet |
---|
182 | def new_enclosing_packet(self): |
---|
183 | if self.dimension > 1: #otherwise have a top-level list |
---|
184 | for iname in self.keys(): #includes lower levels |
---|
185 | target_list = self[iname] |
---|
186 | for i in range(3,self.dimension): #dim 2 upwards are lists of lists of... |
---|
187 | target_list = target_list[-1] |
---|
188 | target_list.append([]) |
---|
189 | # print '%s now %s' % (iname,`self[iname]`) |
---|
190 | |
---|
191 | def recursive_iter(self,dict_so_far={},coord=[]): |
---|
192 | # print "Recursive iter: coord %s, keys %s, dim %d" % (`coord`,`self.block.keys()`,self.dimension) |
---|
193 | my_length = 0 |
---|
194 | top_items = self.block.items() |
---|
195 | top_values = self.block.values() #same order as items |
---|
196 | drill_values = self.block.values() |
---|
197 | for dimup in range(0,self.dimension): #look higher in the tree |
---|
198 | if len(drill_values)>0: #this block has values |
---|
199 | drill_values=drill_values[0] #drill in |
---|
200 | else: |
---|
201 | raise StarError("Malformed loop packet %s" % repr( top_items[0] )) |
---|
202 | my_length = len(drill_values[0]) #length of 'string' entry |
---|
203 | if self.dimension == 0: #top level |
---|
204 | for aloop in self.loops: |
---|
205 | for apacket in aloop.recursive_iter(): |
---|
206 | # print "Recursive yielding %s" % repr( dict(top_items + apacket.items()) ) |
---|
207 | prep_yield = StarPacket(top_values+apacket.values()) #straight list |
---|
208 | for name,value in top_items + apacket.items(): |
---|
209 | setattr(prep_yield,name,value) |
---|
210 | yield prep_yield |
---|
211 | else: #in some loop |
---|
212 | for i in range(my_length): |
---|
213 | kvpairs = map(lambda a:(a,self.coord_to_group(a,coord)[i]),self.block.keys()) |
---|
214 | kvvals = map(lambda a:a[1],kvpairs) #just values |
---|
215 | # print "Recursive kvpairs at %d: %s" % (i,repr( kvpairs )) |
---|
216 | if self.loops: |
---|
217 | for aloop in self.loops: |
---|
218 | for apacket in aloop.recursive_iter(coord=coord+[i]): |
---|
219 | # print "Recursive yielding %s" % repr( dict(kvpairs + apacket.items()) ) |
---|
220 | prep_yield = StarPacket(kvvals+apacket.values()) |
---|
221 | for name,value in kvpairs + apacket.items(): |
---|
222 | setattr(prep_yield,name,value) |
---|
223 | yield prep_yield |
---|
224 | else: # we're at the bottom of the tree |
---|
225 | # print "Recursive yielding %s" % repr( dict(kvpairs) ) |
---|
226 | prep_yield = StarPacket(kvvals) |
---|
227 | for name,value in kvpairs: |
---|
228 | setattr(prep_yield,name,value) |
---|
229 | yield prep_yield |
---|
230 | |
---|
231 | # small function to use the coordinates. |
---|
232 | def coord_to_group(self,dataname,coords): |
---|
233 | if not isinstance(dataname,unicode): |
---|
234 | return dataname # flag inner loop processing |
---|
235 | newm = self[dataname] # newm must be a list or tuple |
---|
236 | for c in coords: |
---|
237 | # print "Coord_to_group: %s ->" % (repr( newm )), |
---|
238 | newm = newm[c] |
---|
239 | # print repr( newm ) |
---|
240 | return newm |
---|
241 | |
---|
242 | def flat_iterator(self): |
---|
243 | my_length = 0 |
---|
244 | top_keys = self.block.keys() |
---|
245 | if len(top_keys)>0: |
---|
246 | my_length = len(self.block[top_keys[0]]) |
---|
247 | for pack_no in range(my_length): |
---|
248 | yield(self.collapse(pack_no)) |
---|
249 | |
---|
250 | |
---|
251 | def RemoveItem(self,itemname): |
---|
252 | """Remove `itemname` from the block.""" |
---|
253 | # first check any loops |
---|
254 | loop_no = self.FindLoop(itemname) |
---|
255 | testkey = itemname.lower() |
---|
256 | if testkey in self: |
---|
257 | del self.block[testkey] |
---|
258 | del self.true_case[testkey] |
---|
259 | # now remove from loop |
---|
260 | if loop_no >= 0: |
---|
261 | self.loops[loop_no].remove(testkey) |
---|
262 | if len(self.loops[loop_no])==0: |
---|
263 | del self.loops[loop_no] |
---|
264 | self.item_order.remove(loop_no) |
---|
265 | else: #will appear in order list |
---|
266 | self.item_order.remove(testkey) |
---|
267 | |
---|
268 | def RemoveLoopItem(self,itemname): |
---|
269 | """*Deprecated*. Use `RemoveItem` instead""" |
---|
270 | self.RemoveItem(itemname) |
---|
271 | |
---|
272 | def GetLoop(self,keyname): |
---|
273 | """Return a `StarFile.LoopBlock` object constructed from the loop containing `keyname`. |
---|
274 | `keyname` is only significant as a way to specify the loop.""" |
---|
275 | return LoopBlock(self,keyname) |
---|
276 | |
---|
277 | def GetPacket(self,index): |
---|
278 | thispack = StarPacket([]) |
---|
279 | for myitem in self.parent_block.loops[self.loop_no]: |
---|
280 | thispack.append(self[myitem][index]) |
---|
281 | setattr(thispack,myitem,thispack[-1]) |
---|
282 | return thispack |
---|
283 | |
---|
284 | def AddPacket(self,packet): |
---|
285 | for myitem in self.parent_block.loops[self.loop_no]: |
---|
286 | old_values = self.parent_block[myitem] |
---|
287 | old_values.append(packet.__getattribute__(myitem)) |
---|
288 | self.parent_block[myitem] = old_values |
---|
289 | |
---|
290 | def GetItemOrder(self): |
---|
291 | """Return a list of datanames in this `LoopBlock` in the order that they will be |
---|
292 | printed""" |
---|
293 | return self.parent_block.loops[self.loop_no][:] |
---|
294 | |
---|
295 | |
---|
296 | def GetItemOrder(self): |
---|
297 | """Return a list of datanames in this `LoopBlock` in the order that they will be |
---|
298 | printed""" |
---|
299 | return self.parent_block.loops[self.loop_no][:] |
---|
300 | |
---|
301 | def ChangeItemOrder(self,itemname,newpos): |
---|
302 | """Change the position at which `itemname` appears when printing out to `newpos`.""" |
---|
303 | self.parent_block.loops[self.loop_no].remove(itemname.lower()) |
---|
304 | self.parent_block.loops[self.loop_no].insert(newpos,itemname.lower()) |
---|
305 | |
---|
306 | def GetItemPosition(self,itemname): |
---|
307 | """A utility function to get the numerical order in the printout |
---|
308 | of `itemname`. An item has coordinate `(loop_no,pos)` with |
---|
309 | the top level having a `loop_no` of -1. If an integer is passed to |
---|
310 | the routine then it will return the position of the loop |
---|
311 | referenced by that number.""" |
---|
312 | import string |
---|
313 | if isinstance(itemname,int): |
---|
314 | # return loop position |
---|
315 | return (-1, self.item_order.index(itemname)) |
---|
316 | if not itemname in self: |
---|
317 | raise ValueError('No such dataname %s' % itemname) |
---|
318 | testname = itemname.lower() |
---|
319 | if testname in self.item_order: |
---|
320 | return (-1,self.item_order.index(testname)) |
---|
321 | loop_no = self.FindLoop(testname) |
---|
322 | loop_pos = self.loops[loop_no].index(testname) |
---|
323 | return loop_no,loop_pos |
---|
324 | |
---|
325 | def GetLoopNames(self,keyname): |
---|
326 | if keyname in self: |
---|
327 | return self.keys() |
---|
328 | for aloop in self.loops: |
---|
329 | try: |
---|
330 | return aloop.GetLoopNames(keyname) |
---|
331 | except KeyError: |
---|
332 | pass |
---|
333 | raise KeyError('Item does not exist') |
---|
334 | |
---|
335 | def GetLoopNames(self,keyname): |
---|
336 | """Return all datanames appearing together with `keyname`""" |
---|
337 | loop_no = self.FindLoop(keyname) |
---|
338 | if loop_no >= 0: |
---|
339 | return self.loops[loop_no] |
---|
340 | else: |
---|
341 | raise KeyError('%s is not in any loop' % keyname) |
---|
342 | |
---|
343 | def AddToLoop(self,dataname,loopdata): |
---|
344 | thisloop = self.GetLoop(dataname) |
---|
345 | for itemname,itemvalue in loopdata.items(): |
---|
346 | thisloop[itemname] = itemvalue |
---|
347 | |
---|
348 | def AddToLoop(self,dataname,loopdata): |
---|
349 | """*Deprecated*. Use `AddItem` followed by calls to `AddLoopName`. |
---|
350 | |
---|
351 | Add multiple columns to the loop containing `dataname`. `loopdata` is a |
---|
352 | collection of (key,value) pairs, where `key` is the new dataname and `value` |
---|
353 | is a list of values for that dataname""" |
---|
354 | # check lengths |
---|
355 | thisloop = self.FindLoop(dataname) |
---|
356 | loop_len = len(self[dataname]) |
---|
357 | bad_vals = [a for a in loopdata.items() if len(a[1])!=loop_len] |
---|
358 | if len(bad_vals)>0: |
---|
359 | raise StarLengthError("Number of values for looped datanames %s not equal to %d" \ |
---|
360 | % (repr( bad_vals ),loop_len)) |
---|
361 | self.update(loopdata) |
---|
362 | self.loops[thisloop]+=loopdata.keys() |
---|
363 | |
---|
364 | |
---|
365 | class StarBlock(object): |
---|
366 | def __init__(self,data = (), maxoutlength=2048, wraplength=80, overwrite=True, |
---|
367 | characterset='ascii',maxnamelength=-1): |
---|
368 | self.block = {} #the actual data storage (lower case keys) |
---|
369 | self.loops = {} #each loop is indexed by a number and contains a list of datanames |
---|
370 | self.item_order = [] #lower case, loops referenced by integer |
---|
371 | self.formatting_hints = {} |
---|
372 | self.true_case = {} #transform lower case to supplied case |
---|
373 | self.provide_value = False #prefer string version always |
---|
374 | self.dictionary = None #DDLm dictionary |
---|
375 | self.popout = False #used during load iteration |
---|
376 | self.curitem = -1 #used during iteration |
---|
377 | self.cache_vals = True #store all calculated values |
---|
378 | self.maxoutlength = maxoutlength |
---|
379 | self.setmaxnamelength(maxnamelength) #to enforce CIF limit of 75 characters |
---|
380 | self.set_characterset(characterset) #to check input names |
---|
381 | self.wraplength = wraplength |
---|
382 | self.overwrite = overwrite |
---|
383 | self.string_delimiters = ["'",'"',"\n;"] #universal CIF set |
---|
384 | self.list_delimiter = " " #CIF2 default |
---|
385 | self.wrapper = textwrap.TextWrapper() |
---|
386 | if isinstance(data,(tuple,list)): |
---|
387 | for item in data: |
---|
388 | self.AddLoopItem(item) |
---|
389 | elif isinstance(data,StarBlock): |
---|
390 | self.block = data.block.copy() |
---|
391 | self.item_order = data.item_order[:] |
---|
392 | self.true_case = data.true_case.copy() |
---|
393 | # loops as well |
---|
394 | self.loops = data.loops.copy() |
---|
395 | |
---|
396 | def setmaxnamelength(self,maxlength): |
---|
397 | """Set the maximum allowable dataname length (-1 for no check)""" |
---|
398 | self.maxnamelength = maxlength |
---|
399 | if maxlength > 0: |
---|
400 | bad_names = [a for a in self.keys() if len(a)>self.maxnamelength] |
---|
401 | if len(bad_names)>0: |
---|
402 | raise StarError('Datanames too long: ' + repr( bad_names )) |
---|
403 | |
---|
404 | def set_characterset(self,characterset): |
---|
405 | """Set the characterset for checking datanames: may be `ascii` or `unicode`""" |
---|
406 | import sys |
---|
407 | self.characterset = characterset |
---|
408 | if characterset == 'ascii': |
---|
409 | self.char_check = re.compile("[][ \n\r\t!%&\(\)*+,./:<=>?@0-9A-Za-z\\\\^`{}\|~\"#$';_-]+",re.M) |
---|
410 | elif characterset == 'unicode': |
---|
411 | if sys.maxunicode < 1114111: |
---|
412 | self.char_check = re.compile(u"[][ \n\r\t!%&\(\)*+,./:<=>?@0-9A-Za-z\\\\^`{}\|~\"#$';_\u00A0-\uD7FF\uE000-\uFDCF\uFDF0-\uFFFD-]+",re.M) |
---|
413 | else: |
---|
414 | self.char_check = re.compile(u"[][ \n\r\t!%&\(\)*+,./:<=>?@0-9A-Za-z\\\\^`{}\|~\"#$';_\u00A0-\uD7FF\uE000-\uFDCF\uFDF0-\uFFFD\U00010000-\U0010FFFD-]+",re.M) |
---|
415 | |
---|
416 | def __str__(self): |
---|
417 | return self.printsection() |
---|
418 | |
---|
419 | def __setitem__(self,key,value): |
---|
420 | if key == "saves": |
---|
421 | raise StarError("""Setting the saves key is deprecated. Add the save block to |
---|
422 | an enclosing block collection (e.g. CIF or STAR file) with this block as child""") |
---|
423 | self.AddItem(key,value) |
---|
424 | |
---|
425 | def __getitem__(self,key): |
---|
426 | if key == "saves": |
---|
427 | raise StarError("""The saves key is deprecated. Access the save block from |
---|
428 | the enclosing block collection (e.g. CIF or STAR file object)""") |
---|
429 | try: |
---|
430 | rawitem,is_value = self.GetFullItemValue(key) |
---|
431 | except KeyError: |
---|
432 | if self.dictionary: |
---|
433 | # send the dictionary the required key and a pointer to us |
---|
434 | try: |
---|
435 | new_value = self.dictionary.derive_item(key,self,store_value=self.cache_vals,allow_defaults=False) |
---|
436 | except StarDerivationFailure: #try now with defaults included |
---|
437 | try: |
---|
438 | new_value = self.dictionary.derive_item(key,self,store_value=self.cache_vals,allow_defaults=True) |
---|
439 | except StarDerivationFailure as s: |
---|
440 | print("In StarBlock.__getitem__, " + repr(s)) |
---|
441 | raise KeyError('No such item: %s' % key) |
---|
442 | print('Set %s to derived value %s' % (key, repr(new_value))) |
---|
443 | return new_value |
---|
444 | else: |
---|
445 | raise KeyError('No such item: %s' % key) |
---|
446 | # we now have an item, we can try to convert it to a number if that is appropriate |
---|
447 | # note numpy values are never stored but are converted to lists |
---|
448 | if not self.dictionary or not key in self.dictionary: return rawitem |
---|
449 | print('%s: is_value %s provide_value %s value %s' % (key,repr( is_value ),repr( self.provide_value ),repr( rawitem ))) |
---|
450 | if is_value: |
---|
451 | if self.provide_value: return rawitem |
---|
452 | else: |
---|
453 | print('Turning %s into string' % repr( rawitem )) |
---|
454 | return self.convert_to_string(key) |
---|
455 | else: # a string |
---|
456 | if self.provide_value and ((not isinstance(rawitem,list) and rawitem != '?' and rawitem != ".") or \ |
---|
457 | (isinstance(rawitem,list) and '?' not in rawitem and '.' not in rawitem)): |
---|
458 | return self.dictionary.change_type(key,rawitem) |
---|
459 | elif self.provide_value: # catch the question marks |
---|
460 | do_calculate = False |
---|
461 | if isinstance(rawitem,(list,tuple)): |
---|
462 | known = [a for a in rawitem if a != '?'] |
---|
463 | if len(known) == 0: #all questions |
---|
464 | do_calculate = True |
---|
465 | elif rawitem == '?': |
---|
466 | do_calculate = True |
---|
467 | if do_calculate: |
---|
468 | # remove old value |
---|
469 | del self[key] |
---|
470 | try: |
---|
471 | new_value = self.dictionary.derive_item(key,self,store_value=True,allow_defaults=False) |
---|
472 | except StarDerivationFailure as s: |
---|
473 | try: |
---|
474 | new_value = self.dictionary.derive_item(key,self,store_value=True,allow_defaults=True) |
---|
475 | except StarDerivationFailure as s: |
---|
476 | |
---|
477 | print("Could not turn %s into a value:" + repr(s)) |
---|
478 | return rawitem |
---|
479 | else: |
---|
480 | print('Set %s to derived value %s' % (key, repr( new_value ))) |
---|
481 | return new_value |
---|
482 | return rawitem #can't do anything |
---|
483 | |
---|
484 | def __delitem__(self,key): |
---|
485 | self.RemoveItem(key) |
---|
486 | |
---|
487 | def __len__(self): |
---|
488 | blen = len(self.block) |
---|
489 | return blen |
---|
490 | |
---|
491 | def __nonzero__(self): |
---|
492 | if self.__len__() > 0: return 1 |
---|
493 | return 0 |
---|
494 | |
---|
495 | # keys returns all internal keys |
---|
496 | def keys(self): |
---|
497 | return list(self.block.keys()) #always lower case |
---|
498 | |
---|
499 | def values(self): |
---|
500 | return [self[a] for a in self.keys()] |
---|
501 | |
---|
502 | def items(self): |
---|
503 | return list(zip(self.keys(),self.values())) |
---|
504 | |
---|
505 | def __contains__(self,key): |
---|
506 | if isinstance(key,(unicode,str)) and key.lower() in self.keys(): |
---|
507 | return True |
---|
508 | return False |
---|
509 | |
---|
510 | def has_key(self,key): |
---|
511 | return key in self |
---|
512 | |
---|
513 | def has_key_or_alias(self,key): |
---|
514 | """Check if a dataname or alias is available in the block""" |
---|
515 | initial_test = key in self |
---|
516 | if initial_test: return True |
---|
517 | elif self.dictionary: |
---|
518 | aliases = [k for k in self.dictionary.alias_table.get(key,[]) if self.has_key(k)] |
---|
519 | if len(aliases)>0: |
---|
520 | return True |
---|
521 | return False |
---|
522 | |
---|
523 | def get(self,key,default=None): |
---|
524 | if key in self: |
---|
525 | retval = self.__getitem__(key) |
---|
526 | else: |
---|
527 | retval = default |
---|
528 | return retval |
---|
529 | |
---|
530 | def clear(self): |
---|
531 | self.block = {} |
---|
532 | self.loops = {} |
---|
533 | self.item_order = [] |
---|
534 | self.true_case = {} |
---|
535 | |
---|
536 | # doesn't appear to work |
---|
537 | def copy(self): |
---|
538 | newcopy = StarBlock() |
---|
539 | newcopy.block = self.block.copy() |
---|
540 | newcopy.loops = [] |
---|
541 | newcopy.item_order = self.item_order[:] |
---|
542 | newcopy.true_case = self.true_case.copy() |
---|
543 | newcopy.loops = self.loops.copy() |
---|
544 | # return self.copy.im_class(newcopy) #catch inheritance |
---|
545 | return newcopy |
---|
546 | |
---|
547 | def update(self,adict): |
---|
548 | for key in adict.keys(): |
---|
549 | self.AddItem(key,adict[key]) |
---|
550 | |
---|
551 | def GetItemPosition(self,itemname): |
---|
552 | """A utility function to get the numerical order in the printout |
---|
553 | of `itemname`. An item has coordinate `(loop_no,pos)` with |
---|
554 | the top level having a `loop_no` of -1. If an integer is passed to |
---|
555 | the routine then it will return the position of the loop |
---|
556 | referenced by that number.""" |
---|
557 | import string |
---|
558 | if isinstance(itemname,int): |
---|
559 | # return loop position |
---|
560 | return (-1, self.item_order.index(itemname)) |
---|
561 | if not itemname in self: |
---|
562 | raise ValueError('No such dataname %s' % itemname) |
---|
563 | testname = itemname.lower() |
---|
564 | if testname in self.item_order: |
---|
565 | return (-1,self.item_order.index(testname)) |
---|
566 | loop_no = self.FindLoop(testname) |
---|
567 | loop_pos = self.loops[loop_no].index(testname) |
---|
568 | return loop_no,loop_pos |
---|
569 | |
---|
570 | def ChangeItemOrder(self,itemname,newpos): |
---|
571 | """Move the printout order of `itemname` to `newpos`. If `itemname` is |
---|
572 | in a loop, `newpos` refers to the order within the loop.""" |
---|
573 | if isinstance(itemname,(unicode,str)): |
---|
574 | true_name = itemname.lower() |
---|
575 | else: |
---|
576 | true_name = itemname |
---|
577 | loopno = self.FindLoop(true_name) |
---|
578 | if loopno < 0: #top level |
---|
579 | self.item_order.remove(true_name) |
---|
580 | self.item_order.insert(newpos,true_name) |
---|
581 | else: |
---|
582 | self.loops[loopno].remove(true_name) |
---|
583 | self.loops[loopno].insert(newpos,true_name) |
---|
584 | |
---|
585 | def GetItemOrder(self): |
---|
586 | """Return a list of datanames in the order in which they will be printed. Loops are |
---|
587 | referred to by numerical index""" |
---|
588 | return self.item_order[:] |
---|
589 | |
---|
590 | def AddItem(self,key,value,precheck=False): |
---|
591 | """Add dataname `key` to block with value `value`. `value` may be |
---|
592 | a single value, a list or a tuple. If `precheck` is False (the default), |
---|
593 | all values will be checked and converted to unicode strings as necessary. If |
---|
594 | `precheck` is True, this checking is bypassed. No checking is necessary |
---|
595 | when values are read from a CIF file as they are already in correct form.""" |
---|
596 | if not isinstance(key,(unicode,str)): |
---|
597 | raise TypeError('Star datanames are strings only (got %s)' % repr( key )) |
---|
598 | key = unicode(key) #everything is unicode internally |
---|
599 | if not precheck: |
---|
600 | self.check_data_name(key,self.maxnamelength) # make sure no nasty characters |
---|
601 | # check for overwriting |
---|
602 | if key in self: |
---|
603 | if not self.overwrite: |
---|
604 | raise StarError( 'Attempt to insert duplicate item name %s' % key) |
---|
605 | if not precheck: #need to sanitise |
---|
606 | regval,empty_val = self.regularise_data(value) |
---|
607 | pure_string = check_stringiness(regval) |
---|
608 | self.check_item_value(regval) |
---|
609 | else: |
---|
610 | regval,empty_val = value,None |
---|
611 | pure_string = True |
---|
612 | # update ancillary information first |
---|
613 | lower_key = key.lower() |
---|
614 | if not lower_key in self and self.FindLoop(lower_key)<0: #need to add to order |
---|
615 | self.item_order.append(lower_key) |
---|
616 | # always remove from our case table in case the case is different |
---|
617 | try: |
---|
618 | del self.true_case[lower_key] |
---|
619 | except KeyError: |
---|
620 | pass |
---|
621 | self.true_case[lower_key] = key |
---|
622 | if pure_string: |
---|
623 | self.block.update({lower_key:[regval,empty_val]}) |
---|
624 | else: |
---|
625 | self.block.update({lower_key:[empty_val,regval]}) |
---|
626 | |
---|
627 | def AddLoopItem(self,incomingdata,precheck=False,maxlength=-1): |
---|
628 | """*Deprecated*. Use `AddItem` followed by `CreateLoop` if |
---|
629 | necessary.""" |
---|
630 | # print "Received data %s" % `incomingdata` |
---|
631 | # we accept tuples, strings, lists and dicts!! |
---|
632 | # Direct insertion: we have a string-valued key, with an array |
---|
633 | # of values -> single-item into our loop |
---|
634 | if isinstance(incomingdata[0],(tuple,list)): |
---|
635 | # a whole loop |
---|
636 | keyvallist = zip(incomingdata[0],incomingdata[1]) |
---|
637 | for key,value in keyvallist: |
---|
638 | self.AddItem(key,value) |
---|
639 | self.CreateLoop(incomingdata[0]) |
---|
640 | elif not isinstance(incomingdata[0],(unicode,str)): |
---|
641 | raise TypeError('Star datanames are strings only (got %s)' % repr( incomingdata[0] )) |
---|
642 | else: |
---|
643 | self.AddItem(incomingdata[0],incomingdata[1]) |
---|
644 | |
---|
645 | def check_data_name(self,dataname,maxlength=-1): |
---|
646 | if maxlength > 0: |
---|
647 | self.check_name_length(dataname,maxlength) |
---|
648 | if dataname[0]!='_': |
---|
649 | raise StarError( 'Dataname ' + dataname + ' does not begin with _') |
---|
650 | if self.characterset=='ascii': |
---|
651 | if len ([a for a in dataname if ord(a) < 33 or ord(a) > 126]) > 0: |
---|
652 | raise StarError( 'Dataname ' + dataname + ' contains forbidden characters') |
---|
653 | else: |
---|
654 | # print 'Checking %s for unicode characterset conformance' % dataname |
---|
655 | if len ([a for a in dataname if ord(a) < 33]) > 0: |
---|
656 | raise StarError( 'Dataname ' + dataname + ' contains forbidden characters (below code point 33)') |
---|
657 | if len ([a for a in dataname if ord(a) > 126 and ord(a) < 160]) > 0: |
---|
658 | raise StarError( 'Dataname ' + dataname + ' contains forbidden characters (between code point 127-159)') |
---|
659 | if len ([a for a in dataname if ord(a) > 0xD7FF and ord(a) < 0xE000]) > 0: |
---|
660 | raise StarError( 'Dataname ' + dataname + ' contains unsupported characters (between U+D800 and U+E000)') |
---|
661 | if len ([a for a in dataname if ord(a) > 0xFDCF and ord(a) < 0xFDF0]) > 0: |
---|
662 | raise StarError( 'Dataname ' + dataname + ' contains unsupported characters (between U+FDD0 and U+FDEF)') |
---|
663 | if len ([a for a in dataname if ord(a) == 0xFFFE or ord(a) == 0xFFFF]) > 0: |
---|
664 | raise StarError( 'Dataname ' + dataname + ' contains unsupported characters (U+FFFE and/or U+FFFF)') |
---|
665 | if len ([a for a in dataname if ord(a) > 0x10000 and (ord(a) & 0xE == 0xE)]) > 0: |
---|
666 | print('%s fails' % dataname) |
---|
667 | for a in dataname: print('%x' % ord(a),end="") |
---|
668 | print() |
---|
669 | raise StarError( u'Dataname ' + dataname + u' contains unsupported characters (U+xFFFE and/or U+xFFFF)') |
---|
670 | |
---|
671 | def check_name_length(self,dataname,maxlength): |
---|
672 | if len(dataname)>maxlength: |
---|
673 | raise StarError( 'Dataname %s exceeds maximum length %d' % (dataname,maxlength)) |
---|
674 | return |
---|
675 | |
---|
676 | def check_item_value(self,item): |
---|
677 | test_item = item |
---|
678 | if not isinstance(item,(list,dict,tuple)): |
---|
679 | test_item = [item] #single item list |
---|
680 | def check_one (it): |
---|
681 | if isinstance(it,unicode): |
---|
682 | if it=='': return |
---|
683 | me = self.char_check.match(it) |
---|
684 | if not me: |
---|
685 | print("Fail value check: %s" % it) |
---|
686 | raise StarError('Bad character in %s' % it) |
---|
687 | else: |
---|
688 | if me.span() != (0,len(it)): |
---|
689 | print("Fail value check, match only %d-%d in string %s" % (me.span()[0],me.span()[1],repr( it ))) |
---|
690 | raise StarError('Data item "' + repr( it ) + u'"... contains forbidden characters') |
---|
691 | [check_one(a) for a in test_item] |
---|
692 | |
---|
693 | def regularise_data(self,dataitem): |
---|
694 | """Place dataitem into a list if necessary""" |
---|
695 | from numbers import Number |
---|
696 | if isinstance(dataitem,str): |
---|
697 | return unicode(dataitem),None |
---|
698 | if isinstance(dataitem,(Number,unicode,StarList,StarDict)): |
---|
699 | return dataitem,None #assume StarList/StarDict contain unicode if necessary |
---|
700 | if isinstance(dataitem,(tuple,list)): |
---|
701 | v,s = zip(*list([self.regularise_data(a) for a in dataitem])) |
---|
702 | return list(v),list(s) |
---|
703 | #return dataitem,[None]*len(dataitem) |
---|
704 | # so try to make into a list |
---|
705 | try: |
---|
706 | regval = list(dataitem) |
---|
707 | except TypeError as value: |
---|
708 | raise StarError( str(dataitem) + ' is wrong type for data value\n' ) |
---|
709 | v,s = zip(*list([self.regularise_data(a) for a in regval])) |
---|
710 | return list(v),list(s) |
---|
711 | |
---|
712 | def RemoveItem(self,itemname): |
---|
713 | """Remove `itemname` from the block.""" |
---|
714 | # first check any loops |
---|
715 | loop_no = self.FindLoop(itemname) |
---|
716 | testkey = itemname.lower() |
---|
717 | if testkey in self: |
---|
718 | del self.block[testkey] |
---|
719 | del self.true_case[testkey] |
---|
720 | # now remove from loop |
---|
721 | if loop_no >= 0: |
---|
722 | self.loops[loop_no].remove(testkey) |
---|
723 | if len(self.loops[loop_no])==0: |
---|
724 | del self.loops[loop_no] |
---|
725 | self.item_order.remove(loop_no) |
---|
726 | else: #will appear in order list |
---|
727 | self.item_order.remove(testkey) |
---|
728 | |
---|
729 | def RemoveLoopItem(self,itemname): |
---|
730 | """*Deprecated*. Use `RemoveItem` instead""" |
---|
731 | self.RemoveItem(itemname) |
---|
732 | |
---|
733 | def GetItemValue(self,itemname): |
---|
734 | """Return value of `itemname`. If `itemname` is looped, a list |
---|
735 | of all values will be returned.""" |
---|
736 | return self.GetFullItemValue(itemname)[0] |
---|
737 | |
---|
738 | def GetFullItemValue(self,itemname): |
---|
739 | """Return the value associated with `itemname`, and a boolean flagging whether |
---|
740 | (True) or not (False) it is in a form suitable for calculation. False is |
---|
741 | always returned for strings and `StarList` objects.""" |
---|
742 | try: |
---|
743 | s,v = self.block[itemname.lower()] |
---|
744 | except KeyError: |
---|
745 | raise KeyError('Itemname %s not in datablock' % itemname) |
---|
746 | # prefer string value unless all are None |
---|
747 | # are we a looped value? |
---|
748 | if not isinstance(s,(tuple,list)) or isinstance(s,StarList): |
---|
749 | if not_none(s): |
---|
750 | return s,False #a string value |
---|
751 | else: |
---|
752 | return v,not isinstance(v,StarList) #a StarList is not calculation-ready |
---|
753 | elif not_none(s): |
---|
754 | return s,False #a list of string values |
---|
755 | else: |
---|
756 | if len(v)>0: |
---|
757 | return v,not isinstance(v[0],StarList) |
---|
758 | return v,True |
---|
759 | |
---|
760 | def CreateLoop(self,datanames,order=-1,length_check=True): |
---|
761 | """Create a loop in the datablock. `datanames` is a list of datanames that |
---|
762 | together form a loop. If length_check is True, they should have been initialised in the block |
---|
763 | to have the same number of elements (possibly 0). If `order` is given, |
---|
764 | the loop will appear at this position in the block when printing |
---|
765 | out. A loop counts as a single position.""" |
---|
766 | |
---|
767 | if length_check: |
---|
768 | # check lengths: these datanames should exist |
---|
769 | listed_values = [a for a in datanames if isinstance(self[a],list) and not isinstance(self[a],StarList)] |
---|
770 | if len(listed_values) == len(datanames): |
---|
771 | len_set = set([len(self[a]) for a in datanames]) |
---|
772 | if len(len_set)>1: |
---|
773 | raise ValueError('Request to loop datanames %s with different lengths: %s' % (repr( datanames ),repr( len_set ))) |
---|
774 | elif len(listed_values) != 0: |
---|
775 | raise ValueError('Request to loop datanames where some are single values and some are not') |
---|
776 | # store as lower case |
---|
777 | lc_datanames = [d.lower() for d in datanames] |
---|
778 | # remove these datanames from all other loops |
---|
779 | [self.loops[a].remove(b) for a in self.loops for b in lc_datanames if b in self.loops[a]] |
---|
780 | # remove empty loops |
---|
781 | empty_loops = [a for a in self.loops.keys() if len(self.loops[a])==0] |
---|
782 | for a in empty_loops: |
---|
783 | self.item_order.remove(a) |
---|
784 | del self.loops[a] |
---|
785 | if len(self.loops)>0: |
---|
786 | loopno = max(self.loops.keys()) + 1 |
---|
787 | else: |
---|
788 | loopno = 1 |
---|
789 | self.loops[loopno] = list(lc_datanames) |
---|
790 | if order >= 0: |
---|
791 | self.item_order.insert(order,loopno) |
---|
792 | else: |
---|
793 | self.item_order.append(loopno) |
---|
794 | # remove these datanames from item ordering |
---|
795 | self.item_order = [a for a in self.item_order if a not in lc_datanames] |
---|
796 | |
---|
797 | def AddLoopName(self,oldname, newname): |
---|
798 | """Add `newname` to the loop containing `oldname`. If it is already in the new loop, no |
---|
799 | error is raised. If `newname` is in a different loop, it is removed from that loop. |
---|
800 | The number of values associated with `newname` must match the number of values associated |
---|
801 | with all other columns of the new loop or a `ValueError` will be raised.""" |
---|
802 | lower_newname = newname.lower() |
---|
803 | loop_no = self.FindLoop(oldname) |
---|
804 | if loop_no < 0: |
---|
805 | raise KeyError('%s not in loop' % oldname) |
---|
806 | if lower_newname in self.loops[loop_no]: |
---|
807 | return |
---|
808 | # check length |
---|
809 | old_provides = self.provide_value |
---|
810 | self.provide_value = False |
---|
811 | loop_len = len(self[oldname]) |
---|
812 | self.provide_value = old_provides |
---|
813 | if len(self[newname]) != loop_len: |
---|
814 | raise ValueError('Mismatch of loop column lengths for %s: should be %d' % (newname,loop_len)) |
---|
815 | # remove from any other loops |
---|
816 | [self.loops[a].remove(lower_newname) for a in self.loops if lower_newname in self.loops[a]] |
---|
817 | # and add to this loop |
---|
818 | self.loops[loop_no].append(lower_newname) |
---|
819 | # remove from item_order if present |
---|
820 | try: |
---|
821 | self.item_order.remove(lower_newname) |
---|
822 | except ValueError: |
---|
823 | pass |
---|
824 | |
---|
825 | def FindLoop(self,keyname): |
---|
826 | """Find the loop that contains `keyname` and return its numerical index or |
---|
827 | -1 if not present. The numerical index can be used to refer to the loop in |
---|
828 | other routines.""" |
---|
829 | loop_no = [a for a in self.loops.keys() if keyname.lower() in self.loops[a]] |
---|
830 | if len(loop_no)>0: |
---|
831 | return loop_no[0] |
---|
832 | else: |
---|
833 | return -1 |
---|
834 | |
---|
835 | def GetLoop(self,keyname): |
---|
836 | """Return a `StarFile.LoopBlock` object constructed from the loop containing `keyname`. |
---|
837 | `keyname` is only significant as a way to specify the loop.""" |
---|
838 | return LoopBlock(self,keyname) |
---|
839 | |
---|
840 | def GetLoopNames(self,keyname): |
---|
841 | if keyname in self: |
---|
842 | return self.keys() |
---|
843 | for aloop in self.loops: |
---|
844 | try: |
---|
845 | return aloop.GetLoopNames(keyname) |
---|
846 | except KeyError: |
---|
847 | pass |
---|
848 | raise KeyError('Item does not exist') |
---|
849 | |
---|
850 | def GetLoopNames(self,keyname): |
---|
851 | """Return all datanames appearing together with `keyname`""" |
---|
852 | loop_no = self.FindLoop(keyname) |
---|
853 | if loop_no >= 0: |
---|
854 | return self.loops[loop_no] |
---|
855 | else: |
---|
856 | raise KeyError('%s is not in any loop' % keyname) |
---|
857 | |
---|
858 | def AddLoopName(self,oldname, newname): |
---|
859 | """Add `newname` to the loop containing `oldname`. If it is already in the new loop, no |
---|
860 | error is raised. If `newname` is in a different loop, it is removed from that loop. |
---|
861 | The number of values associated with `newname` must match the number of values associated |
---|
862 | with all other columns of the new loop or a `ValueError` will be raised.""" |
---|
863 | lower_newname = newname.lower() |
---|
864 | loop_no = self.FindLoop(oldname) |
---|
865 | if loop_no < 0: |
---|
866 | raise KeyError('%s not in loop' % oldname) |
---|
867 | if lower_newname in self.loops[loop_no]: |
---|
868 | return |
---|
869 | # check length |
---|
870 | old_provides = self.provide_value |
---|
871 | self.provide_value = False |
---|
872 | loop_len = len(self[oldname]) |
---|
873 | self.provide_value = old_provides |
---|
874 | if len(self[newname]) != loop_len: |
---|
875 | raise ValueError('Mismatch of loop column lengths for %s: should be %d' % (newname,loop_len)) |
---|
876 | # remove from any other loops |
---|
877 | [self.loops[a].remove(lower_newname) for a in self.loops if lower_newname in self.loops[a]] |
---|
878 | # and add to this loop |
---|
879 | self.loops[loop_no].append(lower_newname) |
---|
880 | # remove from item_order if present |
---|
881 | try: |
---|
882 | self.item_order.remove(lower_newname) |
---|
883 | except ValueError: |
---|
884 | pass |
---|
885 | |
---|
886 | def AddToLoop(self,dataname,loopdata): |
---|
887 | thisloop = self.GetLoop(dataname) |
---|
888 | for itemname,itemvalue in loopdata.items(): |
---|
889 | thisloop[itemname] = itemvalue |
---|
890 | |
---|
891 | def AddToLoop(self,dataname,loopdata): |
---|
892 | """*Deprecated*. Use `AddItem` followed by calls to `AddLoopName`. |
---|
893 | |
---|
894 | Add multiple columns to the loop containing `dataname`. `loopdata` is a |
---|
895 | collection of (key,value) pairs, where `key` is the new dataname and `value` |
---|
896 | is a list of values for that dataname""" |
---|
897 | # check lengths |
---|
898 | thisloop = self.FindLoop(dataname) |
---|
899 | loop_len = len(self[dataname]) |
---|
900 | bad_vals = [a for a in loopdata.items() if len(a[1])!=loop_len] |
---|
901 | if len(bad_vals)>0: |
---|
902 | raise StarLengthError("Number of values for looped datanames %s not equal to %d" \ |
---|
903 | % (repr( bad_vals ),loop_len)) |
---|
904 | self.update(loopdata) |
---|
905 | self.loops[thisloop]+=loopdata.keys() |
---|
906 | |
---|
907 | def RemoveKeyedPacket(self,keyname,keyvalue): |
---|
908 | """Remove the packet for which dataname `keyname` takes |
---|
909 | value `keyvalue`. Only the first such occurrence is |
---|
910 | removed.""" |
---|
911 | packet_coord = list(self[keyname]).index(keyvalue) |
---|
912 | loopnames = self.GetLoopNames(keyname) |
---|
913 | for dataname in loopnames: |
---|
914 | self.block[dataname][0] = list(self.block[dataname][0]) |
---|
915 | del self.block[dataname][0][packet_coord] |
---|
916 | self.block[dataname][1] = list(self.block[dataname][1]) |
---|
917 | del self.block[dataname][1][packet_coord] |
---|
918 | |
---|
919 | def GetKeyedPacket(self,keyname,keyvalue,no_case=False): |
---|
920 | """Return the loop packet (a `StarPacket` object) where `keyname` has value |
---|
921 | `keyvalue`. Ignore case in `keyvalue` if `no_case` is True. `ValueError` |
---|
922 | is raised if no packet is found or more than one packet is found.""" |
---|
923 | my_loop = self.GetLoop(keyname) |
---|
924 | #print("Looking for %s in %s" % (keyvalue, my_loop.parent_block)) |
---|
925 | #print('Packet check on:' + keyname) |
---|
926 | #[print(repr(getattr(a,keyname))) for a in my_loop] |
---|
927 | if no_case: |
---|
928 | one_pack= [a for a in my_loop if getattr(a,keyname).lower()==keyvalue.lower()] |
---|
929 | else: |
---|
930 | one_pack= [a for a in my_loop if getattr(a,keyname)==keyvalue] |
---|
931 | if len(one_pack)!=1: |
---|
932 | raise ValueError("Bad packet key %s = %s: returned %d packets" % (keyname,keyvalue,len(one_pack))) |
---|
933 | # print("Keyed packet: %s" % one_pack[0]) |
---|
934 | return one_pack[0] |
---|
935 | |
---|
936 | def GetCompoundKeyedPacket(self,keydict): |
---|
937 | """Return the loop packet (a `StarPacket` object) where the `{key:(value,caseless)}` pairs |
---|
938 | in `keydict` take the appropriate values. Ignore case for a given `key` if `caseless` is |
---|
939 | True. `ValueError` is raised if no packet is found or more than one packet is found.""" |
---|
940 | #print "Looking for %s in %s" % (keyvalue, self.parent_block[keyname]) |
---|
941 | keynames = list(keydict.keys()) |
---|
942 | my_loop = self.GetLoop(keynames[0]) |
---|
943 | for one_key in keynames: |
---|
944 | keyval,no_case = keydict[one_key] |
---|
945 | if no_case: |
---|
946 | my_loop = list([a for a in my_loop if str(getattr(a,one_key)).lower()==str(keyval).lower()]) |
---|
947 | else: |
---|
948 | my_loop = list([a for a in my_loop if getattr(a,one_key)==keyval]) |
---|
949 | if len(my_loop)!=1: |
---|
950 | raise ValueError("Bad packet keys %s: returned %d packets" % (repr(keydict),len(my_loop))) |
---|
951 | print("Compound keyed packet: %s" % my_loop[0]) |
---|
952 | return my_loop[0] |
---|
953 | |
---|
954 | def GetKeyedSemanticPacket(self,keyvalue,cat_id): |
---|
955 | """Return a complete packet for category `cat_id` where the |
---|
956 | category key for the category equals `keyvalue`. This routine |
---|
957 | will understand any joined loops, so if separate loops in the |
---|
958 | datafile belong to the |
---|
959 | same category hierarchy (e.g. `_atom_site` and `_atom_site_aniso`), |
---|
960 | the returned `StarPacket` object will contain datanames from |
---|
961 | both categories.""" |
---|
962 | target_keys = self.dictionary.cat_key_table[cat_id] |
---|
963 | target_keys = [k[0] for k in target_keys] #one only in each list |
---|
964 | p = StarPacket() |
---|
965 | # set case-sensitivity flag |
---|
966 | lcase = False |
---|
967 | if self.dictionary[target_keys[0]]['_type.contents'] in ['Code','Tag','Name']: |
---|
968 | lcase = True |
---|
969 | for cat_key in target_keys: |
---|
970 | try: |
---|
971 | extra_packet = self.GetKeyedPacket(cat_key,keyvalue,no_case=lcase) |
---|
972 | except KeyError: #missing key |
---|
973 | try: |
---|
974 | test_key = self[cat_key] #generate key if possible |
---|
975 | print('Test key is %s' % repr( test_key )) |
---|
976 | if test_key is not None and\ |
---|
977 | not (isinstance(test_key,list) and (None in test_key or len(test_key)==0)): |
---|
978 | print('Getting packet for key %s' % repr( keyvalue )) |
---|
979 | extra_packet = self.GetKeyedPacket(cat_key,keyvalue,no_case=lcase) |
---|
980 | except: #cannot be generated |
---|
981 | continue |
---|
982 | except ValueError: #none/more than one, assume none |
---|
983 | continue |
---|
984 | #extra_packet = self.dictionary.generate_default_packet(cat_id,cat_key,keyvalue) |
---|
985 | p.merge_packet(extra_packet) |
---|
986 | # the following attributes used to calculate missing values |
---|
987 | for keyname in target_keys: |
---|
988 | if hasattr(p,keyname): |
---|
989 | p.key = [keyname] |
---|
990 | break |
---|
991 | if not hasattr(p,"key"): |
---|
992 | raise ValueError("No key found for %s, packet is %s" % (cat_id,str(p))) |
---|
993 | p.cif_dictionary = self.dictionary |
---|
994 | p.fulldata = self |
---|
995 | return p |
---|
996 | |
---|
997 | def GetMultiKeyedSemanticPacket(self,keydict,cat_id): |
---|
998 | """Return a complete packet for category `cat_id` where the keyvalues are |
---|
999 | provided as a dictionary of key:(value,caseless) pairs |
---|
1000 | This routine |
---|
1001 | will understand any joined loops, so if separate loops in the |
---|
1002 | datafile belong to the |
---|
1003 | same category hierarchy (e.g. `_atom_site` and `_atom_site_aniso`), |
---|
1004 | the returned `StarPacket` object will contain datanames from |
---|
1005 | the requested category and any children.""" |
---|
1006 | #if len(keyvalues)==1: #simplification |
---|
1007 | # return self.GetKeyedSemanticPacket(keydict[1][0],cat_id) |
---|
1008 | target_keys = self.dictionary.cat_key_table[cat_id] |
---|
1009 | # update the dictionary passed to us with all equivalents, for |
---|
1010 | # simplicity. |
---|
1011 | parallel_keys = list(zip(*target_keys)) #transpose |
---|
1012 | print('Parallel keys:' + repr(parallel_keys)) |
---|
1013 | print('Keydict:' + repr(keydict)) |
---|
1014 | start_keys = list(keydict.keys()) |
---|
1015 | for one_name in start_keys: |
---|
1016 | key_set = [a for a in parallel_keys if one_name in a] |
---|
1017 | for one_key in key_set: |
---|
1018 | keydict[one_key] = keydict[one_name] |
---|
1019 | # target_keys is a list of lists, each of which is a compound key |
---|
1020 | p = StarPacket() |
---|
1021 | # a little function to return the dataname for a key |
---|
1022 | def find_key(key): |
---|
1023 | for one_key in self.dictionary.key_equivs.get(key,[])+[key]: |
---|
1024 | if self.has_key(one_key): |
---|
1025 | return one_key |
---|
1026 | return None |
---|
1027 | for one_set in target_keys: #loop down the categories |
---|
1028 | true_keys = [find_key(k) for k in one_set] |
---|
1029 | true_keys = [k for k in true_keys if k is not None] |
---|
1030 | if len(true_keys)==len(one_set): |
---|
1031 | truekeydict = dict([(t,keydict[k]) for t,k in zip(true_keys,one_set)]) |
---|
1032 | try: |
---|
1033 | extra_packet = self.GetCompoundKeyedPacket(truekeydict) |
---|
1034 | except KeyError: #one or more are missing |
---|
1035 | continue #should try harder? |
---|
1036 | except ValueError: |
---|
1037 | continue |
---|
1038 | else: |
---|
1039 | continue |
---|
1040 | print('Merging packet for keys ' + repr(one_set)) |
---|
1041 | p.merge_packet(extra_packet) |
---|
1042 | # the following attributes used to calculate missing values |
---|
1043 | p.key = true_keys |
---|
1044 | p.cif_dictionary = self.dictionary |
---|
1045 | p.fulldata = self |
---|
1046 | return p |
---|
1047 | |
---|
1048 | |
---|
1049 | def set_grammar(self,new_grammar): |
---|
1050 | self.string_delimiters = ["'",'"',"\n;",None] |
---|
1051 | if new_grammar in ['STAR2','2.0']: |
---|
1052 | self.string_delimiters += ['"""',"'''"] |
---|
1053 | if new_grammar == '2.0': |
---|
1054 | self.list_delimiter = " " |
---|
1055 | elif new_grammar == 'STAR2': |
---|
1056 | self.list_delimiter = ", " |
---|
1057 | elif new_grammar not in ['1.0','1.1']: |
---|
1058 | raise StarError('Request to set unknown grammar %s' % new_grammar) |
---|
1059 | |
---|
1060 | def SetOutputLength(self,wraplength=80,maxoutlength=2048): |
---|
1061 | """Set the maximum output line length (`maxoutlength`) and the line length to |
---|
1062 | wrap at (`wraplength`). The wrap length is a target only and may not always be |
---|
1063 | possible.""" |
---|
1064 | if wraplength > maxoutlength: |
---|
1065 | raise StarError("Wrap length (requested %d) must be <= Maximum line length (requested %d)" % (wraplength,maxoutlength)) |
---|
1066 | self.wraplength = wraplength |
---|
1067 | self.maxoutlength = maxoutlength |
---|
1068 | |
---|
1069 | def printsection(self,instring='',blockstart="",blockend="",indent=0,finish_at='',start_from=''): |
---|
1070 | import string |
---|
1071 | self.provide_value = False |
---|
1072 | # first make an ordering |
---|
1073 | self.create_ordering(finish_at,start_from) #create self.output_order |
---|
1074 | # now do it... |
---|
1075 | if not instring: |
---|
1076 | outstring = CIFStringIO(target_width=80) # the returned string |
---|
1077 | else: |
---|
1078 | outstring = instring |
---|
1079 | # print block delimiter |
---|
1080 | outstring.write(blockstart,canbreak=True) |
---|
1081 | while len(self.output_order)>0: |
---|
1082 | #print "Remaining to output " + `self.output_order` |
---|
1083 | itemname = self.output_order.pop(0) |
---|
1084 | if not isinstance(itemname,int): #no loop |
---|
1085 | item_spec = [i for i in self.formatting_hints if i['dataname'].lower()==itemname.lower()] |
---|
1086 | if len(item_spec)>0: |
---|
1087 | item_spec = item_spec[0] |
---|
1088 | col_pos = item_spec.get('column',-1) |
---|
1089 | name_pos = item_spec.get('name_pos',-1) |
---|
1090 | else: |
---|
1091 | col_pos = -1 |
---|
1092 | item_spec = {} |
---|
1093 | name_pos = -1 |
---|
1094 | if col_pos < 0: col_pos = 40 |
---|
1095 | outstring.set_tab(col_pos) |
---|
1096 | itemvalue = self[itemname] |
---|
1097 | outstring.write(self.true_case[itemname],mustbreak=True,do_tab=False,startcol=name_pos) |
---|
1098 | outstring.write(' ',canbreak=True,do_tab=False,delimiter=True) #space after itemname |
---|
1099 | self.format_value(itemvalue,outstring,hints=item_spec) |
---|
1100 | else:# we are asked to print a loop block |
---|
1101 | outstring.set_tab(10) #guess this is OK? |
---|
1102 | loop_spec = [i['name_pos'] for i in self.formatting_hints if i["dataname"]=='loop'] |
---|
1103 | if loop_spec: |
---|
1104 | loop_indent = max(loop_spec[0],0) |
---|
1105 | else: |
---|
1106 | loop_indent = indent |
---|
1107 | outstring.write('loop_\n',mustbreak=True,do_tab=False,startcol=loop_indent) |
---|
1108 | self.format_names(outstring,indent+2,loop_no=itemname) |
---|
1109 | self.format_packets(outstring,indent+2,loop_no=itemname) |
---|
1110 | else: |
---|
1111 | returnstring = outstring.getvalue() |
---|
1112 | outstring.close() |
---|
1113 | return returnstring |
---|
1114 | |
---|
1115 | def format_names(self,outstring,indent=0,loop_no=-1): |
---|
1116 | """Print datanames from `loop_no` one per line""" |
---|
1117 | temp_order = self.loops[loop_no][:] #copy |
---|
1118 | format_hints = dict([(i['dataname'],i) for i in self.formatting_hints if i['dataname'] in temp_order]) |
---|
1119 | while len(temp_order)>0: |
---|
1120 | itemname = temp_order.pop(0) |
---|
1121 | req_indent = format_hints.get(itemname,{}).get('name_pos',indent) |
---|
1122 | outstring.write(' ' * req_indent,do_tab=False) |
---|
1123 | outstring.write(self.true_case[itemname],do_tab=False) |
---|
1124 | outstring.write("\n",do_tab=False) |
---|
1125 | |
---|
1126 | def format_packets(self,outstring,indent=0,loop_no=-1): |
---|
1127 | import string |
---|
1128 | alldata = [self[a] for a in self.loops[loop_no]] |
---|
1129 | loopnames = self.loops[loop_no] |
---|
1130 | #print 'Alldata: %s' % `alldata` |
---|
1131 | packet_data = list(zip(*alldata)) |
---|
1132 | #print 'Packet data: %s' % `packet_data` |
---|
1133 | #create a dictionary for quick lookup of formatting requirements |
---|
1134 | format_hints = dict([(i['dataname'],i) for i in self.formatting_hints if i['dataname'] in loopnames]) |
---|
1135 | for position in range(len(packet_data)): |
---|
1136 | if position > 0: |
---|
1137 | outstring.write("\n") #new line each packet except first |
---|
1138 | for point in range(len(packet_data[position])): |
---|
1139 | datapoint = packet_data[position][point] |
---|
1140 | format_hint = format_hints.get(loopnames[point],{}) |
---|
1141 | packstring = self.format_packet_item(datapoint,indent,outstring,format_hint) |
---|
1142 | outstring.write(' ',canbreak=True,do_tab=False,delimiter=True) |
---|
1143 | |
---|
1144 | def format_packet_item(self,pack_item,indent,outstring,format_hint): |
---|
1145 | # print 'Formatting %s' % `pack_item` |
---|
1146 | # temporary check for any non-unicode items |
---|
1147 | if isinstance(pack_item,str) and not isinstance(pack_item,unicode): |
---|
1148 | raise StarError("Item {0!r} is not unicode".format(pack_item)) |
---|
1149 | if isinstance(pack_item,unicode): |
---|
1150 | delimiter = format_hint.get('delimiter',None) |
---|
1151 | startcol = format_hint.get('column',-1) |
---|
1152 | outstring.write(self._formatstring(pack_item,delimiter=delimiter),startcol=startcol) |
---|
1153 | else: |
---|
1154 | self.format_value(pack_item,outstring,hints = format_hint) |
---|
1155 | |
---|
1156 | def _formatstring(self,instring,delimiter=None,standard='CIF1',indent=0,hints={}): |
---|
1157 | import string |
---|
1158 | if hints.get("reformat",False) and "\n" in instring: |
---|
1159 | instring = "\n"+self.do_wrapping(instring,hints["reformat_indent"]) |
---|
1160 | allowed_delimiters = set(self.string_delimiters) |
---|
1161 | if len(instring)==0: allowed_delimiters.difference_update([None]) |
---|
1162 | if len(instring) > (self.maxoutlength-2) or '\n' in instring: |
---|
1163 | allowed_delimiters.intersection_update(["\n;","'''",'"""']) |
---|
1164 | if ' ' in instring or '\t' in instring or '\v' in instring or (len(instring)>0 and instring[0] in '_$#;([{') or ',' in instring: |
---|
1165 | allowed_delimiters.difference_update([None]) |
---|
1166 | if len(instring)>3 and (instring[:4].lower()=='data' or instring[:4].lower()=='save'): |
---|
1167 | allowed_delimiters.difference_update([None]) |
---|
1168 | if len(instring)>5 and instring[:6].lower()=='global': |
---|
1169 | allowed_delimiters.difference_update([None]) |
---|
1170 | if '"' in instring: allowed_delimiters.difference_update(['"',None]) |
---|
1171 | if "'" in instring: allowed_delimiters.difference_update(["'",None]) |
---|
1172 | out_delimiter = "\n;" #default (most conservative) |
---|
1173 | if delimiter in allowed_delimiters: |
---|
1174 | out_delimiter = delimiter |
---|
1175 | elif "'" in allowed_delimiters: out_delimiter = "'" |
---|
1176 | elif '"' in allowed_delimiters: out_delimiter = '"' |
---|
1177 | if out_delimiter in ['"',"'",'"""',"'''"]: return out_delimiter + instring + out_delimiter |
---|
1178 | elif out_delimiter is None: return instring |
---|
1179 | # we are left with semicolon strings |
---|
1180 | # use our protocols: |
---|
1181 | maxlinelength = max([len(a) for a in instring.split('\n')]) |
---|
1182 | if maxlinelength > self.maxoutlength: |
---|
1183 | protocol_string = apply_line_folding(instring) |
---|
1184 | else: |
---|
1185 | protocol_string = instring |
---|
1186 | # now check for embedded delimiters |
---|
1187 | if "\n;" in protocol_string: |
---|
1188 | prefix = "CIF:" |
---|
1189 | while prefix in protocol_string: prefix = prefix + ":" |
---|
1190 | protocol_string = apply_line_prefix(protocol_string,prefix+"> ") |
---|
1191 | return "\n;" + protocol_string + "\n;" |
---|
1192 | |
---|
1193 | def format_value(self,itemvalue,stringsink,compound=False,hints={}): |
---|
1194 | """Format a Star data value""" |
---|
1195 | global have_numpy |
---|
1196 | delimiter = hints.get('delimiter',None) |
---|
1197 | startcol = hints.get('column',-1) |
---|
1198 | if isinstance(itemvalue,str) and not isinstance(itemvalue,unicode): #not allowed |
---|
1199 | raise StarError("Non-unicode value {0} found in block".format(itemvalue)) |
---|
1200 | if isinstance(itemvalue,unicode): #need to sanitize |
---|
1201 | stringsink.write(self._formatstring(itemvalue,delimiter=delimiter,hints=hints),canbreak = True,startcol=startcol) |
---|
1202 | elif isinstance(itemvalue,(list)) or (hasattr(itemvalue,'dtype') and hasattr(itemvalue,'__iter__')): #numpy |
---|
1203 | stringsink.set_tab(0) |
---|
1204 | stringsink.write('[',canbreak=True,newindent=True,mustbreak=compound,startcol=startcol) |
---|
1205 | if len(itemvalue)>0: |
---|
1206 | self.format_value(itemvalue[0],stringsink) |
---|
1207 | for listval in itemvalue[1:]: |
---|
1208 | # print 'Formatting %s' % `listval` |
---|
1209 | stringsink.write(self.list_delimiter,do_tab=False) |
---|
1210 | self.format_value(listval,stringsink,compound=True) |
---|
1211 | stringsink.write(']',unindent=True) |
---|
1212 | elif isinstance(itemvalue,dict): |
---|
1213 | stringsink.set_tab(0) |
---|
1214 | stringsink.write('{',newindent=True,mustbreak=compound,startcol=startcol) #start a new line inside |
---|
1215 | items = list(itemvalue.items()) |
---|
1216 | if len(items)>0: |
---|
1217 | stringsink.write("'"+items[0][0]+"'"+':',canbreak=True) |
---|
1218 | self.format_value(items[0][1],stringsink) |
---|
1219 | for key,value in items[1:]: |
---|
1220 | stringsink.write(self.list_delimiter) |
---|
1221 | stringsink.write("'"+key+"'"+":",canbreak=True) |
---|
1222 | self.format_value(value,stringsink) #never break between key and value |
---|
1223 | stringsink.write('}',unindent=True) |
---|
1224 | elif isinstance(itemvalue,(float,int)) or \ |
---|
1225 | (have_numpy and isinstance(itemvalue,(numpy.number))): #TODO - handle uncertainties |
---|
1226 | stringsink.write(str(itemvalue),canbreak=True,startcol=startcol) #numbers |
---|
1227 | else: |
---|
1228 | raise ValueError('Value in unexpected format for output: %s' % repr( itemvalue )) |
---|
1229 | |
---|
1230 | def create_ordering(self,finish_at,start_from): |
---|
1231 | """Create a canonical ordering that includes loops using our formatting hints dictionary""" |
---|
1232 | requested_order = list([i['dataname'] for i in self.formatting_hints if i['dataname']!='loop']) |
---|
1233 | new_order = [] |
---|
1234 | for item in requested_order: |
---|
1235 | if isinstance(item,unicode) and item.lower() in self.item_order: |
---|
1236 | new_order.append(item.lower()) |
---|
1237 | elif item in self: #in a loop somewhere |
---|
1238 | target_loop = self.FindLoop(item) |
---|
1239 | if target_loop not in new_order: |
---|
1240 | new_order.append(target_loop) |
---|
1241 | # adjust loop name order |
---|
1242 | loopnames = self.loops[target_loop] |
---|
1243 | loop_order = [i for i in requested_order if i in loopnames] |
---|
1244 | unordered = [i for i in loopnames if i not in loop_order] |
---|
1245 | self.loops[target_loop] = loop_order + unordered |
---|
1246 | extras = list([i for i in self.item_order if i not in new_order]) |
---|
1247 | self.output_order = new_order + extras |
---|
1248 | # now handle partial output |
---|
1249 | if start_from != '': |
---|
1250 | if start_from in requested_order: |
---|
1251 | sfi = requested_order.index(start_from) |
---|
1252 | loop_order = [self.FindLoop(k) for k in requested_order[sfi:] if self.FindLoop(k)>0] |
---|
1253 | candidates = list([k for k in self.output_order if k in requested_order[sfi:]]) |
---|
1254 | cand_pos = len(new_order) |
---|
1255 | if len(candidates)>0: |
---|
1256 | cand_pos = self.output_order.index(candidates[0]) |
---|
1257 | if len(loop_order)>0: |
---|
1258 | cand_pos = min(cand_pos,self.output_order.index(loop_order[0])) |
---|
1259 | if cand_pos < len(self.output_order): |
---|
1260 | print('Output starts from %s, requested %s' % (self.output_order[cand_pos],start_from)) |
---|
1261 | self.output_order = self.output_order[cand_pos:] |
---|
1262 | else: |
---|
1263 | print('Start is beyond end of output list') |
---|
1264 | self.output_order = [] |
---|
1265 | elif start_from in extras: |
---|
1266 | self.output_order = self.output_order[self.output_order.index(start_from):] |
---|
1267 | else: |
---|
1268 | self.output_order = [] |
---|
1269 | if finish_at != '': |
---|
1270 | if finish_at in requested_order: |
---|
1271 | fai = requested_order.index(finish_at) |
---|
1272 | loop_order = list([self.FindLoop(k) for k in requested_order[fai:] if self.FindLoop(k)>0]) |
---|
1273 | candidates = list([k for k in self.output_order if k in requested_order[fai:]]) |
---|
1274 | cand_pos = len(new_order) |
---|
1275 | if len(candidates)>0: |
---|
1276 | cand_pos = self.output_order.index(candidates[0]) |
---|
1277 | if len(loop_order)>0: |
---|
1278 | cand_pos = min(cand_pos,self.output_order.index(loop_order[0])) |
---|
1279 | if cand_pos < len(self.output_order): |
---|
1280 | print('Output finishes before %s, requested before %s' % (self.output_order[cand_pos],finish_at)) |
---|
1281 | self.output_order = self.output_order[:cand_pos] |
---|
1282 | else: |
---|
1283 | print('All of block output') |
---|
1284 | elif finish_at in extras: |
---|
1285 | self.output_order = self.output_order[:self.output_order.index(finish_at)] |
---|
1286 | #print('Final order: ' + repr(self.output_order)) |
---|
1287 | |
---|
1288 | def convert_to_string(self,dataname): |
---|
1289 | """Convert values held in dataname value fork to string version""" |
---|
1290 | v,is_value = self.GetFullItemValue(dataname) |
---|
1291 | if not is_value: |
---|
1292 | return v |
---|
1293 | if check_stringiness(v): return v #already strings |
---|
1294 | # TODO...something else |
---|
1295 | return v |
---|
1296 | |
---|
1297 | def do_wrapping(self,instring,indent=3): |
---|
1298 | """Wrap the provided string""" |
---|
1299 | if " " in instring: #already formatted |
---|
1300 | return instring |
---|
1301 | self.wrapper.initial_indent = ' '*indent |
---|
1302 | self.wrapper.subsequent_indent = ' '*indent |
---|
1303 | # remove leading and trailing space |
---|
1304 | instring = instring.strip() |
---|
1305 | # split into paragraphs |
---|
1306 | paras = instring.split("\n\n") |
---|
1307 | wrapped_paras = [self.wrapper.fill(p) for p in paras] |
---|
1308 | return "\n".join(wrapped_paras) |
---|
1309 | |
---|
1310 | |
---|
1311 | def merge(self,new_block,mode="strict",match_att=[],match_function=None, |
---|
1312 | rel_keys = []): |
---|
1313 | if mode == 'strict': |
---|
1314 | for key in new_block.keys(): |
---|
1315 | if key in self and key not in match_att: |
---|
1316 | raise StarError( "Identical keys %s in strict merge mode" % key) |
---|
1317 | elif key not in match_att: #a new dataname |
---|
1318 | self[key] = new_block[key] |
---|
1319 | # we get here if there are no keys in common, so we can now copy |
---|
1320 | # the loops and not worry about overlaps |
---|
1321 | for one_loop in new_block.loops.values(): |
---|
1322 | self.CreateLoop(one_loop) |
---|
1323 | # we have lost case information |
---|
1324 | self.true_case.update(new_block.true_case) |
---|
1325 | elif mode == 'replace': |
---|
1326 | newkeys = list(new_block.keys()) |
---|
1327 | for ma in match_att: |
---|
1328 | try: |
---|
1329 | newkeys.remove(ma) #don't touch the special ones |
---|
1330 | except ValueError: |
---|
1331 | pass |
---|
1332 | for key in new_block.keys(): |
---|
1333 | if isinstance(key,unicode): |
---|
1334 | self[key] = new_block[key] |
---|
1335 | # creating the loop will remove items from other loops |
---|
1336 | for one_loop in new_block.loops.values(): |
---|
1337 | self.CreateLoop(one_loop) |
---|
1338 | # we have lost case information |
---|
1339 | self.true_case.update(new_block.true_case) |
---|
1340 | elif mode == 'overlay': |
---|
1341 | print('Overlay mode, current overwrite is %s' % self.overwrite) |
---|
1342 | raise StarError('Overlay block merge mode not implemented') |
---|
1343 | save_overwrite = self.overwrite |
---|
1344 | self.overwrite = True |
---|
1345 | for attribute in new_block.keys(): |
---|
1346 | if attribute in match_att: continue #ignore this one |
---|
1347 | new_value = new_block[attribute] |
---|
1348 | #non-looped items |
---|
1349 | if new_block.FindLoop(attribute)<0: #not looped |
---|
1350 | self[attribute] = new_value |
---|
1351 | my_loops = self.loops.values() |
---|
1352 | perfect_overlaps = [a for a in new_block.loops if a in my_loops] |
---|
1353 | for po in perfect_overlaps: |
---|
1354 | loop_keys = [a for a in po if a in rel_keys] #do we have a key? |
---|
1355 | try: |
---|
1356 | newkeypos = map(lambda a:newkeys.index(a),loop_keys) |
---|
1357 | newkeypos = newkeypos[0] #one key per loop for now |
---|
1358 | loop_keys = loop_keys[0] |
---|
1359 | except (ValueError,IndexError): |
---|
1360 | newkeypos = [] |
---|
1361 | overlap_data = map(lambda a:listify(self[a]),overlaps) #old packet data |
---|
1362 | new_data = map(lambda a:new_block[a],overlaps) #new packet data |
---|
1363 | packet_data = transpose(overlap_data) |
---|
1364 | new_p_data = transpose(new_data) |
---|
1365 | # remove any packets for which the keys match between old and new; we |
---|
1366 | # make the arbitrary choice that the old data stays |
---|
1367 | if newkeypos: |
---|
1368 | # get matching values in new list |
---|
1369 | print("Old, new data:\n%s\n%s" % (repr(overlap_data[newkeypos]),repr(new_data[newkeypos]))) |
---|
1370 | key_matches = filter(lambda a:a in overlap_data[newkeypos],new_data[newkeypos]) |
---|
1371 | # filter out any new data with these key values |
---|
1372 | new_p_data = filter(lambda a:a[newkeypos] not in key_matches,new_p_data) |
---|
1373 | if new_p_data: |
---|
1374 | new_data = transpose(new_p_data) |
---|
1375 | else: new_data = [] |
---|
1376 | # wipe out the old data and enter the new stuff |
---|
1377 | byebyeloop = self.GetLoop(overlaps[0]) |
---|
1378 | # print "Removing '%s' with overlaps '%s'" % (`byebyeloop`,`overlaps`) |
---|
1379 | # Note that if, in the original dictionary, overlaps are not |
---|
1380 | # looped, GetLoop will return the block itself. So we check |
---|
1381 | # for this case... |
---|
1382 | if byebyeloop != self: |
---|
1383 | self.remove_loop(byebyeloop) |
---|
1384 | self.AddLoopItem((overlaps,overlap_data)) #adding old packets |
---|
1385 | for pd in new_p_data: #adding new packets |
---|
1386 | if pd not in packet_data: |
---|
1387 | for i in range(len(overlaps)): |
---|
1388 | #don't do this at home; we are appending |
---|
1389 | #to something in place |
---|
1390 | self[overlaps[i]].append(pd[i]) |
---|
1391 | self.overwrite = save_overwrite |
---|
1392 | |
---|
1393 | def assign_dictionary(self,dic): |
---|
1394 | if not dic.diclang=="DDLm": |
---|
1395 | print("Warning: ignoring dictionary %s" % dic.dic_as_cif.my_uri) |
---|
1396 | return |
---|
1397 | self.dictionary = dic |
---|
1398 | |
---|
1399 | def unassign_dictionary(self): |
---|
1400 | """Remove dictionary-dependent behaviour""" |
---|
1401 | self.dictionary = None |
---|
1402 | |
---|
1403 | |
---|
1404 | |
---|
1405 | class StarPacket(list): |
---|
1406 | def merge_packet(self,incoming): |
---|
1407 | """Merge contents of incoming packet with this packet""" |
---|
1408 | new_attrs = [a for a in dir(incoming) if a[0] == '_' and a[1] != "_"] |
---|
1409 | self.extend(incoming) |
---|
1410 | for na in new_attrs: |
---|
1411 | setattr(self,na,getattr(incoming,na)) |
---|
1412 | |
---|
1413 | def __getattr__(self,att_name): |
---|
1414 | """Derive a missing attribute""" |
---|
1415 | if att_name.lower() in self.__dict__: |
---|
1416 | return getattr(self,att_name.lower()) |
---|
1417 | if att_name in ('cif_dictionary','fulldata','key'): |
---|
1418 | raise AttributeError('Programming error: can only assign value of %s' % att_name) |
---|
1419 | d = self.cif_dictionary |
---|
1420 | c = self.fulldata |
---|
1421 | k = self.key |
---|
1422 | assert isinstance(k,list) |
---|
1423 | d.derive_item(att_name,c,store_value=True) |
---|
1424 | # |
---|
1425 | # now pick out the new value |
---|
1426 | # self.key is a list of the key values |
---|
1427 | keydict = dict([(v,(getattr(self,v),True)) for v in k]) |
---|
1428 | full_pack = c.GetCompoundKeyedPacket(keydict) |
---|
1429 | return getattr(full_pack,att_name) |
---|
1430 | |
---|
1431 | class BlockCollection(object): |
---|
1432 | """A container for StarBlock objects. The constructor takes |
---|
1433 | one non-keyword argument `datasource` to set the initial data. If |
---|
1434 | `datasource` is a Python dictionary, the values must be `StarBlock` |
---|
1435 | objects and the keys will be blocknames in the new object. Keyword |
---|
1436 | arguments: |
---|
1437 | |
---|
1438 | standard: |
---|
1439 | `CIF` or `Dic`. `CIF` enforces 75-character blocknames, and will |
---|
1440 | print block contents before that block's save frame. |
---|
1441 | |
---|
1442 | blocktype: |
---|
1443 | The type of blocks held in this container. Normally `StarBlock` |
---|
1444 | or `CifBlock`. |
---|
1445 | |
---|
1446 | characterset: |
---|
1447 | `ascii` or `unicode`. Blocknames and datanames appearing within |
---|
1448 | blocks are restricted to the appropriate characterset. Note that |
---|
1449 | only characters in the basic multilingual plane are accepted. This |
---|
1450 | restriction will be lifted when PyCIFRW is ported to Python3. |
---|
1451 | |
---|
1452 | scoping: |
---|
1453 | `instance` or `dictionary`: `instance` implies that save frames are |
---|
1454 | hidden from save frames lower in the hierarchy or in sibling |
---|
1455 | hierarchies. `dictionary` makes all save frames visible everywhere |
---|
1456 | within a data block. This setting is only relevant for STAR2 dictionaries and |
---|
1457 | STAR2 data files, as save frames are currently not used in plain CIF data |
---|
1458 | files. |
---|
1459 | |
---|
1460 | """ |
---|
1461 | def __init__(self,datasource=None,standard='CIF',blocktype = StarBlock, |
---|
1462 | characterset='ascii',scoping='instance',**kwargs): |
---|
1463 | import collections |
---|
1464 | self.dictionary = {} |
---|
1465 | self.standard = standard |
---|
1466 | self.lower_keys = set() # short_cuts |
---|
1467 | self.renamed = {} |
---|
1468 | self.PC = collections.namedtuple('PC',['block_id','parent']) |
---|
1469 | self.child_table = {} |
---|
1470 | self.visible_keys = [] # for efficiency |
---|
1471 | self.block_input_order = [] # to output in same order |
---|
1472 | self.scoping = scoping #will trigger setting of child table |
---|
1473 | self.blocktype = blocktype |
---|
1474 | self.master_template = {} #for outputting |
---|
1475 | self.set_grammar('2.0') |
---|
1476 | self.set_characterset(characterset) |
---|
1477 | if isinstance(datasource,BlockCollection): |
---|
1478 | self.merge_fast(datasource) |
---|
1479 | self.scoping = scoping #reset visibility |
---|
1480 | elif isinstance(datasource,dict): |
---|
1481 | for key,value in datasource.items(): |
---|
1482 | self[key]= value |
---|
1483 | self.header_comment = '' |
---|
1484 | |
---|
1485 | def set_grammar(self,new_grammar): |
---|
1486 | """Set the syntax and grammar for output to `new_grammar`""" |
---|
1487 | if new_grammar not in ['1.1','1.0','2.0','STAR2']: |
---|
1488 | raise StarError('Unrecognised output grammar %s' % new_grammar) |
---|
1489 | self.grammar = new_grammar |
---|
1490 | |
---|
1491 | def set_characterset(self,characterset): |
---|
1492 | """Set the allowed characters for datanames and datablocks: may be `ascii` or `unicode`. If datanames |
---|
1493 | have already been added to any datablocks, they are not checked.""" |
---|
1494 | self.characterset = characterset |
---|
1495 | for one_block in self.lower_keys: |
---|
1496 | self[one_block].set_characterset(characterset) |
---|
1497 | |
---|
1498 | def unlock(self): |
---|
1499 | """Allow overwriting of all blocks in this collection""" |
---|
1500 | for a in self.lower_keys: |
---|
1501 | self[a].overwrite=True |
---|
1502 | |
---|
1503 | def lock(self): |
---|
1504 | """Disallow overwriting for all blocks in this collection""" |
---|
1505 | for a in self.lower_keys: |
---|
1506 | self[a].overwrite = False |
---|
1507 | |
---|
1508 | def __str__(self): |
---|
1509 | return self.WriteOut() |
---|
1510 | |
---|
1511 | def __setitem__(self,key,value): |
---|
1512 | self.NewBlock(key,value,parent=None) |
---|
1513 | |
---|
1514 | def __getitem__(self,key): |
---|
1515 | if isinstance(key,(unicode,str)): |
---|
1516 | lowerkey = key.lower() |
---|
1517 | if lowerkey in self.lower_keys: |
---|
1518 | return self.dictionary[lowerkey] |
---|
1519 | #print 'Visible keys:' + `self.visible_keys` |
---|
1520 | #print 'All keys' + `self.lower_keys` |
---|
1521 | #print 'Child table' + `self.child_table` |
---|
1522 | raise KeyError('No such item %s' % key) |
---|
1523 | |
---|
1524 | # we have to get an ordered list of the current keys, |
---|
1525 | # as we'll have to delete one of them anyway. |
---|
1526 | # Deletion will delete any key regardless of visibility |
---|
1527 | |
---|
1528 | def __delitem__(self,key): |
---|
1529 | dummy = self[key] #raise error if not present |
---|
1530 | lowerkey = key.lower() |
---|
1531 | # get rid of all children recursively as well |
---|
1532 | children = [a[0] for a in self.child_table.items() if a[1].parent == lowerkey] |
---|
1533 | for child in children: |
---|
1534 | del self[child] #recursive call |
---|
1535 | del self.dictionary[lowerkey] |
---|
1536 | del self.child_table[lowerkey] |
---|
1537 | try: |
---|
1538 | self.visible_keys.remove(lowerkey) |
---|
1539 | except KeyError: |
---|
1540 | pass |
---|
1541 | self.lower_keys.remove(lowerkey) |
---|
1542 | self.block_input_order.remove(lowerkey) |
---|
1543 | |
---|
1544 | def __len__(self): |
---|
1545 | return len(self.visible_keys) |
---|
1546 | |
---|
1547 | def __contains__(self,item): |
---|
1548 | """Support the 'in' operator""" |
---|
1549 | if not isinstance(item,(unicode,str)): return False |
---|
1550 | if item.lower() in self.visible_keys: |
---|
1551 | return True |
---|
1552 | return False |
---|
1553 | |
---|
1554 | # We iterate over all visible |
---|
1555 | def __iter__(self): |
---|
1556 | for one_block in self.keys(): |
---|
1557 | yield self[one_block] |
---|
1558 | |
---|
1559 | # TODO: handle different case |
---|
1560 | def keys(self): |
---|
1561 | return self.visible_keys |
---|
1562 | |
---|
1563 | # Note that has_key does not exist in 3.5 |
---|
1564 | def has_key(self,key): |
---|
1565 | return key in self |
---|
1566 | |
---|
1567 | def get(self,key,default=None): |
---|
1568 | if key in self: # take account of case |
---|
1569 | return self.__getitem__(key) |
---|
1570 | else: |
---|
1571 | return default |
---|
1572 | |
---|
1573 | def clear(self): |
---|
1574 | self.dictionary.clear() |
---|
1575 | self.lower_keys = set() |
---|
1576 | self.child_table = {} |
---|
1577 | self.visible_keys = [] |
---|
1578 | self.block_input_order = [] |
---|
1579 | |
---|
1580 | def copy(self): |
---|
1581 | newcopy = self.dictionary.copy() #all blocks |
---|
1582 | for k,v in self.dictionary.items(): |
---|
1583 | newcopy[k] = v.copy() |
---|
1584 | newcopy = BlockCollection(newcopy) |
---|
1585 | newcopy.child_table = self.child_table.copy() |
---|
1586 | newcopy.lower_keys = self.lower_keys.copy() |
---|
1587 | newcopy.block_input_order = self.block_input_order.copy() |
---|
1588 | newcopy.characterset = self.characterset |
---|
1589 | newcopy.SetTemplate(self.master_template.copy()) |
---|
1590 | newcopy.scoping = self.scoping #this sets visible keys |
---|
1591 | return newcopy |
---|
1592 | |
---|
1593 | def update(self,adict): |
---|
1594 | for key in adict.keys(): |
---|
1595 | self[key] = adict[key] |
---|
1596 | |
---|
1597 | def items(self): |
---|
1598 | return [(a,self[a]) for a in self.keys()] |
---|
1599 | |
---|
1600 | def first_block(self): |
---|
1601 | """Return the 'first' block. This is not necessarily the first block in the file.""" |
---|
1602 | if self.keys(): |
---|
1603 | return self[self.keys()[0]] |
---|
1604 | |
---|
1605 | def NewBlock(self,blockname,blockcontents=None,fix=True,parent=None): |
---|
1606 | """Add a new block named `blockname` with contents `blockcontents`. If `fix` |
---|
1607 | is True, `blockname` will have spaces and tabs replaced by underscores. `parent` |
---|
1608 | allows a parent block to be set so that block hierarchies can be created. Depending on |
---|
1609 | the output standard, these blocks will be printed out as nested save frames or |
---|
1610 | ignored.""" |
---|
1611 | if blockcontents is None: |
---|
1612 | blockcontents = StarBlock() |
---|
1613 | if self.standard == "CIF": |
---|
1614 | blockcontents.setmaxnamelength(75) |
---|
1615 | if len(blockname)>75: |
---|
1616 | raise StarError('Blockname %s is longer than 75 characters' % blockname) |
---|
1617 | if fix: |
---|
1618 | newblockname = re.sub('[ \t]','_',blockname) |
---|
1619 | else: newblockname = blockname |
---|
1620 | new_lowerbn = newblockname.lower() |
---|
1621 | if new_lowerbn in self.lower_keys: #already there |
---|
1622 | if self.standard is not None: |
---|
1623 | toplevelnames = [a[0] for a in self.child_table.items() if a[1].parent==None] |
---|
1624 | if parent is None and new_lowerbn not in toplevelnames: #can give a new key to this one |
---|
1625 | while new_lowerbn in self.lower_keys: new_lowerbn = new_lowerbn + '+' |
---|
1626 | elif parent is not None and new_lowerbn in toplevelnames: #can fix a different one |
---|
1627 | replace_name = new_lowerbn |
---|
1628 | while replace_name in self.lower_keys: replace_name = replace_name + '+' |
---|
1629 | self._rekey(new_lowerbn,replace_name) |
---|
1630 | # now continue on to add in the new block |
---|
1631 | if parent.lower() == new_lowerbn: #the new block's requested parent just got renamed!! |
---|
1632 | parent = replace_name |
---|
1633 | else: |
---|
1634 | raise StarError( "Attempt to replace existing block " + blockname) |
---|
1635 | else: |
---|
1636 | del self[new_lowerbn] |
---|
1637 | self.dictionary.update({new_lowerbn:blockcontents}) |
---|
1638 | self.lower_keys.add(new_lowerbn) |
---|
1639 | self.block_input_order.append(new_lowerbn) |
---|
1640 | if parent is None: |
---|
1641 | self.child_table[new_lowerbn]=self.PC(newblockname,None) |
---|
1642 | self.visible_keys.append(new_lowerbn) |
---|
1643 | else: |
---|
1644 | if parent.lower() in self.lower_keys: |
---|
1645 | if self.scoping == 'instance': |
---|
1646 | self.child_table[new_lowerbn]=self.PC(newblockname,parent.lower()) |
---|
1647 | else: |
---|
1648 | self.child_table[new_lowerbn]=self.PC(newblockname,parent.lower()) |
---|
1649 | self.visible_keys.append(new_lowerbn) |
---|
1650 | else: |
---|
1651 | print('Warning:Parent block %s does not exist for child %s' % (parent,newblockname)) |
---|
1652 | self[new_lowerbn].set_grammar(self.grammar) |
---|
1653 | self[new_lowerbn].set_characterset(self.characterset) |
---|
1654 | self[new_lowerbn].formatting_hints = self.master_template |
---|
1655 | return new_lowerbn #in case calling routine wants to know |
---|
1656 | |
---|
1657 | def _rekey(self,oldname,newname,block_id=''): |
---|
1658 | """The block with key [[oldname]] gets [[newname]] as a new key, but the printed name |
---|
1659 | does not change unless [[block_id]] is given. Prefer [[rename]] for a safe version.""" |
---|
1660 | move_block = self[oldname] #old block |
---|
1661 | is_visible = oldname in self.visible_keys |
---|
1662 | move_block_info = self.child_table[oldname] #old info |
---|
1663 | move_block_children = [a for a in self.child_table.items() if a[1].parent==oldname] |
---|
1664 | # now rewrite the necessary bits |
---|
1665 | self.child_table.update(dict([(a[0],self.PC(a[1].block_id,newname)) for a in move_block_children])) |
---|
1666 | oldpos = self.block_input_order.index(oldname) |
---|
1667 | del self[oldname] #do this after updating child table so we don't delete children |
---|
1668 | self.dictionary.update({newname:move_block}) |
---|
1669 | self.lower_keys.add(newname) |
---|
1670 | #print 'Block input order was: ' + `self.block_input_order` |
---|
1671 | self.block_input_order[oldpos:oldpos]=[newname] |
---|
1672 | if block_id == '': |
---|
1673 | self.child_table.update({newname:move_block_info}) |
---|
1674 | else: |
---|
1675 | self.child_table.update({newname:self.PC(block_id,move_block_info.parent)}) |
---|
1676 | if is_visible: self.visible_keys += [newname] |
---|
1677 | |
---|
1678 | def rename(self,oldname,newname): |
---|
1679 | """Rename datablock from [[oldname]] to [[newname]]. Both key and printed name are changed. No |
---|
1680 | conformance checks are conducted.""" |
---|
1681 | realoldname = oldname.lower() |
---|
1682 | realnewname = newname.lower() |
---|
1683 | if realnewname in self.lower_keys: |
---|
1684 | raise StarError('Cannot change blockname %s to %s as %s already present' % (oldname,newname,newname)) |
---|
1685 | if realoldname not in self.lower_keys: |
---|
1686 | raise KeyError('Cannot find old block %s' % realoldname) |
---|
1687 | self._rekey(realoldname,realnewname,block_id=newname) |
---|
1688 | |
---|
1689 | def makebc(self,namelist,scoping='dictionary'): |
---|
1690 | """Make a block collection from a list of block names""" |
---|
1691 | newbc = BlockCollection() |
---|
1692 | block_lower = [n.lower() for n in namelist] |
---|
1693 | proto_child_table = [a for a in self.child_table.items() if a[0] in block_lower] |
---|
1694 | newbc.child_table = dict(proto_child_table) |
---|
1695 | new_top_level = [(a[0],self.PC(a[1].block_id,None)) for a in newbc.child_table.items() if a[1].parent not in block_lower] |
---|
1696 | newbc.child_table.update(dict(new_top_level)) |
---|
1697 | newbc.lower_keys = set([a[0] for a in proto_child_table]) |
---|
1698 | newbc.dictionary = dict((a[0],self.dictionary[a[0]]) for a in proto_child_table) |
---|
1699 | newbc.scoping = scoping |
---|
1700 | newbc.block_input_order = block_lower |
---|
1701 | return newbc |
---|
1702 | |
---|
1703 | |
---|
1704 | def merge_fast(self,new_bc,parent=None): |
---|
1705 | """Do a fast merge. WARNING: this may change one or more of its frame headers in order to |
---|
1706 | remove duplicate frames. Please keep a handle to the block object instead of the text of |
---|
1707 | the header.""" |
---|
1708 | if self.standard is None: |
---|
1709 | mode = 'replace' |
---|
1710 | else: |
---|
1711 | mode = 'strict' |
---|
1712 | overlap_flag = not self.lower_keys.isdisjoint(new_bc.lower_keys) |
---|
1713 | if parent is not None: |
---|
1714 | parent_name = [a[0] for a in self.dictionary.items() if a[1] == parent] |
---|
1715 | if len(parent_name)==0 or len(parent_name)>1: |
---|
1716 | raise StarError("Unable to find unique parent block name: have %s" % str(parent_name)) |
---|
1717 | parent_name = parent_name[0] |
---|
1718 | else: |
---|
1719 | parent_name = None #an error will be thrown if we treat as a string |
---|
1720 | if overlap_flag and mode != 'replace': |
---|
1721 | double_keys = self.lower_keys.intersection(new_bc.lower_keys) |
---|
1722 | for dup_key in double_keys: |
---|
1723 | our_parent = self.child_table[dup_key].parent |
---|
1724 | their_parent = new_bc.child_table[dup_key].parent |
---|
1725 | if (our_parent is None and their_parent is not None and parent is None) or\ |
---|
1726 | parent is not None: #rename our block |
---|
1727 | start_key = dup_key |
---|
1728 | while start_key in self.lower_keys: start_key = start_key+'+' |
---|
1729 | self._rekey(dup_key,start_key) |
---|
1730 | if parent_name.lower() == dup_key: #we just renamed the prospective parent! |
---|
1731 | parent_name = start_key |
---|
1732 | elif our_parent is not None and their_parent is None and parent is None: |
---|
1733 | start_key = dup_key |
---|
1734 | while start_key in new_bc.lower_keys: start_key = start_key+'+' |
---|
1735 | new_bc._rekey(dup_key,start_key) |
---|
1736 | else: |
---|
1737 | raise StarError("In strict merge mode:duplicate keys %s" % dup_key) |
---|
1738 | self.dictionary.update(new_bc.dictionary) |
---|
1739 | self.lower_keys.update(new_bc.lower_keys) |
---|
1740 | self.visible_keys += (list(new_bc.lower_keys)) |
---|
1741 | self.block_input_order += new_bc.block_input_order |
---|
1742 | #print('Block input order now:' + repr(self.block_input_order)) |
---|
1743 | self.child_table.update(new_bc.child_table) |
---|
1744 | if parent_name is not None: #redo the child_table entries |
---|
1745 | reparent_list = [(a[0],a[1].block_id) for a in new_bc.child_table.items() if a[1].parent==None] |
---|
1746 | reparent_dict = [(a[0],self.PC(a[1],parent_name.lower())) for a in reparent_list] |
---|
1747 | self.child_table.update(dict(reparent_dict)) |
---|
1748 | |
---|
1749 | def merge(self,new_bc,mode=None,parent=None,single_block=[], |
---|
1750 | idblock="",match_att=[],match_function=None): |
---|
1751 | if mode is None: |
---|
1752 | if self.standard is None: |
---|
1753 | mode = 'replace' |
---|
1754 | else: |
---|
1755 | mode = 'strict' |
---|
1756 | if single_block: |
---|
1757 | self[single_block[0]].merge(new_bc[single_block[1]],mode, |
---|
1758 | match_att=match_att, |
---|
1759 | match_function=match_function) |
---|
1760 | return None |
---|
1761 | base_keys = [a[1].block_id for a in self.child_table.items()] |
---|
1762 | block_to_item = base_keys #default |
---|
1763 | new_keys = [a[1].block_id for a in new_bc.child_table.items()] #get list of incoming blocks |
---|
1764 | if match_att: |
---|
1765 | #make a blockname -> item name map |
---|
1766 | if match_function: |
---|
1767 | block_to_item = [match_function(self[a]) for a in self.keys()] |
---|
1768 | else: |
---|
1769 | block_to_item = [self[a].get(match_att[0],None) for a in self.keys()] |
---|
1770 | #print `block_to_item` |
---|
1771 | for key in new_keys: #run over incoming blocknames |
---|
1772 | if key == idblock: continue #skip dictionary id |
---|
1773 | basekey = key #default value |
---|
1774 | if len(match_att)>0: |
---|
1775 | attval = new_bc[key].get(match_att[0],0) #0 if ignoring matching |
---|
1776 | else: |
---|
1777 | attval = 0 |
---|
1778 | for ii in range(len(block_to_item)): #do this way to get looped names |
---|
1779 | thisatt = block_to_item[ii] #keyname in old block |
---|
1780 | #print "Looking for %s in %s" % (attval,thisatt) |
---|
1781 | if attval == thisatt or \ |
---|
1782 | (isinstance(thisatt,list) and attval in thisatt): |
---|
1783 | basekey = base_keys.pop(ii) |
---|
1784 | block_to_item.remove(thisatt) |
---|
1785 | break |
---|
1786 | if not basekey in self or mode=="replace": |
---|
1787 | new_parent = new_bc.get_parent(key) |
---|
1788 | if parent is not None and new_parent is None: |
---|
1789 | new_parent = parent |
---|
1790 | self.NewBlock(basekey,new_bc[key],parent=new_parent) #add the block |
---|
1791 | else: |
---|
1792 | if mode=="strict": |
---|
1793 | raise StarError( "In strict merge mode: block %s in old and block %s in new files" % (basekey,key)) |
---|
1794 | elif mode=="overlay": |
---|
1795 | # print "Merging block %s with %s" % (basekey,key) |
---|
1796 | self[basekey].merge(new_bc[key],mode,match_att=match_att) |
---|
1797 | else: |
---|
1798 | raise StarError( "Merge called with unknown mode %s" % mode) |
---|
1799 | |
---|
1800 | def checknamelengths(self,target_block,maxlength=-1): |
---|
1801 | if maxlength < 0: |
---|
1802 | return |
---|
1803 | else: |
---|
1804 | toolong = [a for a in target_block.keys() if len(a)>maxlength] |
---|
1805 | outstring = "" |
---|
1806 | if toolong: |
---|
1807 | outstring = "\n".join(toolong) |
---|
1808 | raise StarError( 'Following data names too long:' + outstring) |
---|
1809 | |
---|
1810 | def get_all(self,item_name): |
---|
1811 | raw_values = [self[a].get(item_name) for a in self.keys()] |
---|
1812 | raw_values = [a for a in raw_values if a != None] |
---|
1813 | ret_vals = [] |
---|
1814 | for rv in raw_values: |
---|
1815 | if isinstance(rv,list): |
---|
1816 | for rvv in rv: |
---|
1817 | if rvv not in ret_vals: ret_vals.append(rvv) |
---|
1818 | else: |
---|
1819 | if rv not in ret_vals: ret_vals.append(rv) |
---|
1820 | return ret_vals |
---|
1821 | |
---|
1822 | def __setattr__(self,attr_name,newval): |
---|
1823 | if attr_name == 'scoping': |
---|
1824 | if newval not in ('dictionary','instance'): |
---|
1825 | raise StarError("Star file may only have 'dictionary' or 'instance' scoping, not %s" % newval) |
---|
1826 | if newval == 'dictionary': |
---|
1827 | self.visible_keys = [a for a in self.lower_keys] |
---|
1828 | else: |
---|
1829 | #only top-level datablocks visible |
---|
1830 | self.visible_keys = [a[0] for a in self.child_table.items() if a[1].parent==None] |
---|
1831 | object.__setattr__(self,attr_name,newval) |
---|
1832 | |
---|
1833 | def get_parent(self,blockname): |
---|
1834 | """Return the name of the block enclosing [[blockname]] in canonical form (lower case)""" |
---|
1835 | possibles = (a for a in self.child_table.items() if a[0] == blockname.lower()) |
---|
1836 | try: |
---|
1837 | first = next(possibles) #get first one |
---|
1838 | except: |
---|
1839 | raise StarError('no parent for %s' % blockname) |
---|
1840 | try: |
---|
1841 | second = next(possibles) |
---|
1842 | except StopIteration: |
---|
1843 | return first[1].parent |
---|
1844 | raise StarError('More than one parent for %s' % blockname) |
---|
1845 | |
---|
1846 | def get_roots(self): |
---|
1847 | """Get the top-level blocks""" |
---|
1848 | return [a for a in self.child_table.items() if a[1].parent==None] |
---|
1849 | |
---|
1850 | def get_children(self,blockname,include_parent=False,scoping='dictionary'): |
---|
1851 | """Get all children of [[blockname]] as a block collection. If [[include_parent]] is |
---|
1852 | True, the parent block will also be included in the block collection as the root.""" |
---|
1853 | newbc = BlockCollection() |
---|
1854 | block_lower = blockname.lower() |
---|
1855 | proto_child_table = [a for a in self.child_table.items() if self.is_child_of_parent(block_lower,a[1].block_id)] |
---|
1856 | newbc.child_table = dict(proto_child_table) |
---|
1857 | if not include_parent: |
---|
1858 | newbc.child_table.update(dict([(a[0],self.PC(a[1].block_id,None)) for a in proto_child_table if a[1].parent == block_lower])) |
---|
1859 | newbc.lower_keys = set([a[0] for a in proto_child_table]) |
---|
1860 | newbc.dictionary = dict((a[0],self.dictionary[a[0]]) for a in proto_child_table) |
---|
1861 | if include_parent: |
---|
1862 | newbc.child_table.update({block_lower:self.PC(self.child_table[block_lower].block_id,None)}) |
---|
1863 | newbc.lower_keys.add(block_lower) |
---|
1864 | newbc.dictionary.update({block_lower:self.dictionary[block_lower]}) |
---|
1865 | newbc.scoping = scoping |
---|
1866 | return newbc |
---|
1867 | |
---|
1868 | def get_immediate_children(self,parentname): |
---|
1869 | """Get the next level of children of the given block as a list, without nested levels""" |
---|
1870 | child_handles = [a for a in self.child_table.items() if a[1].parent == parentname.lower()] |
---|
1871 | return child_handles |
---|
1872 | |
---|
1873 | # This takes time |
---|
1874 | def get_child_list(self,parentname): |
---|
1875 | """Get a list of all child categories in alphabetical order""" |
---|
1876 | child_handles = [a[0] for a in self.child_table.items() if self.is_child_of_parent(parentname.lower(),a[0])] |
---|
1877 | child_handles.sort() |
---|
1878 | return child_handles |
---|
1879 | |
---|
1880 | def is_child_of_parent(self,parentname,blockname): |
---|
1881 | """Return `True` if `blockname` is a child of `parentname`""" |
---|
1882 | checkname = parentname.lower() |
---|
1883 | more_children = [a[0] for a in self.child_table.items() if a[1].parent == checkname] |
---|
1884 | if blockname.lower() in more_children: |
---|
1885 | return True |
---|
1886 | else: |
---|
1887 | for one_child in more_children: |
---|
1888 | if self.is_child_of_parent(one_child,blockname): return True |
---|
1889 | return False |
---|
1890 | |
---|
1891 | def set_parent(self,parentname,childname): |
---|
1892 | """Set the parent block""" |
---|
1893 | # first check that both blocks exist |
---|
1894 | if parentname.lower() not in self.lower_keys: |
---|
1895 | raise KeyError('Parent block %s does not exist' % parentname) |
---|
1896 | if childname.lower() not in self.lower_keys: |
---|
1897 | raise KeyError('Child block %s does not exist' % childname) |
---|
1898 | old_entry = self.child_table[childname.lower()] |
---|
1899 | self.child_table[childname.lower()]=self.PC(old_entry.block_id, |
---|
1900 | parentname.lower()) |
---|
1901 | self.scoping = self.scoping #reset visibility |
---|
1902 | |
---|
1903 | def SetTemplate(self,template_file): |
---|
1904 | """Use `template_file` as a template for all block output""" |
---|
1905 | self.master_template = process_template(template_file) |
---|
1906 | for b in self.dictionary.values(): |
---|
1907 | b.formatting_hints = self.master_template |
---|
1908 | |
---|
1909 | def WriteOut(self,comment='',wraplength=80,maxoutlength=0,blockorder=None,saves_after=None): |
---|
1910 | """Return the contents of this file as a string, wrapping if possible at `wraplength` |
---|
1911 | characters and restricting maximum line length to `maxoutlength`. Delimiters and |
---|
1912 | save frame nesting are controlled by `self.grammar`. If `blockorder` is |
---|
1913 | provided, blocks are output in this order unless nested save frames have been |
---|
1914 | requested (STAR2). The default block order is the order in which blocks were input. |
---|
1915 | `saves_after` inserts all save frames after the given dataname, |
---|
1916 | which allows less important items to appear later. Useful in conjunction with a |
---|
1917 | template for dictionary files.""" |
---|
1918 | if maxoutlength != 0: |
---|
1919 | self.SetOutputLength(maxoutlength) |
---|
1920 | if not comment: |
---|
1921 | comment = self.header_comment |
---|
1922 | outstring = StringIO() |
---|
1923 | if self.grammar == "2.0" and comment[0:10] != r"#\#CIF_2.0": |
---|
1924 | outstring.write(r"#\#CIF_2.0" + "\n") |
---|
1925 | outstring.write(comment) |
---|
1926 | # prepare all blocks |
---|
1927 | for b in self.dictionary.values(): |
---|
1928 | b.set_grammar(self.grammar) |
---|
1929 | b.formatting_hints = self.master_template |
---|
1930 | b.SetOutputLength(wraplength,self.maxoutlength) |
---|
1931 | # loop over top-level |
---|
1932 | # monitor output |
---|
1933 | all_names = list(self.child_table.keys()) #i.e. lower case |
---|
1934 | if blockorder is None: |
---|
1935 | blockorder = self.block_input_order |
---|
1936 | top_block_names = [(a,self.child_table[a].block_id) for a in blockorder if self.child_table[a].parent is None] |
---|
1937 | for blockref,blockname in top_block_names: |
---|
1938 | print('Writing %s, ' % blockname + repr(self[blockref])) |
---|
1939 | outstring.write('\n' + 'data_' +blockname+'\n') |
---|
1940 | all_names.remove(blockref) |
---|
1941 | if self.standard == 'Dic': #put contents before save frames |
---|
1942 | outstring.write(self[blockref].printsection(finish_at='_dictionary_valid.application')) |
---|
1943 | if self.grammar == 'STAR2': #nested save frames |
---|
1944 | child_refs = self.get_immediate_children(blockref) |
---|
1945 | for child_ref,child_info in child_refs: |
---|
1946 | child_name = child_info.block_id |
---|
1947 | outstring.write('\n\n' + 'save_' + child_name + '\n') |
---|
1948 | self.block_to_string_nested(child_ref,child_name,outstring,4) |
---|
1949 | outstring.write('\n' + 'save_'+ '\n') |
---|
1950 | elif self.grammar in ('1.0','1.1','2.0'): #non-nested save frames |
---|
1951 | child_refs = [a for a in blockorder if self.is_child_of_parent(blockref,a)] |
---|
1952 | for child_ref in child_refs: |
---|
1953 | child_name = self.child_table[child_ref].block_id |
---|
1954 | outstring.write('\n\n' + 'save_' + child_name + '\n') |
---|
1955 | outstring.write(str(self[child_ref])) |
---|
1956 | outstring.write('\n\n' + 'save_' + '\n') |
---|
1957 | all_names.remove(child_ref.lower()) |
---|
1958 | else: |
---|
1959 | raise StarError('Grammar %s is not recognised for output' % self.grammar) |
---|
1960 | if self.standard != 'Dic': #put contents after save frames |
---|
1961 | outstring.write(str(self[blockref])) |
---|
1962 | else: |
---|
1963 | outstring.write(self[blockref].printsection(start_from='_dictionary_valid.application')) |
---|
1964 | returnstring = outstring.getvalue() |
---|
1965 | outstring.close() |
---|
1966 | if len(all_names)>0: |
---|
1967 | print('WARNING: following blocks not output: %s' % repr(all_names)) |
---|
1968 | else: |
---|
1969 | print('All blocks output.') |
---|
1970 | return returnstring |
---|
1971 | |
---|
1972 | def block_to_string_nested(self,block_ref,block_id,outstring,indentlevel=0): |
---|
1973 | """Output a complete datablock indexed by [[block_ref]] and named [[block_id]], including children, |
---|
1974 | and syntactically nesting save frames""" |
---|
1975 | child_refs = self.get_immediate_children(block_ref) |
---|
1976 | self[block_ref].set_grammar(self.grammar) |
---|
1977 | if self.standard == 'Dic': |
---|
1978 | outstring.write(str(self[block_ref])) |
---|
1979 | for child_ref,child_info in child_refs: |
---|
1980 | child_name = child_info.block_id |
---|
1981 | outstring.write('\n' + 'save_' + child_name + '\n') |
---|
1982 | self.block_to_string_nested(child_ref,child_name,outstring,indentlevel) |
---|
1983 | outstring.write('\n' + ' '*indentlevel + 'save_' + '\n') |
---|
1984 | if self.standard != 'Dic': |
---|
1985 | outstring.write(str(self[block_ref])) |
---|
1986 | |
---|
1987 | |
---|
1988 | class StarFile(BlockCollection): |
---|
1989 | def __init__(self,datasource=None,maxinlength=-1,maxoutlength=0, |
---|
1990 | scoping='instance',grammar='1.1',scantype='standard', |
---|
1991 | **kwargs): |
---|
1992 | super(StarFile,self).__init__(datasource=datasource,**kwargs) |
---|
1993 | self.my_uri = getattr(datasource,'my_uri','') |
---|
1994 | if maxoutlength == 0: |
---|
1995 | self.maxoutlength = 2048 |
---|
1996 | else: |
---|
1997 | self.maxoutlength = maxoutlength |
---|
1998 | self.scoping = scoping |
---|
1999 | if isinstance(datasource,(unicode,str)) or hasattr(datasource,"read"): |
---|
2000 | ReadStar(datasource,prepared=self,grammar=grammar,scantype=scantype, |
---|
2001 | maxlength = maxinlength) |
---|
2002 | self.header_comment = \ |
---|
2003 | """#\\#STAR |
---|
2004 | ########################################################################## |
---|
2005 | # STAR Format file |
---|
2006 | # Produced by PySTARRW module |
---|
2007 | # |
---|
2008 | # This is a STAR file. STAR is a superset of the CIF file type. For |
---|
2009 | # more information, please refer to International Tables for Crystallography, |
---|
2010 | # Volume G, Chapter 2.1 |
---|
2011 | # |
---|
2012 | ########################################################################## |
---|
2013 | """ |
---|
2014 | def set_uri(self,my_uri): self.my_uri = my_uri |
---|
2015 | |
---|
2016 | |
---|
2017 | import math |
---|
2018 | class CIFStringIO(StringIO): |
---|
2019 | def __init__(self,target_width=80,**kwargs): |
---|
2020 | StringIO.__init__(self,**kwargs) |
---|
2021 | self.currentpos = 0 |
---|
2022 | self.target_width = target_width |
---|
2023 | self.tabwidth = -1 |
---|
2024 | self.indentlist = [0] |
---|
2025 | self.last_char = "" |
---|
2026 | |
---|
2027 | def write(self,outstring,canbreak=False,mustbreak=False,do_tab=True,newindent=False,unindent=False, |
---|
2028 | delimiter=False,startcol=-1): |
---|
2029 | """Write a string with correct linebreak, tabs and indents""" |
---|
2030 | # do we need to break? |
---|
2031 | if delimiter: |
---|
2032 | if len(outstring)>1: |
---|
2033 | raise ValueError('Delimiter %s is longer than one character' % repr( outstring )) |
---|
2034 | output_delimiter = True |
---|
2035 | if mustbreak: #insert a new line and indent |
---|
2036 | temp_string = '\n' + ' ' * self.indentlist[-1] |
---|
2037 | StringIO.write(self,temp_string) |
---|
2038 | self.currentpos = self.indentlist[-1] |
---|
2039 | self.last_char = temp_string[-1] |
---|
2040 | if self.currentpos+len(outstring)>self.target_width: #try to break |
---|
2041 | if not delimiter and outstring[0]!='\n': #ie <cr>; |
---|
2042 | if canbreak: |
---|
2043 | temp_string = '\n' + ' ' * self.indentlist[-1] |
---|
2044 | StringIO.write(self,temp_string) |
---|
2045 | self.currentpos = self.indentlist[-1] |
---|
2046 | self.last_char = temp_string[-1] |
---|
2047 | else: #assume a break will be forced on next value |
---|
2048 | output_delimiter = False #the line break becomes the delimiter |
---|
2049 | #try to match requested column |
---|
2050 | if startcol > 0: |
---|
2051 | if self.currentpos < startcol: |
---|
2052 | StringIO.write(self,(startcol - self.currentpos)* ' ') |
---|
2053 | self.currentpos = startcol |
---|
2054 | self.last_char = ' ' |
---|
2055 | else: |
---|
2056 | print('Could not format %s at column %d as already at %d' % (outstring,startcol,self.currentpos)) |
---|
2057 | startcol = -1 #so that tabbing works as a backup |
---|
2058 | #handle tabs |
---|
2059 | if self.tabwidth >0 and do_tab and startcol < 0: |
---|
2060 | next_stop = ((self.currentpos//self.tabwidth)+1)*self.tabwidth |
---|
2061 | #print 'Currentpos %d: Next tab stop at %d' % (self.currentpos,next_stop) |
---|
2062 | if self.currentpos < next_stop: |
---|
2063 | StringIO.write(self,(next_stop-self.currentpos)*' ') |
---|
2064 | self.currentpos = next_stop |
---|
2065 | self.last_char = ' ' |
---|
2066 | #calculate indentation after tabs and col setting applied |
---|
2067 | if newindent: #indent by current amount |
---|
2068 | if self.indentlist[-1] == 0: #first time |
---|
2069 | self.indentlist.append(self.currentpos) |
---|
2070 | # print 'Indentlist: ' + `self.indentlist` |
---|
2071 | else: |
---|
2072 | self.indentlist.append(self.indentlist[-1]+2) |
---|
2073 | elif unindent: |
---|
2074 | if len(self.indentlist)>1: |
---|
2075 | self.indentlist.pop() |
---|
2076 | else: |
---|
2077 | print('Warning: cannot unindent any further') |
---|
2078 | #check that we still need a delimiter |
---|
2079 | if self.last_char in [' ','\n','\t']: |
---|
2080 | output_delimiter = False |
---|
2081 | #now output the string - every invocation comes through here |
---|
2082 | if (delimiter and output_delimiter) or not delimiter: |
---|
2083 | StringIO.write(self,outstring) |
---|
2084 | last_line_break = outstring.rfind('\n') |
---|
2085 | if last_line_break >=0: |
---|
2086 | self.currentpos = len(outstring)-last_line_break |
---|
2087 | else: |
---|
2088 | self.currentpos = self.currentpos + len(outstring) |
---|
2089 | #remember the last character |
---|
2090 | if len(outstring)>0: |
---|
2091 | self.last_char = outstring[-1] |
---|
2092 | |
---|
2093 | def set_tab(self,tabwidth): |
---|
2094 | """Set the tab stop position""" |
---|
2095 | self.tabwidth = tabwidth |
---|
2096 | |
---|
2097 | class StarError(Exception): |
---|
2098 | def __init__(self,value): |
---|
2099 | self.value = value |
---|
2100 | def __str__(self): |
---|
2101 | return '\nStar Format error: '+ self.value |
---|
2102 | |
---|
2103 | class StarLengthError(Exception): |
---|
2104 | def __init__(self,value): |
---|
2105 | self.value = value |
---|
2106 | def __str__(self): |
---|
2107 | return '\nStar length error: ' + self.value |
---|
2108 | |
---|
2109 | class StarDerivationError(Exception): |
---|
2110 | def __init__(self,fail_name): |
---|
2111 | self.fail_name = fail_name |
---|
2112 | def __str__(self): |
---|
2113 | return "Derivation of %s failed, None returned" % self.fail_name |
---|
2114 | |
---|
2115 | # |
---|
2116 | # This is subclassed from AttributeError in order to allow hasattr |
---|
2117 | # to work. |
---|
2118 | # |
---|
2119 | class StarDerivationFailure(AttributeError): |
---|
2120 | def __init__(self,fail_name): |
---|
2121 | self.fail_name = fail_name |
---|
2122 | def __str__(self): |
---|
2123 | return "Derivation of %s failed" % self.fail_name |
---|
2124 | |
---|
2125 | def ReadStar(filename,prepared = None, maxlength=-1, |
---|
2126 | scantype='standard',grammar='STAR2',CBF=False): |
---|
2127 | |
---|
2128 | """ Read in a STAR file, returning the contents in the `prepared` object. |
---|
2129 | |
---|
2130 | * `filename` may be a URL, a file |
---|
2131 | path on the local system, or any object with a `read` method. |
---|
2132 | |
---|
2133 | * `prepared` provides a `StarFile` or `CifFile` object that the contents of `filename` |
---|
2134 | will be added to. |
---|
2135 | |
---|
2136 | * `maxlength` is the maximum allowable line length in the input file. This has been set at |
---|
2137 | 2048 characters for CIF but is unlimited (-1) for STAR files. |
---|
2138 | |
---|
2139 | * `grammar` chooses the STAR grammar variant. `1.0` is the original 1992 CIF/STAR grammar and `1.1` |
---|
2140 | is identical except for the exclusion of square brackets as the first characters in |
---|
2141 | undelimited datanames. `2.0` will read files in the CIF2.0 standard, and `STAR2` will |
---|
2142 | read files according to the STAR2 publication. If grammar is `None` or `auto`, autodetection |
---|
2143 | will be attempted in the order `2.0`, `1.1` and `1.0`. This will always succeed for conformant CIF2.0 files. |
---|
2144 | Note that (nested) save frames are read in all grammar variations and then flagged afterwards if |
---|
2145 | they do not match the requested grammar. |
---|
2146 | |
---|
2147 | * `scantype` can be `standard` or `flex`. `standard` provides pure Python parsing at the |
---|
2148 | cost of a factor of 10 or so in speed. `flex` will tokenise the input CIF file using |
---|
2149 | fast C routines. Note that running PyCIFRW in Jython uses native Java regular expressions |
---|
2150 | to provide a speedup regardless of this argument. |
---|
2151 | |
---|
2152 | * `CBF` flags that the input file is in Crystallographic Binary File format. The binary block is |
---|
2153 | excised from the input data stream before parsing and is not available in the returned object. |
---|
2154 | """ |
---|
2155 | |
---|
2156 | import string |
---|
2157 | import codecs |
---|
2158 | # save desired scoping |
---|
2159 | save_scoping = prepared.scoping |
---|
2160 | from . import YappsStarParser_1_1 as Y11 |
---|
2161 | from . import YappsStarParser_1_0 as Y10 |
---|
2162 | from . import YappsStarParser_2_0 as Y20 |
---|
2163 | from . import YappsStarParser_STAR2 as YST |
---|
2164 | if prepared is None: |
---|
2165 | prepared = StarFile() |
---|
2166 | if grammar == "auto" or grammar is None: |
---|
2167 | try_list = [('2.0',Y20),('1.1',Y11),('1.0',Y10)] |
---|
2168 | elif grammar == '1.0': |
---|
2169 | try_list = [('1.0',Y10)] |
---|
2170 | elif grammar == '1.1': |
---|
2171 | try_list = [('1.1',Y11)] |
---|
2172 | elif grammar == '2.0': |
---|
2173 | try_list = [('2.0',Y20)] |
---|
2174 | elif grammar == 'STAR2': |
---|
2175 | try_list = [('STAR2',YST)] |
---|
2176 | else: |
---|
2177 | raise AttributeError('Unknown STAR/CIF grammar requested, %s' % repr( grammar )) |
---|
2178 | if isinstance(filename,(unicode,str)): |
---|
2179 | # create an absolute URL |
---|
2180 | relpath = urlparse(filename) |
---|
2181 | if relpath.scheme == "": |
---|
2182 | if not os.path.isabs(filename): |
---|
2183 | fullpath = os.path.join(os.getcwd(),filename) |
---|
2184 | else: |
---|
2185 | fullpath = filename |
---|
2186 | newrel = list(relpath) |
---|
2187 | newrel[0] = "file" |
---|
2188 | newrel[2] = fullpath |
---|
2189 | my_uri = urlunparse(newrel) |
---|
2190 | else: |
---|
2191 | my_uri = urlunparse(relpath) |
---|
2192 | # print("Full URL is: " + my_uri) |
---|
2193 | filestream = urlopen(my_uri) |
---|
2194 | # text = filestream.read().decode('utf8') |
---|
2195 | text = filestream.read().decode('latin1') |
---|
2196 | filestream.close() |
---|
2197 | else: |
---|
2198 | filestream = filename #already opened for us |
---|
2199 | text = filestream.read() |
---|
2200 | if not isinstance(text,unicode): |
---|
2201 | # text = text.decode('utf8') #CIF is always ascii/utf8 |
---|
2202 | text = text.decode('latin1') #CIF is always ascii/utf8 |
---|
2203 | my_uri = "" |
---|
2204 | if not text: # empty file, return empty block |
---|
2205 | return prepared.set_uri(my_uri) |
---|
2206 | # filter out non-ASCII characters in CBF files if required. We assume |
---|
2207 | # that the binary is enclosed in a fixed string that occurs |
---|
2208 | # nowhere else. |
---|
2209 | if CBF: |
---|
2210 | text_bits = text.split("-BINARY-FORMAT-SECTION-") |
---|
2211 | text = text_bits[0] |
---|
2212 | for section in range(2,len(text_bits),2): |
---|
2213 | text = text+" (binary omitted)"+text_bits[section] |
---|
2214 | # we recognise ctrl-Z as end of file |
---|
2215 | endoffile = text.find(chr(26)) |
---|
2216 | if endoffile >= 0: |
---|
2217 | text = text[:endoffile] |
---|
2218 | split = text.split('\n') |
---|
2219 | if maxlength > 0: |
---|
2220 | toolong = [a for a in split if len(a)>maxlength] |
---|
2221 | if toolong: |
---|
2222 | pos = split.index(toolong[0]) |
---|
2223 | raise StarError( 'Line %d contains more than %d characters' % (pos+1,maxlength)) |
---|
2224 | # honour the header string |
---|
2225 | if text[:10] != "#\#CIF_2.0" and ('2.0',Y20) in try_list: |
---|
2226 | try_list.remove(('2.0',Y20),) |
---|
2227 | if not try_list: |
---|
2228 | raise StarError('File %s missing CIF2.0 header' % (filename)) |
---|
2229 | for grammar_name,Y in try_list: |
---|
2230 | if scantype == 'standard' or grammar_name in ['2.0','STAR2']: |
---|
2231 | parser = Y.StarParser(Y.StarParserScanner(text)) |
---|
2232 | else: |
---|
2233 | parser = Y.StarParser(Y.yappsrt.Scanner(None,[],text,scantype='flex')) |
---|
2234 | # handle encoding switch |
---|
2235 | if grammar_name in ['2.0','STAR2']: |
---|
2236 | prepared.set_characterset('unicode') |
---|
2237 | else: |
---|
2238 | prepared.set_characterset('ascii') |
---|
2239 | proto_star = None |
---|
2240 | try: |
---|
2241 | proto_star = getattr(parser,"input")(prepared) |
---|
2242 | except Y.yappsrt.SyntaxError as e: |
---|
2243 | input = parser._scanner.input |
---|
2244 | Y.yappsrt.print_error(input, e, parser._scanner) |
---|
2245 | except Y.yappsrt.NoMoreTokens: |
---|
2246 | print('Could not complete parsing; stopped around here:',file=sys.stderr) |
---|
2247 | print(parser._scanner,file=sys.stderr) |
---|
2248 | except ValueError: |
---|
2249 | print('Unexpected error:') |
---|
2250 | import traceback |
---|
2251 | traceback.print_exc() |
---|
2252 | if proto_star is not None: |
---|
2253 | proto_star.set_grammar(grammar_name) #remember for output |
---|
2254 | break |
---|
2255 | if proto_star is None: |
---|
2256 | errorstring = 'Syntax error in input file: last value parsed was %s' % Y.lastval |
---|
2257 | errorstring = errorstring + '\nParser status: %s' % repr( parser._scanner ) |
---|
2258 | raise StarError( errorstring) |
---|
2259 | # set visibility correctly |
---|
2260 | proto_star.scoping = 'dictionary' |
---|
2261 | proto_star.set_uri(my_uri) |
---|
2262 | proto_star.scoping = save_scoping |
---|
2263 | return proto_star |
---|
2264 | |
---|
2265 | def get_dim(dataitem,current=0,packlen=0): |
---|
2266 | zerotypes = [int, float, str] |
---|
2267 | if type(dataitem) in zerotypes: |
---|
2268 | return current, packlen |
---|
2269 | if not dataitem.__class__ == ().__class__ and \ |
---|
2270 | not dataitem.__class__ == [].__class__: |
---|
2271 | return current, packlen |
---|
2272 | elif len(dataitem)>0: |
---|
2273 | # print "Get_dim: %d: %s" % (current,`dataitem`) |
---|
2274 | return get_dim(dataitem[0],current+1,len(dataitem)) |
---|
2275 | else: return current+1,0 |
---|
2276 | |
---|
2277 | def apply_line_folding(instring,minwraplength=60,maxwraplength=80): |
---|
2278 | """Insert line folding characters into instring between min/max wraplength""" |
---|
2279 | # first check that we need to do this |
---|
2280 | lines = instring.split('\n') |
---|
2281 | line_len = [len(l) for l in lines] |
---|
2282 | if max(line_len) < maxwraplength and re.match("\\[ \v\t\f]*\n",instring) is None: |
---|
2283 | return instring |
---|
2284 | outstring = "\\\n" #header |
---|
2285 | for l in lines: |
---|
2286 | if len(l) < maxwraplength: |
---|
2287 | outstring = outstring + l |
---|
2288 | if len(l) > 0 and l[-1]=='\\': #who'da thunk it? A line ending with a backslash |
---|
2289 | outstring = outstring + "\\\n" # |
---|
2290 | outstring = outstring + "\n" # put back the split character |
---|
2291 | else: |
---|
2292 | current_bit = l |
---|
2293 | while len(current_bit) > maxwraplength: |
---|
2294 | space_pos = re.search('[ \v\f\t]+',current_bit[minwraplength:]) |
---|
2295 | if space_pos is not None and space_pos.start()<maxwraplength-1: |
---|
2296 | outstring = outstring + current_bit[:minwraplength+space_pos.start()] + "\\\n" |
---|
2297 | current_bit = current_bit[minwraplength+space_pos.start():] |
---|
2298 | else: #just blindly insert |
---|
2299 | outstring = outstring + current_bit[:maxwraplength-1] + "\\\n" |
---|
2300 | current_bit = current_bit[maxwraplength-1:] |
---|
2301 | outstring = outstring + current_bit |
---|
2302 | if current_bit[-1] == '\\': #a backslash just happens to be here |
---|
2303 | outstring = outstring + "\\\n" |
---|
2304 | outstring = outstring + '\n' |
---|
2305 | outstring = outstring[:-1] #remove final newline |
---|
2306 | return outstring |
---|
2307 | |
---|
2308 | def remove_line_folding(instring): |
---|
2309 | """Remove line folding from instring""" |
---|
2310 | if re.match(r"\\[ \v\t\f]*" +"\n",instring) is not None: |
---|
2311 | return re.sub(r"\\[ \v\t\f]*$" + "\n?","",instring,flags=re.M) |
---|
2312 | else: |
---|
2313 | return instring |
---|
2314 | |
---|
2315 | def apply_line_prefix(instring,prefix): |
---|
2316 | """Prefix every line in instring with prefix""" |
---|
2317 | if prefix[0] != ";" and "\\" not in prefix: |
---|
2318 | header = re.match(r"(\\[ \v\t\f]*" +"\n)",instring) |
---|
2319 | if header is not None: |
---|
2320 | print('Found line folded string for prefixing...') |
---|
2321 | not_header = instring[header.end():] |
---|
2322 | outstring = prefix + "\\\\\n" + prefix |
---|
2323 | else: |
---|
2324 | print('No folding in input string...') |
---|
2325 | not_header = instring |
---|
2326 | outstring = prefix + "\\\n" + prefix |
---|
2327 | outstring = outstring + not_header.replace("\n","\n"+prefix) |
---|
2328 | return outstring |
---|
2329 | raise StarError("Requested prefix starts with semicolon or contains a backslash: " + prefix) |
---|
2330 | |
---|
2331 | def remove_line_prefix(instring): |
---|
2332 | """Remove prefix from every line if present""" |
---|
2333 | prefix_match = re.match("(?P<prefix>[^;\\\n][^\n\\\\]+)(?P<folding>\\\\{1,2}[ \t\v\f]*\n)",instring) |
---|
2334 | if prefix_match is not None: |
---|
2335 | prefix_text = prefix_match.group('prefix') |
---|
2336 | print('Found prefix %s' % prefix_text) |
---|
2337 | prefix_end = prefix_match.end('folding') |
---|
2338 | # keep any line folding instructions |
---|
2339 | if prefix_match.group('folding')[:2]=='\\\\': #two backslashes |
---|
2340 | outstring = instring[prefix_match.end('folding')-1:].replace("\n"+prefix_text,"\n") |
---|
2341 | return "\\" + outstring #keep line folding first line |
---|
2342 | else: |
---|
2343 | outstring = instring[prefix_match.end('folding')-1:].replace("\n"+prefix_text,"\n") |
---|
2344 | return outstring[1:] #drop first line ending, no longer necessary |
---|
2345 | else: |
---|
2346 | return instring |
---|
2347 | |
---|
2348 | |
---|
2349 | def listify(item): |
---|
2350 | if isinstance(item,unicode): return [item] |
---|
2351 | else: return item |
---|
2352 | |
---|
2353 | #Transpose the list of lists passed to us |
---|
2354 | def transpose(base_list): |
---|
2355 | new_lofl = [] |
---|
2356 | full_length = len(base_list) |
---|
2357 | opt_range = range(full_length) |
---|
2358 | for i in range(len(base_list[0])): |
---|
2359 | new_packet = [] |
---|
2360 | for j in opt_range: |
---|
2361 | new_packet.append(base_list[j][i]) |
---|
2362 | new_lofl.append(new_packet) |
---|
2363 | return new_lofl |
---|
2364 | |
---|
2365 | # This routine optimised to return as quickly as possible |
---|
2366 | # as it is called a lot. |
---|
2367 | def not_none(itemlist): |
---|
2368 | """Return true only if no values of None are present""" |
---|
2369 | if itemlist is None: |
---|
2370 | return False |
---|
2371 | if not isinstance(itemlist,(tuple,list)): |
---|
2372 | return True |
---|
2373 | for x in itemlist: |
---|
2374 | if not not_none(x): return False |
---|
2375 | return True |
---|
2376 | |
---|
2377 | |
---|
2378 | def check_stringiness(data): |
---|
2379 | """Check that the contents of data are all strings""" |
---|
2380 | if not hasattr(data,'dtype'): #so not Numpy |
---|
2381 | from numbers import Number |
---|
2382 | if isinstance(data,Number): return False |
---|
2383 | elif isinstance(data,(unicode,str)): return True |
---|
2384 | elif data is None:return False #should be data are None :) |
---|
2385 | else: |
---|
2386 | for one_item in data: |
---|
2387 | if not check_stringiness(one_item): return False |
---|
2388 | return True #all must be strings |
---|
2389 | else: #numerical python |
---|
2390 | import numpy |
---|
2391 | if data.ndim == 0: #a bare value |
---|
2392 | if data.dtype.kind in ['S','U']: return True |
---|
2393 | else: return False |
---|
2394 | else: |
---|
2395 | for one_item in numpy.nditer(data): |
---|
2396 | print('numpy data: ' + repr( one_item )) |
---|
2397 | if not check_stringiness(one_item): return False |
---|
2398 | return True |
---|
2399 | |
---|
2400 | def process_template(template_file): |
---|
2401 | """Process a template datafile to formatting instructions""" |
---|
2402 | template_as_cif = StarFile(template_file,grammar="2.0").first_block() |
---|
2403 | if isinstance(template_file,(unicode,str)): |
---|
2404 | template_string = open(template_file).read() |
---|
2405 | else: #a StringIO object |
---|
2406 | template_file.seek(0) #reset |
---|
2407 | template_string = template_file.read() |
---|
2408 | #template_as_lines = template_string.split("\n") |
---|
2409 | #template_as_lines = [l for l in template_as_lines if len(l)>0 and l[0]!='#'] |
---|
2410 | #template_as_lines = [l for l in template_as_lines if l.split()[0] != 'loop_'] |
---|
2411 | #template_full_lines = dict([(l.split()[0],l) for l in template_as_lines if len(l.split())>0]) |
---|
2412 | form_hints = [] #ordered array of hint dictionaries |
---|
2413 | find_indent = "^ +" |
---|
2414 | for item in template_as_cif.item_order: #order of input |
---|
2415 | if not isinstance(item,int): #not nested |
---|
2416 | hint_dict = {"dataname":item} |
---|
2417 | # find the line in the file |
---|
2418 | start_pos = re.search("(^[ \t]*(?P<name>" + item + ")[ \t\n]+)(?P<spec>([\S]+)|(^;))",template_string,re.I|re.M) |
---|
2419 | if start_pos.group("spec") != None: |
---|
2420 | spec_pos = start_pos.start("spec")-start_pos.start(0) |
---|
2421 | spec_char = template_string[start_pos.start("spec"):start_pos.start("spec")+3] |
---|
2422 | if spec_char[0] in '\'";': |
---|
2423 | hint_dict.update({"delimiter":spec_char[0]}) |
---|
2424 | if spec_char == '"""' or spec_char == "'''": |
---|
2425 | hint_dict.update({"delimiter":spec_char}) |
---|
2426 | if spec_char[0] != ";": #so we need to work out the column number |
---|
2427 | hint_dict.update({"column":spec_pos}) |
---|
2428 | else: #need to put in the carriage return |
---|
2429 | hint_dict.update({"delimiter":"\n;"}) |
---|
2430 | # can we format the text? |
---|
2431 | text_val = template_as_cif[item] |
---|
2432 | hint_dict["reformat"] = "\n\t" in text_val or "\n " in text_val |
---|
2433 | if hint_dict["reformat"]: #find the indentation |
---|
2434 | p = re.search(find_indent,text_val,re.M) |
---|
2435 | if p.group() is not None: |
---|
2436 | hint_dict["reformat_indent"]=p.end() - p.start() |
---|
2437 | if start_pos.group('name') != None: |
---|
2438 | name_pos = start_pos.start('name') - start_pos.start(0) |
---|
2439 | hint_dict.update({"name_pos":name_pos}) |
---|
2440 | #print '%s: %s' % (item,`hint_dict`) |
---|
2441 | form_hints.append(hint_dict) |
---|
2442 | else: #loop block |
---|
2443 | testnames = template_as_cif.loops[item] |
---|
2444 | total_items = len(template_as_cif.loops[item]) |
---|
2445 | testname = testnames[0] |
---|
2446 | #find the loop spec line in the file |
---|
2447 | loop_regex = "(^[ \t]*(?P<loop>loop_)[ \t\n\r]+(?P<name>" + testname + ")([ \t\n\r]+_[\S]+){%d}[ \t]*$(?P<packet>(.(?!_loop|_[\S]+))*))" % (total_items - 1) |
---|
2448 | loop_line = re.search(loop_regex,template_string,re.I|re.M|re.S) |
---|
2449 | loop_so_far = loop_line.end() |
---|
2450 | packet_text = loop_line.group('packet') |
---|
2451 | loop_indent = loop_line.start('loop') - loop_line.start(0) |
---|
2452 | form_hints.append({"dataname":'loop','name_pos':loop_indent}) |
---|
2453 | packet_regex = "[ \t]*(?P<all>(?P<sqqq>'''([^\n\r\f']*)''')|(?P<sq>'([^\n\r\f']*)'+)|(?P<dq>\"([^\n\r\"]*)\"+)|(?P<none>[^\s]+))" |
---|
2454 | packet_pos = re.finditer(packet_regex,packet_text) |
---|
2455 | line_end_pos = re.finditer("^",packet_text,re.M) |
---|
2456 | next_end = next(line_end_pos).end() |
---|
2457 | last_end = next_end |
---|
2458 | for loopname in testnames: |
---|
2459 | #find the name in the file for name pos |
---|
2460 | name_regex = "(^[ \t]*(?P<name>" + loopname + "))" |
---|
2461 | name_match = re.search(name_regex,template_string,re.I|re.M|re.S) |
---|
2462 | loop_name_indent = name_match.start('name')-name_match.start(0) |
---|
2463 | hint_dict = {"dataname":loopname,"name_pos":loop_name_indent} |
---|
2464 | #find the value |
---|
2465 | thismatch = next(packet_pos) |
---|
2466 | while thismatch.start('all') > next_end: |
---|
2467 | try: |
---|
2468 | last_end = next_end |
---|
2469 | next_end = next(line_end_pos).start() |
---|
2470 | print('next end %d' % next_end) |
---|
2471 | except StopIteration: |
---|
2472 | break |
---|
2473 | print('Start %d, last_end %d' % (thismatch.start('all'),last_end)) |
---|
2474 | col_pos = thismatch.start('all') - last_end + 1 |
---|
2475 | if thismatch.group('none') is None: |
---|
2476 | if thismatch.group('sqqq') is not None: |
---|
2477 | hint_dict.update({'delimiter':"'''"}) |
---|
2478 | else: |
---|
2479 | hint_dict.update({'delimiter':thismatch.groups()[0][0]}) |
---|
2480 | hint_dict.update({'column':col_pos}) |
---|
2481 | print('%s: %s' % (loopname,repr( hint_dict ))) |
---|
2482 | form_hints.append(hint_dict) |
---|
2483 | return form_hints |
---|
2484 | |
---|
2485 | |
---|
2486 | #No documentation flags |
---|
2487 | |
---|