SOAdvancedDissector

SOAdvancedDissector Git Source Tree

Root/objects.py

1# -*- coding: utf-8 -*-
2#
3# Copyright Grégory Soutadé
4
5# This file is part of SOAdvancedDissector
6
7# SOAdvancedDissector is free software: you can redistribute it and/or modify
8# it under the terms of the GNU General Public License as published by
9# the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# SOAdvancedDissector is distributed in the hope that it will be useful,
13# but WITHOUT ANY WARRANTY; without even the implied warranty of
14# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15# GNU General Public License for more details.
16#
17# You should have received a copy of the GNU General Public License
18# along with SOAdvancedDissector. If not, see <http://www.gnu.org/licenses/>.
19#
20
21import sys
22from cppprototypeparser import CPPPrototypeParser
23
24print_raw_virtual_table = False
25print_indent = False
26cur_indent = 0
27
28parser = CPPPrototypeParser()
29
30def setPrintRawVirtualTable(value):
31 global print_raw_virtual_table
32 print_raw_virtual_table = value
33
34def setPrintIndent(value):
35 global print_indent
36 print_indent = value
37
38def _getIndent():
39 global print_indent
40 indent = ''
41 if print_indent:
42 indent = ' '*cur_indent
43 return indent
44
45class Object:
46 """Abstract object representation (simply a name)"""
47 def __init__(self, name):
48 self.name = name
49
50 def find(self, obj):
51 if self == obj:
52 return self
53 return None
54
55 def getParametersDependencies(self):
56 return None
57
58 def getDependencies(self):
59 """Get dependencies from other namespaces"""
60 return None
61
62 def __eq__(self, other):
63 if type(other) == str:
64 return self.name == other
65 else:
66 return self.name == other.name
67
68 def __lt__(self, other):
69 return self.name < other.name
70
71class Attribute(Object):
72 """Class attribute (member)"""
73 def __init__(self, name, address=0, namespace=''):
74 Object.__init__(self, name)
75 self.address = address
76 self.namespace = namespace
77
78 def fullname(self):
79 """Return namespace::name"""
80 if self.namespace:
81 return '{}::{}'.format(self.namespace, self.name)
82 else:
83 return self.name
84
85 def __eq__(self, other):
86 if type(other) == str:
87 return self.name == other
88 else:
89 if hasattr(other, 'address'):
90 return self.address == other.address
91 else:
92 return self.name == other.name
93
94 def __str__(self):
95 return '{}void* {};\n'.format(_getIndent(), self.name)
96
97
98class Function(Attribute):
99 """Function description"""
100 def __init__(self, name, address=0, virtual=False, pure_virtual=False, namespace=''):
101 # Be sure we have () in name
102 if '(' in name and not name.endswith(')') and not name.endswith('const')\
103 and not name.endswith('const&'):
104 name += ')'
105 Attribute.__init__(self, name, address, namespace)
106 self.virtual = virtual
107 self.pure_virtual = pure_virtual
108 self.constructor = False
109
110 def isPure(self):
111 """Is function pure virtual"""
112 return self.pure_virtual
113
114 def setConstructor(self, value):
115 """Set/clear constructor property"""
116 self.constructor = value
117
118 def getDependencies(self):
119 """Get dependencies from other namespaces"""
120 dependencies = []
121 parser.parse(self.name)
122 if not parser.has_parameters:
123 return None
124 for p in parser.parameters:
125 # If parameter has namespace, add into dependencies
126 if type(p) == list and len(p) > 1 and not self.namespace.startswith(p[0]):
127 dep = '::'.join(p)
128 if not '::' in dep: continue
129 dep = dep.replace('*', '')
130 dep = dep.replace('&', '')
131 if dep.endswith(' const'):
132 dep = dep[:-len(' const')]
133 dependencies.append(dep)
134 if dependencies:
135 return list(set(dependencies))
136 return None
137
138 def _getReturnType(self):
139 """Get method return type"""
140 type_ = 'void '
141 if self.name.startswith('vtable_index'): type_='void* '
142 elif self.name.startswith('operator new'): type_='void* '
143 elif self.constructor: type_ = ''
144 return type_
145
146 def __str__(self):
147 res = ''
148 type_ = self._getReturnType()
149 if self.pure_virtual:
150 res = 'virtual {}{} = 0;\n'.format(type_, self.name)
151 elif self.virtual:
152 res = 'virtual {}{};\n'.format(type_, self.name)
153 else:
154 res = '{}{};\n'.format(type_, self.name)
155 res = '{}{}'.format(_getIndent(), res)
156 return res
157
158class Namespace(Object):
159 """Namespace description"""
160 def __init__(self, name):
161 Object.__init__(self, name)
162 self.childs = []
163 self.dependencies = [] # Dependencies from objects in other namespace
164
165 def addChild(self, child):
166 """Add child (function, class, attribute) to namespace"""
167 if not child in self.childs:
168 self.childs.append(child)
169
170 def removeChild(self, child):
171 """Remove child from namespace"""
172 self.childs.remove(child)
173
174 def child(self, name):
175 """Try to find name in childs without recursion"""
176 for child in self.childs:
177 if child.name == name:
178 return child
179 return None
180
181 def find(self, obj):
182 """Try to find obj in childs and their own child (with recursion)"""
183 if self == obj:
184 return self
185
186 for child in self.childs:
187 res = child.find(obj)
188 if res:
189 return res
190 return None
191
192 def fillFrom(self, other):
193 """Copy all childs from other object"""
194 for child in other.childs:
195 self.childs.append(child)
196
197 def getDependencies(self):
198 """Get dependencies from other namespaces"""
199 dependencies = []
200
201 for child in self.childs:
202 depend = child.getDependencies()
203 if depend:
204 for d in depend:
205 if not d.startswith('{}::'.format(self.name)):
206 dependencies.append(d)
207
208 if dependencies:
209 return list(set(dependencies))
210 return []
211
212 def __str__(self):
213 global cur_indent
214 if self.name != 'global':
215 res = '{}namespace {} {{\n\n'.format(_getIndent(), self.name)
216 cur_indent += 1
217 else:
218 res = ''
219
220 namespaces = []
221 classes = []
222 functions = []
223 other = []
224 for child in self.childs:
225 if type(child) == Namespace: namespaces.append(child)
226 elif type(child) == Class: classes.append(child)
227 elif type(child) == Function: functions.append(child)
228 else: other.append(child)
229
230 # Compute classes inheritance dependency
231 classes_dep = []
232 for class_ in sorted(classes):
233 isDep = False
234 for pos, class2 in enumerate(classes_dep):
235 if class_ in class2.inherit_from:
236 isDep = True
237 classes_dep.insert(pos, class_)
238 break
239 if not isDep:
240 classes_dep.append(class_)
241
242
243 # Add class declaration
244 if len(classes_dep) > 1:
245 for c in classes_dep:
246 res += '{}class {};\n'.format(_getIndent(), c.name)
247 if classes_dep: res += '\n\n'
248
249 for namespace in sorted(namespaces):
250 res += namespace.__str__()
251 if namespaces: res += '\n'
252
253 for c in classes_dep:
254 res += c.__str__()
255 if classes_dep: res += '\n'
256
257 for func in sorted(functions):
258 res += func.__str__()
259 if functions: res += '\n'
260
261 for child in sorted(other):
262 res += child.__str__()
263 if other: res += '\n'
264
265 if self.name != 'global':
266 cur_indent -= 1
267 res += '{}}}\n'.format(_getIndent())
268 res += '\n'
269
270 return res
271
272class Class(Namespace):
273 """Class description"""
274 def __init__(self, name, namespace=''):
275 Namespace.__init__(self, name)
276
277 self.constructors = []
278 self.destructors = []
279 self.inherit_from = []
280 self.virtual_functions = []
281 self.namespace = namespace
282
283 def fullname(self):
284 """Return namespace::name"""
285 if self.namespace:
286 return '{}::{}'.format(self.namespace, self.name)
287 else:
288 return self.name
289
290 def _checkConstructor(self, obj):
291 """Check if obj is a constructor/destructor and
292 set its property.
293
294 Returns
295 -------
296 list
297 Adequat list (constructor or destructor) or None
298 """
299 if type(obj) != Function: return None
300 if obj.name.startswith('~'):
301 # sys.stderr.write('Check C {} -> D\n'.format(obj.name))
302 obj.setConstructor(True)
303 return self.destructors
304 if obj.name.startswith('{}('.format(self.name)):
305 # sys.stderr.write('Check C {} -> C\n'.format(obj.name))
306 obj.setConstructor(True)
307 return self.constructors
308 # sys.stderr.write('Check C {} -> N\n'.format(obj.name))
309 return None
310
311 def addVirtualFunction(self, obj):
312 """Add a new virtual function"""
313 if obj.address == 0 or not obj in self.virtual_functions:
314 self._checkConstructor(obj)
315 self.virtual_functions.append(obj)
316
317 def updateVirtualFunction(self, idx, obj):
318 """Update virtual function at index idx"""
319 try:
320 self._checkConstructor(obj)
321 self.virtual_functions[idx] = obj
322 except:
323 sys.stderr.write('updateVirtualFunction Error, cur vtable size {}; idx {}, class {}, obj {}\n'.format(len(self.virtual_functions), idx, self.name, obj))
324 sys.stderr.flush()
325
326 def addMember(self, obj):
327 """Add a new member"""
328 if obj in self.virtual_functions or\
329 obj in self.constructors or\
330 obj in self.destructors or\
331 obj in self.childs:
332 return
333
334 targetList = self._checkConstructor(obj)
335 if targetList is None:
336 self.childs.append(obj)
337 else:
338 targetList.append(obj)
339
340 def addChild(self, child):
341 return self.addMember(child)
342
343 def addBaseClass(self, obj):
344 self.inherit_from.append(obj)
345
346 def fixVirtualFunction(self, index, newName):
347 """Set real name for virtfunc and unknown virtfunc
348 generic names
349 """
350 if index >= len(self.virtual_functions):
351 sys.stderr.write('FVF Error {} > {} for {}/{}\n'.format(index, len(self.virtual_functions), newName, self.fullname()))
352 return False
353
354 virtfunc = self.virtual_functions[index]
355
356 if not virtfunc.isPure():
357 return False
358
359 if virtfunc.name.startswith('virtfunc') or\
360 virtfunc.name.startswith('unknown_virtfunc'):
361 virtfunc.name = newName
362 self._checkConstructor(virtfunc)
363 return True
364
365 return False
366
367 def hasMultipleVirtualSections(self):
368 """Check if we have a vtable_indexX (X>0) entry in our
369 virtual functions table
370 """
371 if not self.virtual_functions:
372 return False
373
374 for vfunc in self.virtual_functions:
375 if vfunc.name.startswith('vtable_index') and\
376 vfunc.name != 'vtable_index0':
377 return True
378 return False
379
380 def fixupInheritance(self):
381 """Report virtual function name in base class
382 if there are pure virtual.
383 """
384 if not self.inherit_from:
385 return
386
387 # First, handle all of our bases
388 for base in self.inherit_from:
389 base.fixupInheritance()
390
391 updated = False
392 curIdx = 0
393 if self.hasMultipleVirtualSections():
394 for base in self.inherit_from:
395 # First is vtable_index, skip it
396 curIdx += 1
397 targetIdx = 1
398 updated = False
399 while curIdx < len(base.virtual_functions):
400 vfunc = self.virtual_functions[curIdx]
401 if vfunc.name.startswith('virtual_index'):
402 break
403 if base.fixVirtualFunction(targetIdx, vfunc.name):
404 updated = True
405 curIdx += 1
406 targetIdx += 1
407 if updated:
408 base.fixupInheritance()
409 # Go to next vtable index if we are not
410 while curIdx < len(self.virtual_functions):
411 vfunc = self.virtual_functions[curIdx]
412 if vfunc.name.startswith('virtual_index'):
413 break
414 curIdx += 1
415 else:
416 for base in self.inherit_from:
417 targetIdx = 0
418 while targetIdx < len(base.virtual_functions):
419 vfunc = self.virtual_functions[curIdx]
420 if base.fixVirtualFunction(targetIdx, vfunc.name):
421 updated = True
422 curIdx += 1
423 targetIdx += 1
424 if updated:
425 base.fixupInheritance()
426
427 def looksLikeNamespace(self):
428 """Empty specific class attributes looks like a namespace
429 """
430 return not self.constructors and\
431 not self.destructors and\
432 not self.inherit_from and\
433 not self.virtual_functions
434
435 def _getDependencies(self, targetList, dependencies):
436 for obj in targetList:
437 res = obj.getDependencies()
438 if res:
439 for d in res:
440 if self.namespace and not d.startswith(self.namespace):
441 dependencies.append(d)
442
443 def getDependencies(self):
444 """Get dependencies from other namespaces"""
445 dependencies = []
446
447 for base in self.inherit_from:
448 if base.namespace:
449 dependencies.append(base.fullname())
450
451 self._getDependencies(self.constructors, dependencies)
452 self._getDependencies(self.destructors, dependencies)
453 self._getDependencies(self.virtual_functions, dependencies)
454 self._getDependencies(self.childs, dependencies)
455
456 if dependencies:
457 return list(set(dependencies))
458 return None
459
460 def printOverloadedVirtualTable(self):
461 """Only select overloaded methods from
462 virtual table
463 """
464 res = ''
465 vfunc_res = ''
466 for vfunc in self.virtual_functions:
467 if vfunc.name.startswith('vtable_index') or\
468 vfunc.name.startswith('typeinfo()'):
469 continue
470
471 # Not overloaded by us
472 if vfunc.namespace and vfunc.namespace != self.namespace:
473 continue
474
475 vfunc_res = vfunc.__str__()
476
477 # Method already in result,
478 # It's the case for virtual descriptor
479 if vfunc_res in res:
480 continue
481
482 res += vfunc_res
483
484 return res
485
486 def __str__(self):
487 global print_raw_virtual_table
488 global cur_indent
489 res = '{}class {}'.format(_getIndent(), self.name)
490 if self.inherit_from:
491 res += ': '
492 bases = []
493 for base in self.inherit_from:
494 bases.append('public {}'.format(base.fullname()))
495 res += ', '.join(bases)
496 res += '\n{}{{\n'.format(_getIndent())
497 res += '{}public:\n'.format(_getIndent())
498 cur_indent += 1
499
500 for constructor in sorted(self.constructors):
501 res += constructor.__str__()
502 if len(self.constructors): res += '\n'
503
504 for destructor in sorted(self.destructors):
505 res += destructor.__str__()
506 if len(self.destructors): res += '\n'
507
508 # Do not sort virtual tables !
509 virtfuncs = ''
510 if print_raw_virtual_table:
511 for virtfunc in self.virtual_functions:
512 virtfuncs += virtfunc.__str__()
513 else:
514 virtfuncs = self.printOverloadedVirtualTable()
515
516 if len(virtfuncs): res += '{}\n'.format(virtfuncs)
517
518 methods = []
519 other = []
520
521 for child in self.childs:
522 if type(child) == Function: methods.append(child)
523 else: other.append(child)
524
525 for method in sorted(methods):
526 res += method.__str__()
527 if methods: res += '\n'
528
529 for child in sorted(other):
530 res += child.__str__()
531 if other: res += '\n'
532
533 cur_indent -= 1
534 res += '{}}};\n\n'.format(_getIndent())
535
536 return res

Archive Download this file

Branches