pw_tokenizer: Support archive files in elf_reader

- Support archive files by reading all ELFs contained within them.
- Make minor changes to support reading an ELF from within another file.
- Add tests for archive files.

Change-Id: I478c443c4f780a7399022543aef0d5f0aae6e2b8
diff --git a/pw_tokenizer/py/detokenize_test.py b/pw_tokenizer/py/detokenize_test.py
index be1a4b7..ef079a5 100755
--- a/pw_tokenizer/py/detokenize_test.py
+++ b/pw_tokenizer/py/detokenize_test.py
@@ -275,6 +275,7 @@
                              frozenset(detok.database.token_to_entries.keys()))
 
             # Open ELF by elf_reader.Elf
+            elf.seek(0)
             detok = detokenize.Detokenizer(elf_reader.Elf(elf))
             self.assertEqual(expected_tokens,
                              frozenset(detok.database.token_to_entries.keys()))
diff --git a/pw_tokenizer/py/elf_reader_test.py b/pw_tokenizer/py/elf_reader_test.py
index 050c7fc..e576598 100755
--- a/pw_tokenizer/py/elf_reader_test.py
+++ b/pw_tokenizer/py/elf_reader_test.py
@@ -70,20 +70,24 @@
   l (large), p (processor specific)
 """)
 
+TEST_ELF_PATH = os.path.join(os.path.dirname(__file__),
+                             'elf_reader_test_binary.elf')
+
 
 class ElfReaderTest(unittest.TestCase):
     """Tests the elf_reader.Elf class."""
     def setUp(self):
         super().setUp()
-        elf_path = os.path.join(os.path.dirname(__file__),
-                                'elf_reader_test_binary.elf')
-        self._elf_file = open(elf_path, 'rb')
+        self._elf_file = open(TEST_ELF_PATH, 'rb')
         self._elf = elf_reader.Elf(self._elf_file)
 
     def tearDown(self):
         super().tearDown()
         self._elf_file.close()
 
+    def _section(self, name):
+        return next(self._elf.sections_with_name(name))
+
     def test_readelf_comparison_using_the_readelf_binary(self):
         """Compares elf_reader to readelf's output."""
 
@@ -120,15 +124,15 @@
             self.assertEqual(section.offset, offset)
             self.assertEqual(section.size, size)
 
-    def test_dump_section(self):
-        self.assertEqual(self._elf.dump_section('.test_section_1'),
+    def test_dump_single_section(self):
+        self.assertEqual(self._elf.dump_sections(r'\.test_section_1'),
                          b'You cannot pass\0')
-        self.assertEqual(self._elf.dump_section('.test_section_2'),
+        self.assertEqual(self._elf.dump_sections(r'\.test_section_2'),
                          b'\xef\xbe\xed\xfe')
 
-    def test_dump_sections(self):
-        if (self._elf.sections_by_name['.test_section_1'].address <
-                self._elf.sections_by_name['.test_section_2'].address):
+    def test_dump_multiple_sections(self):
+        if (self._section('.test_section_1').address <
+                self._section('.test_section_2').address):
             contents = b'You cannot pass\0\xef\xbe\xed\xfe'
         else:
             contents = b'\xef\xbe\xed\xfeYou cannot pass\0'
@@ -136,11 +140,10 @@
         self.assertIn(self._elf.dump_sections(r'.test_section_\d'), contents)
 
     def test_read_values(self):
-        string_address = self._elf.sections_by_name['.test_section_1'].address
-        self.assertEqual(self._elf.read_value(string_address),
-                         b'You cannot pass')
+        address = self._section('.test_section_1').address
+        self.assertEqual(self._elf.read_value(address), b'You cannot pass')
 
-        int32_address = self._elf.sections_by_name['.test_section_2'].address
+        int32_address = self._section('.test_section_2').address
         self.assertEqual(self._elf.read_value(int32_address, 4),
                          b'\xef\xbe\xed\xfe')
 
@@ -152,6 +155,94 @@
         self.assertEqual(elf_reader.read_c_string(bytes_io), b'No terminator!')
         self.assertEqual(elf_reader.read_c_string(bytes_io), b'')
 
+    def test_compatible_file_for_elf(self):
+        self.assertTrue(elf_reader.compatible_file(self._elf_file))
+        self.assertTrue(elf_reader.compatible_file(io.BytesIO(b'\x7fELF')))
+
+    def test_compatible_file_for_elf_start_at_offset(self):
+        self._elf_file.seek(13)  # Seek ahead to get out of sync
+        self.assertTrue(elf_reader.compatible_file(self._elf_file))
+        self.assertEqual(13, self._elf_file.tell())
+
+    def test_compatible_file_for_invalid_elf(self):
+        self.assertFalse(elf_reader.compatible_file(io.BytesIO(b'\x7fELVESF')))
+
+
+def _archive_file(data: bytes) -> bytes:
+    return ('FILE ID 90123456'
+            'MODIFIED 012'
+            'OWNER '
+            'GROUP '
+            'MODE 678'
+            f'{len(data):10}'  # File size -- the only part that's needed.
+            '`\n'.encode() + data)
+
+
+class ArchiveTest(unittest.TestCase):
+    """Tests reading from archive files."""
+    def setUp(self):
+        super().setUp()
+
+        with open(TEST_ELF_PATH, 'rb') as fd:
+            self._elf_data = fd.read()
+
+        self._archive_entries = b'blah', b'hello', self._elf_data
+
+        self._archive_data = elf_reader.ARCHIVE_MAGIC + b''.join(
+            _archive_file(f) for f in self._archive_entries)
+        self._archive = io.BytesIO(self._archive_data)
+
+    def test_compatible_file_for_archive(self):
+        self.assertTrue(elf_reader.compatible_file(io.BytesIO(b'!<arch>\n')))
+        self.assertTrue(elf_reader.compatible_file(self._archive))
+
+    def test_compatible_file_for_invalid_archive(self):
+        self.assertFalse(elf_reader.compatible_file(io.BytesIO(b'!<arch>')))
+
+    def test_iterate_over_files(self):
+        for expected, size in zip(self._archive_entries,
+                                  elf_reader.files_in_archive(self._archive)):
+            self.assertEqual(expected, self._archive.read(size))
+
+    def test_iterate_over_empty_archive(self):
+        with self.assertRaises(StopIteration):
+            next(iter(elf_reader.files_in_archive(io.BytesIO(b'!<arch>\n'))))
+
+    def test_iterate_over_invalid_archive(self):
+        with self.assertRaises(elf_reader.FileDecodeError):
+            for _ in elf_reader.files_in_archive(
+                    io.BytesIO(b'!<arch>blah blahblah')):
+                pass
+
+    def test_iterate_over_archive_with_invalid_size(self):
+        data = elf_reader.ARCHIVE_MAGIC + _archive_file(b'$' * 3210)
+        file = io.BytesIO(data)
+
+        # Iterate over the file normally.
+        for size in elf_reader.files_in_archive(file):
+            self.assertEqual(b'$' * 3210, file.read(size))
+
+        # Replace the size with a hex number, which is not valid.
+        with self.assertRaises(elf_reader.FileDecodeError):
+            for _ in elf_reader.files_in_archive(
+                    io.BytesIO(data.replace(b'3210', b'0x99'))):
+                pass
+
+    def test_elf_reader_dump_single_section(self):
+        elf = elf_reader.Elf(self._archive)
+        self.assertEqual(elf.dump_sections(r'\.test_section_1'),
+                         b'You cannot pass\0')
+        self.assertEqual(elf.dump_sections(r'\.test_section_2'),
+                         b'\xef\xbe\xed\xfe')
+
+    def test_elf_reader_read_values(self):
+        elf = elf_reader.Elf(self._archive)
+        address = next(elf.sections_with_name('.test_section_1')).address
+        self.assertEqual(elf.read_value(address), b'You cannot pass')
+
+        int32_address = next(elf.sections_with_name('.test_section_2')).address
+        self.assertEqual(elf.read_value(int32_address, 4), b'\xef\xbe\xed\xfe')
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/pw_tokenizer/py/pw_tokenizer/database.py b/pw_tokenizer/py/pw_tokenizer/database.py
index a717321..cf9319d 100755
--- a/pw_tokenizer/py/pw_tokenizer/database.py
+++ b/pw_tokenizer/py/pw_tokenizer/database.py
@@ -54,7 +54,7 @@
 
 def read_tokenizer_metadata(elf) -> Dict[str, int]:
     """Reads the metadata entries from an ELF."""
-    sections = _elf_reader(elf).dump_section('.tokenized.meta')
+    sections = _elf_reader(elf).dump_sections(r'\.tokenized\.meta')
 
     metadata: Dict[str, int] = {}
     if sections is not None:
@@ -87,14 +87,14 @@
 
         # Read the path as an ELF file.
         with open(db, 'rb') as fd:
-            if elf_reader.file_is_elf(fd):
+            if elf_reader.compatible_file(fd):
                 return tokens.Database.from_strings(_read_strings_from_elf(fd))
 
         # Read the path as a packed binary or CSV file.
         return tokens.DatabaseFile(db)
 
     # Assume that it's a file object and check if it's an ELF.
-    if elf_reader.file_is_elf(db):
+    if elf_reader.compatible_file(db):
         return tokens.Database.from_strings(_read_strings_from_elf(db))
 
     # Read the database as CSV or packed binary from a file object's path.
diff --git a/pw_tokenizer/py/pw_tokenizer/elf_reader.py b/pw_tokenizer/py/pw_tokenizer/elf_reader.py
index 3788b05..079e808 100755
--- a/pw_tokenizer/py/pw_tokenizer/elf_reader.py
+++ b/pw_tokenizer/py/pw_tokenizer/elf_reader.py
@@ -17,14 +17,73 @@
 This module provides tools for dumping the contents of an ELF section. It can
 also be used to read values at a particular address. A command line interface
 for both of these features is provided.
+
+This module supports any ELF-format file, including .o and .so files. This
+module also has basic support for archive (.a) files. All ELF files in an
+archive are read as one unit.
 """
 
 import argparse
-import collections
 import re
 import struct
 import sys
-from typing import BinaryIO, Dict, Iterable, NamedTuple, Optional, Tuple, Union
+from typing import BinaryIO, Dict, Iterable, NamedTuple, Optional
+from typing import Pattern, Tuple, Union
+
+ARCHIVE_MAGIC = b'!<arch>\n'
+ELF_MAGIC = b'\x7fELF'
+
+
+def _check_next_bytes(fd: BinaryIO, expected: bytes, what: str) -> None:
+    actual = fd.read(len(expected))
+    if expected != actual:
+        raise FileDecodeError(
+            f'Invalid {what}: expected {expected!r}, found {actual!r}')
+
+
+def files_in_archive(fd: BinaryIO) -> Iterable[int]:
+    """Seeks to each file in an archive and yields its size."""
+
+    _check_next_bytes(fd, ARCHIVE_MAGIC, 'archive magic number')
+
+    while True:
+        # Each file in an archive is prefixed with an ASCII header:
+        #
+        #   16 B - file identifier (text)
+        #   12 B - file modification timestamp (decimal)
+        #    6 B - owner ID (decimal)
+        #    6 B - group ID (decimal)
+        #    8 B - file mode (octal)
+        #   10 B - file size in bytes (decimal)
+        #    2 B - ending characters (`\n)
+        #
+        # Skip the unused portions of the file header, then read the size.
+        fd.seek(16 + 12 + 6 + 6 + 8, 1)
+        size_str = fd.read(10)
+        if not size_str:
+            return
+
+        try:
+            size = int(size_str, 10)
+        except ValueError as exc:
+            raise FileDecodeError(
+                'Archive file sizes must be decimal integers') from exc
+
+        _check_next_bytes(fd, b'`\n', 'archive file header ending')
+        offset = fd.tell()  # Store offset in case the caller reads the file.
+
+        yield size
+
+        fd.seek(offset + size)
+
+
+def _elf_files_in_archive(fd: BinaryIO):
+    if _bytes_match(fd, ELF_MAGIC):
+        yield  # The value isn't used, so just yield None.
+    else:
+        for _ in files_in_archive(fd):
+            if _bytes_match(fd, ELF_MAGIC):
+                yield
 
 
 class Field(NamedTuple):
@@ -76,35 +135,41 @@
         string += byte
 
 
-def file_is_elf(fd: BinaryIO) -> bool:
-    """Returns true if the provided file starts with the ELF magic number."""
+def _bytes_match(fd: BinaryIO, expected: bytes) -> bool:
+    """Peeks at the next bytes to see if they match the expected."""
     try:
-        fd.seek(0)
-        magic_number = fd.read(4)
-        fd.seek(0)
-        return magic_number == b'\x7fELF'
+        offset = fd.tell()
+        data = fd.read(len(expected))
+        fd.seek(offset)
+        return data == expected
     except IOError:
         return False
 
 
-class ElfDecodeError(Exception):
+def compatible_file(fd: BinaryIO) -> bool:
+    """True if the file type is supported (ELF or archive)."""
+    offset = fd.tell()
+    fd.seek(0)
+    result = _bytes_match(fd, ELF_MAGIC) or _bytes_match(fd, ARCHIVE_MAGIC)
+    fd.seek(offset)
+    return result
+
+
+class FileDecodeError(Exception):
     """Invalid data was read from an ELF file."""
 
 
 class FieldReader:
     """Reads ELF fields defined with a Field tuple from an ELF file."""
     def __init__(self, elf: BinaryIO):
-        if not file_is_elf(elf):
-            raise ElfDecodeError(r"ELF files must start with b'\x7fELF'")
-
         self._elf = elf
+        self.file_offset = self._elf.tell()
+
+        _check_next_bytes(self._elf, ELF_MAGIC, 'ELF file header')
+        size_field = self._elf.read(1)  # e_ident[EI_CLASS] (address size)
 
         int_unpacker = self._determine_integer_format()
 
-        # Set up decoding based on the address size
-        self._elf.seek(0x04)  # e_ident[EI_CLASS] (address size)
-        size_field = self._elf.read(1)
-
         if size_field == b'\x01':
             self.offset = lambda field: field.offset_32
             self._size = lambda field: field.size_32
@@ -114,18 +179,17 @@
             self._size = lambda field: field.size_64
             self._decode = lambda f, d: int_unpacker[f.size_64].unpack(d)[0]
         else:
-            raise ElfDecodeError('Unknown size {!r}'.format(size_field))
+            raise FileDecodeError('Unknown size {!r}'.format(size_field))
 
     def _determine_integer_format(self) -> Dict[int, struct.Struct]:
         """Returns a dict of structs used for converting bytes to integers."""
-        self._elf.seek(0x05)  # e_ident[EI_DATA] (endianness)
-        endianness_byte = self._elf.read(1)
+        endianness_byte = self._elf.read(1)  # e_ident[EI_DATA] (endianness)
         if endianness_byte == b'\x01':
             endianness = '<'
         elif endianness_byte == b'\x02':
             endianness = '>'
         else:
-            raise ElfDecodeError(
+            raise FileDecodeError(
                 'Unknown endianness {!r}'.format(endianness_byte))
 
         return {
@@ -136,24 +200,25 @@
         }
 
     def read(self, field: Field, base: int = 0) -> int:
-        self._elf.seek(base + self.offset(field))
+        self._elf.seek(self.file_offset + base + self.offset(field))
         data = self._elf.read(self._size(field))
         return self._decode(field, data)
 
-    def read_string(self, address: int) -> str:
-        self._elf.seek(address)
+    def read_string(self, offset: int) -> str:
+        self._elf.seek(self.file_offset + offset)
         return read_c_string(self._elf).decode()
 
 
 class Elf:
     """Represents an ELF file and the sections in it."""
-    class Section:
+    class Section(NamedTuple):
         """Info about a section in an ELF file."""
-        def __init__(self, name: str, address: int, offset: int, size: int):
-            self.name = name
-            self.address = address
-            self.offset = offset
-            self.size = size
+        name: str
+        address: int
+        offset: int
+        size: int
+
+        file_offset: int  # Starting place in the file; 0 unless in an archive.
 
         def range(self) -> range:
             return range(self.address, self.address + self.size)
@@ -161,45 +226,38 @@
         def __lt__(self, other) -> bool:
             return self.address < other.address
 
-        def __str__(self) -> str:
-            return ('Section(name={self.name}, address=0x{self.address:08x} '
-                    'offset=0x{self.offset:x} size=0x{self.size:x})').format(
-                        self=self)
-
-        def __repr__(self) -> str:
-            return str(self)
-
     def __init__(self, elf: BinaryIO):
         self._elf = elf
         self.sections: Tuple[Elf.Section, ...] = tuple(self._list_sections())
-        self.sections_by_name: Dict[str,
-                                    Elf.Section] = collections.OrderedDict(
-                                        (section.name, section)
-                                        for section in self.sections)
 
     def _list_sections(self) -> Iterable['Elf.Section']:
         """Reads the section headers to enumerate all ELF sections."""
-        reader = FieldReader(self._elf)
-        base = reader.read(FILE_HEADER.section_header_offset)
-        section_header_size = reader.offset(SECTION_HEADER.section_header_end)
+        for _ in _elf_files_in_archive(self._elf):
+            reader = FieldReader(self._elf)
+            base = reader.read(FILE_HEADER.section_header_offset)
+            section_header_size = reader.offset(
+                SECTION_HEADER.section_header_end)
 
-        # Find the section with the section names in it
-        names_section_header_base = base + section_header_size * reader.read(
-            FILE_HEADER.section_names_index)
-        names_table_base = reader.read(SECTION_HEADER.section_offset,
-                                       names_section_header_base)
+            # Find the section with the section names in it.
+            names_section_header_base = (
+                base + section_header_size *
+                reader.read(FILE_HEADER.section_names_index))
+            names_table_base = reader.read(SECTION_HEADER.section_offset,
+                                           names_section_header_base)
 
-        base = reader.read(FILE_HEADER.section_header_offset)
-        for _ in range(reader.read(FILE_HEADER.section_count)):
-            name_offset = reader.read(SECTION_HEADER.section_name_offset, base)
+            base = reader.read(FILE_HEADER.section_header_offset)
+            for _ in range(reader.read(FILE_HEADER.section_count)):
+                name_offset = reader.read(SECTION_HEADER.section_name_offset,
+                                          base)
 
-            yield self.Section(
-                reader.read_string(names_table_base + name_offset),
-                reader.read(SECTION_HEADER.section_address, base),
-                reader.read(SECTION_HEADER.section_offset, base),
-                reader.read(SECTION_HEADER.section_size, base))
+                yield self.Section(
+                    reader.read_string(names_table_base + name_offset),
+                    reader.read(SECTION_HEADER.section_address, base),
+                    reader.read(SECTION_HEADER.section_offset, base),
+                    reader.read(SECTION_HEADER.section_size, base),
+                    reader.file_offset)
 
-            base += section_header_size
+                base += section_header_size
 
     def section_by_address(self, address: int) -> Optional['Elf.Section']:
         """Returns the section that contains the provided address, if any."""
@@ -210,6 +268,11 @@
 
         return None
 
+    def sections_with_name(self, name: str) -> Iterable['Elf.Section']:
+        for section in self.sections:
+            if section.name == name:
+                yield section
+
     def read_value(self,
                    address: int,
                    size: Optional[int] = None) -> Union[None, bytes, int]:
@@ -219,31 +282,22 @@
             return None
 
         assert section.address <= address
-        self._elf.seek(section.offset + address - section.address)
+        self._elf.seek(section.file_offset + section.offset + address -
+                       section.address)
 
         if size is None:
             return read_c_string(self._elf)
 
         return self._elf.read(size)
 
-    def dump_section(self, name: str) -> Optional[bytes]:
-        """Dumps section contents as a byte string; None if no match."""
-        try:
-            section = self.sections_by_name[name]
-        except KeyError:
-            return None
-
-        self._elf.seek(section.offset)
-        return self._elf.read(section.size)
-
-    def dump_sections(self, name_regex) -> Optional[bytes]:
-        """Dumps a binary string containing the sections matching name_regex."""
-        name_regex = re.compile(name_regex)
+    def dump_sections(self, name: Union[str, Pattern[str]]) -> Optional[bytes]:
+        """Dumps a binary string containing the sections matching the regex."""
+        name_regex = re.compile(name)
 
         sections = []
         for section in self.sections:
             if name_regex.match(section.name):
-                self._elf.seek(section.offset)
+                self._elf.seek(section.file_offset + section.offset)
                 sections.append(self._elf.read(section.size))
 
         return b''.join(sections) if sections else None
@@ -268,15 +322,12 @@
         output(value)
 
 
-def _dump_sections(elf: Elf, output, name: str, regex) -> None:
-    if not name and not regex:
+def _dump_sections(elf: Elf, output, sections: Iterable[Pattern[str]]) -> None:
+    if not sections:
         output(elf.summary().encode())
         return
 
-    for section in name:
-        output(elf.dump_section(section))
-
-    for section_pattern in regex:
+    for section_pattern in sections:
         output(elf.dump_sections(section_pattern))
 
 
@@ -307,8 +358,11 @@
 
     section_parser = subparsers.add_parser('section')
     section_parser.set_defaults(handler=_dump_sections)
-    section_parser.add_argument('-n', '--name', default=[], action='append')
-    section_parser.add_argument('-r', '--regex', default=[], action='append')
+    section_parser.add_argument('sections',
+                                metavar='section_regex',
+                                nargs='*',
+                                type=re.compile,
+                                help='section name regular expression')
 
     address_parser = subparsers.add_parser('address')
     address_parser.set_defaults(handler=_read_addresses)