SOAdvancedDissector/objects.py

537 lines
17 KiB
Python

# -*- coding: utf-8 -*-
#
# Copyright Grégory Soutadé
# This file is part of SOAdvancedDissector
# SOAdvancedDissector is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# SOAdvancedDissector is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with SOAdvancedDissector. If not, see <http://www.gnu.org/licenses/>.
#
import sys
from cppprototypeparser import CPPPrototypeParser
print_raw_virtual_table = False
print_indent = False
cur_indent = 0
parser = CPPPrototypeParser()
def setPrintRawVirtualTable(value):
global print_raw_virtual_table
print_raw_virtual_table = value
def setPrintIndent(value):
global print_indent
print_indent = value
def _getIndent():
global print_indent
indent = ''
if print_indent:
indent = ' '*cur_indent
return indent
class Object:
"""Abstract object representation (simply a name)"""
def __init__(self, name):
self.name = name
def find(self, obj):
if self == obj:
return self
return None
def getParametersDependencies(self):
return None
def getDependencies(self):
"""Get dependencies from other namespaces"""
return None
def __eq__(self, other):
if type(other) == str:
return self.name == other
else:
return self.name == other.name
def __lt__(self, other):
return self.name < other.name
class Attribute(Object):
"""Class attribute (member)"""
def __init__(self, name, address=0, namespace=''):
Object.__init__(self, name)
self.address = address
self.namespace = namespace
def fullname(self):
"""Return namespace::name"""
if self.namespace:
return '{}::{}'.format(self.namespace, self.name)
else:
return self.name
def __eq__(self, other):
if type(other) == str:
return self.name == other
else:
if hasattr(other, 'address'):
return self.address == other.address
else:
return self.name == other.name
def __str__(self):
return '{}void* {};\n'.format(_getIndent(), self.name)
class Function(Attribute):
"""Function description"""
def __init__(self, name, address=0, virtual=False, pure_virtual=False, namespace=''):
# Be sure we have () in name
if '(' in name and not name.endswith(')') and not name.endswith('const')\
and not name.endswith('const&'):
name += ')'
Attribute.__init__(self, name, address, namespace)
self.virtual = virtual
self.pure_virtual = pure_virtual
self.constructor = False
def isPure(self):
"""Is function pure virtual"""
return self.pure_virtual
def setConstructor(self, value):
"""Set/clear constructor property"""
self.constructor = value
def getDependencies(self):
"""Get dependencies from other namespaces"""
dependencies = []
parser.parse(self.name)
if not parser.has_parameters:
return None
for p in parser.parameters:
# If parameter has namespace, add into dependencies
if type(p) == list and len(p) > 1 and not self.namespace.startswith(p[0]):
dep = '::'.join(p)
if not '::' in dep: continue
dep = dep.replace('*', '')
dep = dep.replace('&', '')
if dep.endswith(' const'):
dep = dep[:-len(' const')]
dependencies.append(dep)
if dependencies:
return list(set(dependencies))
return None
def _getReturnType(self):
"""Get method return type"""
type_ = 'void '
if self.name.startswith('vtable_index'): type_='void* '
elif self.name.startswith('operator new'): type_='void* '
elif self.constructor: type_ = ''
return type_
def __str__(self):
res = ''
type_ = self._getReturnType()
if self.pure_virtual:
res = 'virtual {}{} = 0;\n'.format(type_, self.name)
elif self.virtual:
res = 'virtual {}{};\n'.format(type_, self.name)
else:
res = '{}{};\n'.format(type_, self.name)
res = '{}{}'.format(_getIndent(), res)
return res
class Namespace(Object):
"""Namespace description"""
def __init__(self, name):
Object.__init__(self, name)
self.childs = []
self.dependencies = [] # Dependencies from objects in other namespace
def addChild(self, child):
"""Add child (function, class, attribute) to namespace"""
if not child in self.childs:
self.childs.append(child)
def removeChild(self, child):
"""Remove child from namespace"""
self.childs.remove(child)
def child(self, name):
"""Try to find name in childs without recursion"""
for child in self.childs:
if child.name == name:
return child
return None
def find(self, obj):
"""Try to find obj in childs and their own child (with recursion)"""
if self == obj:
return self
for child in self.childs:
res = child.find(obj)
if res:
return res
return None
def fillFrom(self, other):
"""Copy all childs from other object"""
for child in other.childs:
self.childs.append(child)
def getDependencies(self):
"""Get dependencies from other namespaces"""
dependencies = []
for child in self.childs:
depend = child.getDependencies()
if depend:
for d in depend:
if not d.startswith('{}::'.format(self.name)):
dependencies.append(d)
if dependencies:
return list(set(dependencies))
return []
def __str__(self):
global cur_indent
if self.name != 'global':
res = '{}namespace {} {{\n\n'.format(_getIndent(), self.name)
cur_indent += 1
else:
res = ''
namespaces = []
classes = []
functions = []
other = []
for child in self.childs:
if type(child) == Namespace: namespaces.append(child)
elif type(child) == Class: classes.append(child)
elif type(child) == Function: functions.append(child)
else: other.append(child)
# Compute classes inheritance dependency
classes_dep = []
for class_ in sorted(classes):
isDep = False
for pos, class2 in enumerate(classes_dep):
if class_ in class2.inherit_from:
isDep = True
classes_dep.insert(pos, class_)
break
if not isDep:
classes_dep.append(class_)
# Add class declaration
if len(classes_dep) > 1:
for c in classes_dep:
res += '{}class {};\n'.format(_getIndent(), c.name)
if classes_dep: res += '\n\n'
for namespace in sorted(namespaces):
res += namespace.__str__()
if namespaces: res += '\n'
for c in classes_dep:
res += c.__str__()
if classes_dep: res += '\n'
for func in sorted(functions):
res += func.__str__()
if functions: res += '\n'
for child in sorted(other):
res += child.__str__()
if other: res += '\n'
if self.name != 'global':
cur_indent -= 1
res += '{}}}\n'.format(_getIndent())
res += '\n'
return res
class Class(Namespace):
"""Class description"""
def __init__(self, name, namespace=''):
Namespace.__init__(self, name)
self.constructors = []
self.destructors = []
self.inherit_from = []
self.virtual_functions = []
self.namespace = namespace
def fullname(self):
"""Return namespace::name"""
if self.namespace:
return '{}::{}'.format(self.namespace, self.name)
else:
return self.name
def _checkConstructor(self, obj):
"""Check if obj is a constructor/destructor and
set its property.
Returns
-------
list
Adequat list (constructor or destructor) or None
"""
if type(obj) != Function: return None
if obj.name.startswith('~'):
# sys.stderr.write('Check C {} -> D\n'.format(obj.name))
obj.setConstructor(True)
return self.destructors
if obj.name.startswith('{}('.format(self.name)):
# sys.stderr.write('Check C {} -> C\n'.format(obj.name))
obj.setConstructor(True)
return self.constructors
# sys.stderr.write('Check C {} -> N\n'.format(obj.name))
return None
def addVirtualFunction(self, obj):
"""Add a new virtual function"""
if obj.address == 0 or not obj in self.virtual_functions:
self._checkConstructor(obj)
self.virtual_functions.append(obj)
def updateVirtualFunction(self, idx, obj):
"""Update virtual function at index idx"""
try:
self._checkConstructor(obj)
self.virtual_functions[idx] = obj
except:
sys.stderr.write('updateVirtualFunction Error, cur vtable size {}; idx {}, class {}, obj {}\n'.format(len(self.virtual_functions), idx, self.name, obj))
sys.stderr.flush()
def addMember(self, obj):
"""Add a new member"""
if obj in self.virtual_functions or\
obj in self.constructors or\
obj in self.destructors or\
obj in self.childs:
return
targetList = self._checkConstructor(obj)
if targetList is None:
self.childs.append(obj)
else:
targetList.append(obj)
def addChild(self, child):
return self.addMember(child)
def addBaseClass(self, obj):
self.inherit_from.append(obj)
def fixVirtualFunction(self, index, newName):
"""Set real name for virtfunc and unknown virtfunc
generic names
"""
if index >= len(self.virtual_functions):
sys.stderr.write('FVF Error {} > {} for {}/{}\n'.format(index, len(self.virtual_functions), newName, self.fullname()))
return False
virtfunc = self.virtual_functions[index]
if not virtfunc.isPure():
return False
if virtfunc.name.startswith('virtfunc') or\
virtfunc.name.startswith('unknown_virtfunc'):
virtfunc.name = newName
self._checkConstructor(virtfunc)
return True
return False
def hasMultipleVirtualSections(self):
"""Check if we have a vtable_indexX (X>0) entry in our
virtual functions table
"""
if not self.virtual_functions:
return False
for vfunc in self.virtual_functions:
if vfunc.name.startswith('vtable_index') and\
vfunc.name != 'vtable_index0':
return True
return False
def fixupInheritance(self):
"""Report virtual function name in base class
if there are pure virtual.
"""
if not self.inherit_from:
return
# First, handle all of our bases
for base in self.inherit_from:
base.fixupInheritance()
updated = False
curIdx = 0
if self.hasMultipleVirtualSections():
for base in self.inherit_from:
# First is vtable_index, skip it
curIdx += 1
targetIdx = 1
updated = False
while curIdx < len(base.virtual_functions):
vfunc = self.virtual_functions[curIdx]
if vfunc.name.startswith('virtual_index'):
break
if base.fixVirtualFunction(targetIdx, vfunc.name):
updated = True
curIdx += 1
targetIdx += 1
if updated:
base.fixupInheritance()
# Go to next vtable index if we are not
while curIdx < len(self.virtual_functions):
vfunc = self.virtual_functions[curIdx]
if vfunc.name.startswith('virtual_index'):
break
curIdx += 1
else:
for base in self.inherit_from:
targetIdx = 0
while targetIdx < len(base.virtual_functions):
vfunc = self.virtual_functions[curIdx]
if base.fixVirtualFunction(targetIdx, vfunc.name):
updated = True
curIdx += 1
targetIdx += 1
if updated:
base.fixupInheritance()
def looksLikeNamespace(self):
"""Empty specific class attributes looks like a namespace
"""
return not self.constructors and\
not self.destructors and\
not self.inherit_from and\
not self.virtual_functions
def _getDependencies(self, targetList, dependencies):
for obj in targetList:
res = obj.getDependencies()
if res:
for d in res:
if self.namespace and not d.startswith(self.namespace):
dependencies.append(d)
def getDependencies(self):
"""Get dependencies from other namespaces"""
dependencies = []
for base in self.inherit_from:
if base.namespace:
dependencies.append(base.fullname())
self._getDependencies(self.constructors, dependencies)
self._getDependencies(self.destructors, dependencies)
self._getDependencies(self.virtual_functions, dependencies)
self._getDependencies(self.childs, dependencies)
if dependencies:
return list(set(dependencies))
return None
def printOverloadedVirtualTable(self):
"""Only select overloaded methods from
virtual table
"""
res = ''
vfunc_res = ''
for vfunc in self.virtual_functions:
if vfunc.name.startswith('vtable_index') or\
vfunc.name.startswith('typeinfo()'):
continue
# Not overloaded by us
if vfunc.namespace and vfunc.namespace != self.namespace:
continue
vfunc_res = vfunc.__str__()
# Method already in result,
# It's the case for virtual descriptor
if vfunc_res in res:
continue
res += vfunc_res
return res
def __str__(self):
global print_raw_virtual_table
global cur_indent
res = '{}class {}'.format(_getIndent(), self.name)
if self.inherit_from:
res += ': '
bases = []
for base in self.inherit_from:
bases.append('public {}'.format(base.fullname()))
res += ', '.join(bases)
res += '\n{}{{\n'.format(_getIndent())
res += '{}public:\n'.format(_getIndent())
cur_indent += 1
for constructor in sorted(self.constructors):
res += constructor.__str__()
if len(self.constructors): res += '\n'
for destructor in sorted(self.destructors):
res += destructor.__str__()
if len(self.destructors): res += '\n'
# Do not sort virtual tables !
virtfuncs = ''
if print_raw_virtual_table:
for virtfunc in self.virtual_functions:
virtfuncs += virtfunc.__str__()
else:
virtfuncs = self.printOverloadedVirtualTable()
if len(virtfuncs): res += '{}\n'.format(virtfuncs)
methods = []
other = []
for child in self.childs:
if type(child) == Function: methods.append(child)
else: other.append(child)
for method in sorted(methods):
res += method.__str__()
if methods: res += '\n'
for child in sorted(other):
res += child.__str__()
if other: res += '\n'
cur_indent -= 1
res += '{}}};\n\n'.format(_getIndent())
return res