#!/usr/bin/env python3
import argparse
import elftools.elf.elffile
import elftools.dwarf.descriptions
from collections import namedtuple
from struct import unpack
import os

from lib.dump import decode_dump
from lib.avr import *


Entry = namedtuple('Entry', ['name', 'loc', 'size', 'declpos'])
Member = namedtuple('Member', ['name', 'off', 'size'])


def array_inc(loc, dim, idx=0):
    if idx == len(dim):
        return True
    loc[idx] += 1
    if loc[idx] == dim[idx]:
        loc[idx] = 0
        return array_inc(loc, dim, idx+1)
    return False

def get_type_size(type_DIE):
    while True:
        if 'DW_AT_byte_size' in type_DIE.attributes:
            return type_DIE, type_DIE.attributes.get('DW_AT_byte_size').value
        if 'DW_AT_type' not in type_DIE.attributes:
            return None
        type_DIE = type_DIE.get_DIE_from_attribute('DW_AT_type')

def get_type_arrsize(type_DIE):
    size = get_type_size(type_DIE)
    if size is None:
        return None
    byte_size = size[1]
    if size[0].tag != 'DW_TAG_pointer_type':
        array_DIE = get_type_def(type_DIE, 'DW_TAG_array_type')
        if array_DIE is not None:
            for range_DIE in array_DIE.iter_children():
                if range_DIE.tag == 'DW_TAG_subrange_type' and \
                   'DW_AT_upper_bound' in range_DIE.attributes:
                    dim = range_DIE.attributes['DW_AT_upper_bound'].value + 1
                    byte_size *= dim
    return byte_size

def get_type_def(type_DIE, type_tag):
    while True:
        if type_DIE.tag == type_tag:
            return type_DIE
        if 'DW_AT_type' not in type_DIE.attributes:
            return None
        type_DIE = type_DIE.get_DIE_from_attribute('DW_AT_type')

def get_FORM_block1(attr):
    if attr.form != 'DW_FORM_block1':
        return None
    if attr.value[0] == 3: # OP_addr
        return int.from_bytes(attr.value[1:], 'little')
    if attr.value[0] == 35: # OP_plus_uconst (ULEB128)
        v = 0
        s = 0
        for b in attr.value[1:]:
            v |= (b & 0x7f) << s
            if b & 0x80 == 0:
                break
            s += 7
        return v
    return None


def get_array_dims(DIE):
    array_DIE = get_type_def(DIE, 'DW_TAG_array_type')
    if array_DIE is None:
        return []

    array_dim = []
    for range_DIE in array_DIE.iter_children():
        if range_DIE.tag == 'DW_TAG_subrange_type' and \
           'DW_AT_upper_bound' in range_DIE.attributes:
            array_dim.append(range_DIE.attributes['DW_AT_upper_bound'].value + 1)
    return array_dim


def get_struct_members(DIE, entry, expand_structs, struct_gaps):
    struct_DIE = get_type_def(DIE, 'DW_TAG_structure_type')
    if struct_DIE is None:
        return []

    members = []
    for member_DIE in struct_DIE.iter_children():
        if member_DIE.tag == 'DW_TAG_member' and 'DW_AT_name' in member_DIE.attributes:
            m_name = member_DIE.attributes['DW_AT_name'].value.decode('ascii')
            m_off = get_FORM_block1(member_DIE.attributes['DW_AT_data_member_location'])
            m_byte_size = get_type_size(member_DIE)[1]

            # still expand member arrays
            m_array_dim = get_array_dims(member_DIE)

            if m_byte_size == 1 and len(m_array_dim) > 1:
                # likely string, remove one dimension
                m_byte_size *= m_array_dim.pop()
            if len(m_array_dim) == 0 or (len(m_array_dim) == 1 and m_array_dim[0] == 1):
                # plain entry
                members.append(Member(m_name, m_off, m_byte_size))
            elif len(m_array_dim) == 1 and m_byte_size == 1:
                # likely string, avoid expansion
                members.append(Member(m_name + '[]', m_off, m_array_dim[0]))
            else:
                # expand array entries
                m_array_pos = m_off
                m_array_loc = [0] * len(m_array_dim)
                while True:
                    # location index
                    sfx = ''
                    for d in range(len(m_array_dim)):
                        sfx += '[{}]'.format(str(m_array_loc[d]).rjust(len(str(m_array_dim[d]-1)), '0'))
                    members.append(Member(m_name + sfx, m_array_pos, m_byte_size))
                    # advance
                    if array_inc(m_array_loc, m_array_dim):
                        break
                    m_array_pos += m_byte_size

    if struct_gaps and len(members):
        # fill gaps in the middle
        members = list(sorted(members, key=lambda x: x.off))
        last_end = 0
        for n in range(len(members)):
            member = members[n]
            if member.off > last_end:
                members.append(Member('*UNKNOWN*', last_end, member.off - last_end))
            last_end = member.off + member.size

    if struct_gaps and len(members):
        # fill gap at the end
        members = list(sorted(members, key=lambda x: x.off))
        last = members[-1]
        last_end = last.off + last.size
        if entry.size > last_end:
            members.append(Member('*UNKNOWN*', last_end, entry.size - last_end))

    return members


def get_elf_globals(path, expand_structs, struct_gaps=True):
    fd = open(path, "rb")
    if fd is None:
        return
    elffile = elftools.elf.elffile.ELFFile(fd)
    if elffile is None or not elffile.has_dwarf_info():
        return

    # probably not needed, since we're decoding expressions manually
    elftools.dwarf.descriptions.set_global_machine_arch(elffile.get_machine_arch())
    dwarfinfo = elffile.get_dwarf_info()

    grefs = []
    for CU in dwarfinfo.iter_CUs():
        file_entries = dwarfinfo.line_program_for_CU(CU).header["file_entry"]

        for DIE in CU.iter_DIEs():
            # handle only variable types
            if DIE.tag != 'DW_TAG_variable':
                continue
            if 'DW_AT_location' not in DIE.attributes:
                continue
            if 'DW_AT_name' not in DIE.attributes and \
               'DW_AT_abstract_origin' not in DIE.attributes:
                continue

            # handle locations encoded directly as DW_OP_addr (leaf globals)
            loc = get_FORM_block1(DIE.attributes['DW_AT_location'])
            if loc is None or loc < SRAM_OFFSET or loc >= EEPROM_OFFSET:
                continue
            loc -= SRAM_OFFSET

            # variable name/type
            if 'DW_AT_name' not in DIE.attributes and \
               'DW_AT_abstract_origin' in DIE.attributes:
                DIE = DIE.get_DIE_from_attribute('DW_AT_abstract_origin')
                if 'DW_AT_location' in DIE.attributes:
                    # duplicate reference (handled directly), skip
                    continue
            if 'DW_AT_name' not in DIE.attributes:
                continue
            if 'DW_AT_type' not in DIE.attributes:
                continue

            name = DIE.attributes['DW_AT_name'].value.decode('ascii')

            # get final storage size
            size = get_type_size(DIE)
            if size is None:
                continue
            byte_size = size[1]

            # location of main definition
            declpos = ''
            if 'DW_AT_decl_file' in DIE.attributes and \
               'DW_AT_decl_line' in DIE.attributes:
                line = DIE.attributes['DW_AT_decl_line'].value
                fname = DIE.attributes['DW_AT_decl_file'].value
                if fname and fname - 1 < len(file_entries):
                    fname = file_entries[fname-1].name.decode('ascii')
                    declpos = '{}:{}'.format(fname, line)

            # fetch array dimensions (if known)
            array_dim = get_array_dims(DIE)

            # fetch structure members (one level only)
            entry = Entry(name, loc, byte_size, declpos)
            if not expand_structs or size[0].tag == 'DW_TAG_pointer_type':
                members = []
            else:
                members = get_struct_members(DIE, entry, expand_structs, struct_gaps)

            def expand_members(entry, members):
                if len(members) == 0:
                    grefs.append(entry)
                else:
                    for member in members:
                        grefs.append(Entry(entry.name + '.' + member.name,
                                           entry.loc + member.off, member.size,
                                           entry.declpos))

            if byte_size == 1 and len(array_dim) > 1:
                # likely string, remove one dimension
                byte_size *= array_dim.pop()
            if len(array_dim) == 0 or (len(array_dim) == 1 and array_dim[0] == 1):
                # plain entry
                expand_members(entry, members)
            elif len(array_dim) == 1 and byte_size == 1:
                # likely string, avoid expansion
                grefs.append(Entry(entry.name + '[]', entry.loc,
                                   array_dim[0], entry.declpos))
            else:
                # expand array entries
                array_pos = loc
                array_loc = [0] * len(array_dim)
                while True:
                    # location index
                    sfx = ''
                    for d in range(len(array_dim)):
                        sfx += '[{}]'.format(str(array_loc[d]).rjust(len(str(array_dim[d]-1)), '0'))
                    expand_members(Entry(entry.name + sfx, array_pos,
                                         byte_size, entry.declpos), members)
                    # advance
                    if array_inc(array_loc, array_dim):
                        break
                    array_pos += byte_size

    return grefs


def annotate_refs(grefs, addr, data, width, gaps=True, overlaps=True):
    last_end = None
    for entry in grefs:
        if entry.loc < addr:
            continue
        if entry.loc + entry.size > addr + len(data):
            continue

        pos = entry.loc-addr
        end_pos = pos + entry.size
        buf = data[pos:end_pos]

        buf_repr = ''
        if len(buf) in [1, 2, 4]:
            # attempt to decode as integers
            buf_repr += ' I:' + str(int.from_bytes(buf, 'little')).rjust(10)
        if len(buf) in [4, 8]:
            # attempt to decode as floats
            typ = 'f' if len(buf) == 4 else 'd'
            buf_repr += ' F:' + '{:10.3f}'.format(unpack(typ, buf)[0])

        if last_end is not None:
            if gaps and last_end < pos:
                # decode gaps
                gap_size = pos - last_end
                gap_buf = data[last_end:pos]
                print('{:04x} {} {:4} R:{}'.format(addr+last_end, "*UNKNOWN*".ljust(width),
                                                   gap_size, gap_buf.hex()))
            if overlaps and last_end > pos + 1:
                gap_size = pos - last_end
                print('{:04x} {} {:4}'.format(addr+last_end, "*OVERLAP*".ljust(width), gap_size))

        print('{:04x} {} {:4}{} R:{}'.format(entry.loc, entry.name.ljust(width),
                                             entry.size, buf_repr, buf.hex()))
        last_end = end_pos


def print_map(grefs):
    print('OFFSET\tSIZE\tNAME\tDECLPOS')
    for entry in grefs:
        print('{:x}\t{}\t{}\t{}'.format(entry.loc, entry.size, entry.name, entry.declpos))


def print_qdirstat(grefs):
    print('[qdirstat 1.0 cache file]')

    entries = {}
    for entry in grefs:
        # do not output registers when looking at space usage
        if entry.loc < SRAM_START:
            continue

        paths = list(filter(None, re.split(r'[\[\].]', entry.name)))
        base = entries
        for i in range(len(paths) - 1):
            name = paths[i]
            if name not in base:
                base[name] = {}
            base = base[name]
        name = paths[-1]
        if name in base:
            name = '{}_{:x}'.format(entry.name, entry.loc)
        base[name] = entry.size

    def walker(root, prefix):
        files = []
        dirs = []

        for name, entries in root.items():
            if type(entries) == int:
                files.append([name, entries])
            else:
                dirs.append([name, entries])

        # print files
        print('D\t{}\t{}\t0x0'.format(prefix, 0))
        for name, size in files:
            print('F\t{}\t{}\t0x0'.format(name, size))

        # recurse directories
        for name, entries in dirs:
            walker(entries, prefix + '/' + name)

    walker(entries, '/')


def main():
    ap = argparse.ArgumentParser(description="""
        Generate a symbol table map starting directly from an ELF
        firmware with DWARF3 debugging information.
        When used along with a memory dump obtained from the D2/D21/D23 g-code,
        show the value of each symbol which is within the address range.
    """)
    ap.add_argument('elf', help='ELF file containing DWARF debugging information')
    ap.add_argument('--no-gaps', action='store_true',
                    help='do not dump memory inbetween known symbols')
    ap.add_argument('--no-expand-structs', action='store_true',
                    help='do not decode structure data')
    ap.add_argument('--overlaps', action='store_true',
                    help='annotate overlaps greater than 1 byte')
    ap.add_argument('--name-width', type=int, default=50,
                    help='set name column width')
    g = ap.add_mutually_exclusive_group(required=True)
    g.add_argument('dump', nargs='?', help='RAM dump obtained from D2 g-code')
    g.add_argument('--map', action='store_true', help='dump global memory map')
    g.add_argument('--qdirstat', action='store_true',
                   help='dump qdirstat-compatible size usage map')
    args = ap.parse_args()

    grefs = get_elf_globals(args.elf, expand_structs=not args.no_expand_structs)
    grefs = list(sorted(grefs, key=lambda x: x.loc))
    if args.map:
        print_map(grefs)
    elif args.qdirstat:
        print_qdirstat(grefs)
    else:
        # fetch the memory data
        dump = decode_dump(args.dump)
        if dump is None:
            return os.EX_DATAERR

        # strip padding, if present
        addr_start = dump.ranges[0][0]
        addr_end = dump.ranges[-1][0]+dump.ranges[-1][1]
        data = dump.data[addr_start:addr_end]

        annotate_refs(grefs, addr_start, data,
                      width=args.name_width,
                      gaps=not args.no_gaps,
                      overlaps=args.overlaps)

if __name__ == '__main__':
    exit(main())