Browse Source

Move the region offset inside RegionFile

Nicola Corna 6 years ago
parent
commit
604e1bf812
1 changed files with 89 additions and 83 deletions
  1. 89 83
      me_cleaner.py

+ 89 - 83
me_cleaner.py

@@ -59,27 +59,35 @@ class RegionFile:
         self.region_end = region_end
 
     def read(self, n):
-        return self.f.read(n)
+        if f.tell() + n <= self.region_end:
+            return self.f.read(n)
+        else:
+            raise OutOfRegionException()
 
     def readinto(self, b):
-        return self.f.readinto(b)
+        if f.tell() + len(b) <= self.region_end:
+            return self.f.readinto(b)
+        else:
+            raise OutOfRegionException()
 
     def seek(self, offset):
-        return self.f.seek(offset)
+        if self.region_start + offset <= self.region_end:
+            return self.f.seek(self.region_start + offset)
+        else:
+            raise OutOfRegionException()
 
     def write_to(self, offset, data):
-        if offset >= self.region_start and \
-           offset + len(data) <= self.region_end:
-            self.f.seek(offset)
+        if self.region_start + offset + len(data) <= self.region_end:
+            self.f.seek(self.region_start + offset)
             return self.f.write(data)
         else:
             raise OutOfRegionException()
 
     def fill_range(self, start, end, fill):
-        if start >= self.region_start and end <= self.region_end:
+        if self.region_start + end <= self.region_end:
             if start < end:
                 block = fill * 4096
-                self.f.seek(start)
+                self.f.seek(self.region_start + start)
                 self.f.writelines(itertools.repeat(block,
                                                    (end - start) // 4096))
                 self.f.write(block[:(end - start) % 4096])
@@ -87,31 +95,32 @@ class RegionFile:
             raise OutOfRegionException()
 
     def move_range(self, offset_from, size, offset_to, fill):
-        if offset_from >= self.region_start and \
-           offset_from + size <= self.region_end and \
-           offset_to >= self.region_start and \
-           offset_to + size <= self.region_end:
+        if self.region_start + offset_from + size <= self.region_end and \
+           self.region_start + offset_to + size <= self.region_end:
             for i in range(0, size, 4096):
-                self.f.seek(offset_from + i, 0)
+                self.f.seek(self.region_start + offset_from + i, 0)
                 block = self.f.read(min(size - i, 4096))
-                self.f.seek(offset_from + i, 0)
+                self.f.seek(self.region_start + offset_from + i, 0)
                 self.f.write(fill * len(block))
-                self.f.seek(offset_to + i, 0)
+                self.f.seek(self.region_start + offset_to + i, 0)
                 self.f.write(block)
         else:
             raise OutOfRegionException()
 
     def save(self, filename, size):
-        self.f.seek(self.region_start)
-        copyf = open(filename, "w+b")
-        for i in range(0, size, 4096):
-            copyf.write(self.f.read(min(size - i, 4096)))
-        return copyf
+        if self.region_start + size <= self.region_end:
+            self.f.seek(self.region_start)
+            copyf = open(filename, "w+b")
+            for i in range(0, size, 4096):
+                copyf.write(self.f.read(min(size - i, 4096)))
+            return copyf
+        else:
+            raise OutOfRegionException()
 
 
-def get_chunks_offsets(llut, me_start):
+def get_chunks_offsets(llut):
     chunk_count = unpack("<I", llut[0x04:0x08])[0]
-    huffman_stream_end = sum(unpack("<II", llut[0x10:0x18])) + me_start
+    huffman_stream_end = sum(unpack("<II", llut[0x10:0x18]))
     nonzero_offsets = [huffman_stream_end]
     offsets = []
 
@@ -120,7 +129,7 @@ def get_chunks_offsets(llut, me_start):
         offset = 0
 
         if chunk[3] != 0x80:
-            offset = unpack("<I", chunk[0:3] + b"\x00")[0] + me_start
+            offset = unpack("<I", chunk[0:3] + b"\x00")[0]
 
         offsets.append([offset, 0])
         if offset != 0:
@@ -176,7 +185,7 @@ def remove_modules(f, mod_headers, ftpr_offset, me_end):
                     chunk_size = unpack("<I", llut[0x30:0x34])[0]
 
                     llut += f.read(chunk_count * 4)
-                    chunks_offsets = get_chunks_offsets(llut, me_start)
+                    chunks_offsets = get_chunks_offsets(llut)
                 else:
                     sys.exit("Huffman modules found, but LLUT is not present")
 
@@ -255,14 +264,13 @@ def print_check_partition_signature(f, offset):
                  "ME/TXE image valid?")
 
 
-def relocate_partition(f, me_start, me_end, partition_header_offset,
+def relocate_partition(f, me_end, partition_header_offset,
                        new_offset, mod_headers):
 
     f.seek(partition_header_offset)
     name = f.read(4).rstrip(b"\x00").decode("ascii")
     f.seek(partition_header_offset + 0x8)
     old_offset, partition_size = unpack("<II", f.read(0x8))
-    old_offset += me_start
 
     llut_start = 0
     for mod_header in mod_headers:
@@ -278,8 +286,7 @@ def relocate_partition(f, me_start, me_end, partition_header_offset,
         f.seek(llut_start + 0x9)
         lut_start_corr = unpack("<H", f.read(2))[0]
         new_offset = max(new_offset,
-                         lut_start_corr + me_start - llut_start - 0x40 +
-                         old_offset)
+                         lut_start_corr - llut_start - 0x40 + old_offset)
         new_offset = ((new_offset + 0x1f) // 0x20) * 0x20
 
     offset_diff = new_offset - old_offset
@@ -289,15 +296,14 @@ def relocate_partition(f, me_start, me_end, partition_header_offset,
 
     print(" Adjusting FPT entry...")
     f.write_to(partition_header_offset + 0x8,
-               pack("<I", new_offset - me_start))
+               pack("<I", new_offset))
 
     if mod_headers:
         if llut_start != 0:
             f.seek(llut_start)
             if f.read(4) == b"LLUT":
                 print(" Adjusting LUT start offset...")
-                lut_offset = llut_start + offset_diff + 0x40 - \
-                    lut_start_corr - me_start
+                lut_offset = llut_start + offset_diff + 0x40 - lut_start_corr
                 f.write_to(llut_start + 0x0c, pack("<I", lut_offset))
 
                 print(" Adjusting Huffman start offset...")
@@ -330,7 +336,7 @@ def relocate_partition(f, me_start, me_end, partition_header_offset,
     return new_offset
 
 
-def check_and_remove_modules(f, me_start, me_end, offset, min_offset,
+def check_and_remove_modules(f, me_end, offset, min_offset,
                              relocate, keep_modules):
 
     f.seek(offset + 0x20)
@@ -358,9 +364,7 @@ def check_and_remove_modules(f, me_start, me_end, offset, min_offset,
                                           offset, me_end)
 
             if args.relocate:
-                new_offset = relocate_partition(f, me_start, me_end,
-                                                me_start + 0x30,
-                                                min_offset + me_start,
+                new_offset = relocate_partition(f, me_end, 0x30, min_offset,
                                                 mod_headers)
                 end_addr += new_offset - offset
                 offset = new_offset
@@ -377,7 +381,7 @@ def check_and_remove_modules(f, me_start, me_end, offset, min_offset,
     return -1, offset
 
 
-def check_and_remove_modules_me11(f, me_start, me_end, partition_offset,
+def check_and_remove_modules_me11(f, me_end, partition_offset,
                                   partition_length, min_offset, relocate,
                                   keep_modules):
 
@@ -435,8 +439,7 @@ def check_and_remove_modules_me11(f, me_start, me_end, partition_offset,
                 end_data = max(end_data, end)
 
     if relocate:
-        new_offset = relocate_partition(f, me_start, me_end, me_start + 0x30,
-                                        min_offset + me_start, [])
+        new_offset = relocate_partition(f, me_end, 0x30, min_offset, [])
         end_data += new_offset - partition_offset
         partition_offset = new_offset
 
@@ -538,9 +541,10 @@ if __name__ == "__main__":
            args.soft_disable or args.soft_disable_only:
             sys.exit("-d, -D, -M, -S and -s require a full dump")
 
-        me_start = 0
         f.seek(0, 2)
+        me_start = 0
         me_end = f.tell()
+        mef = RegionFile(f, me_start, me_end + 1)
 
     elif magic == b"\x5a\xa5\xf0\x0f":
         print("Full image detected")
@@ -564,8 +568,10 @@ if __name__ == "__main__":
         if me_start >= me_end:
             sys.exit("The ME/TXE region in this image has been disabled")
 
-        f.seek(me_start + 0x10)
-        if f.read(4) != b"$FPT":
+        mef = RegionFile(f, me_start, me_end + 1)
+
+        mef.seek(0x10)
+        if mef.read(4) != b"$FPT":
             sys.exit("The ME/TXE region is corrupted or missing")
 
         print("The ME/TXE region goes from {:#x} to {:#x}"
@@ -575,14 +581,14 @@ if __name__ == "__main__":
 
     end_addr = me_end
 
-    print("Found FPT header at {:#x}".format(me_start + 0x10))
+    print("Found FPT header at {:#x}".format(mef.region_start + 0x10))
 
-    f.seek(me_start + 0x14)
-    entries = unpack("<I", f.read(4))[0]
+    mef.seek(0x14)
+    entries = unpack("<I", mef.read(4))[0]
     print("Found {} partition(s)".format(entries))
 
-    f.seek(me_start + 0x30)
-    partitions = f.read(entries * 0x20)
+    mef.seek(0x30)
+    partitions = mef.read(entries * 0x20)
 
     ftpr_header = b""
 
@@ -595,20 +601,19 @@ if __name__ == "__main__":
         sys.exit("FTPR header not found, this image doesn't seem to be valid")
 
     ftpr_offset, ftpr_length = unpack("<II", ftpr_header[0x08:0x10])
-    ftpr_offset += me_start
     print("Found FTPR header: FTPR partition spans from {:#x} to {:#x}"
           .format(ftpr_offset, ftpr_offset + ftpr_length))
 
-    f.seek(ftpr_offset)
-    if f.read(4) == b"$CPD":
+    mef.seek(ftpr_offset)
+    if mef.read(4) == b"$CPD":
         me11 = True
-        num_entries = unpack("<I", f.read(4))[0]
+        num_entries = unpack("<I", mef.read(4))[0]
 
-        f.seek(ftpr_offset + 0x10)
+        mef.seek(ftpr_offset + 0x10)
         ftpr_mn2_offset = -1
 
         for i in range(0, num_entries):
-            data = f.read(0x18)
+            data = mef.read(0x18)
             name = data[0x0:0xc].rstrip(b"\x00").decode("ascii")
             offset = unpack("<I", data[0xc:0xf] + b"\x00")[0]
 
@@ -617,24 +622,24 @@ if __name__ == "__main__":
                 break
 
         if ftpr_mn2_offset >= 0:
-            check_mn2_tag(f, ftpr_offset + ftpr_mn2_offset)
+            check_mn2_tag(mef, ftpr_offset + ftpr_mn2_offset)
             print("Found FTPR manifest at {:#x}"
                   .format(ftpr_offset + ftpr_mn2_offset))
         else:
             sys.exit("Can't find the manifest of the FTPR partition")
 
     else:
-        check_mn2_tag(f, ftpr_offset)
+        check_mn2_tag(mef, ftpr_offset)
         me11 = False
         ftpr_mn2_offset = 0
 
-    f.seek(ftpr_offset + ftpr_mn2_offset + 0x24)
-    version = unpack("<HHHH", f.read(0x08))
+    mef.seek(ftpr_offset + ftpr_mn2_offset + 0x24)
+    version = unpack("<HHHH", mef.read(0x08))
     print("ME/TXE firmware version {}"
           .format('.'.join(str(i) for i in version)))
 
-    f.seek(ftpr_offset + ftpr_mn2_offset + 0x80)
-    pubkey_md5 = hashlib.md5(f.read(0x104)).hexdigest()
+    mef.seek(ftpr_offset + ftpr_mn2_offset + 0x80)
+    pubkey_md5 = hashlib.md5(mef.read(0x104)).hexdigest()
 
     if pubkey_md5 in pubkeys_md5:
         variant, pubkey_versions = pubkeys_md5[pubkey_md5]
@@ -655,10 +660,10 @@ if __name__ == "__main__":
         shutil.copy(args.file, args.output)
         f = open(args.output, "r+b")
 
-    mef = RegionFile(f, me_start, me_end)
+    mef = RegionFile(f, me_start, me_end + 1)
 
     if me_start > 0:
-        fdf = RegionFile(f, fd_start, fd_end)
+        fdf = RegionFile(f, fd_start, fd_end + 1)
 
     if not args.check:
         if not args.soft_disable_only:
@@ -708,32 +713,31 @@ if __name__ == "__main__":
                             extra_part_end = max(extra_part_end, part_end)
                         print("NOT removed")
                     else:
-                        mef.fill_range(me_start + part_start,
-                                       me_start + part_end, b"\xff")
+                        mef.fill_range(part_start, part_end, b"\xff")
                         print("removed")
 
             print("Removing partition entries in FPT...")
-            mef.write_to(me_start + 0x30, unremovable_part_fpt)
-            mef.write_to(me_start + 0x14,
+            mef.write_to(0x30, unremovable_part_fpt)
+            mef.write_to(0x14,
                          pack("<I", len(unremovable_part_fpt) // 0x20))
 
-            mef.fill_range(me_start + 0x30 + len(unremovable_part_fpt),
-                           me_start + 0x30 + len(partitions), b"\xff")
+            mef.fill_range(0x30 + len(unremovable_part_fpt),
+                           0x30 + len(partitions), b"\xff")
 
             if (not blacklist and "EFFS" not in whitelist) or \
                "EFFS" in blacklist:
                 print("Removing EFFS presence flag...")
-                mef.seek(me_start + 0x24)
+                mef.seek(0x24)
                 flags = unpack("<I", mef.read(4))[0]
                 flags &= ~(0x00000001)
-                mef.write_to(me_start + 0x24, pack("<I", flags))
+                mef.write_to(0x24, pack("<I", flags))
 
             if me11:
-                mef.seek(me_start + 0x10)
+                mef.seek(0x10)
                 header = bytearray(mef.read(0x20))
                 header[0x0b] = 0x00
             else:
-                mef.seek(me_start)
+                mef.seek(0)
                 header = bytearray(mef.read(0x30))
                 header[0x1b] = 0x00
             checksum = (0x100 - sum(header) & 0xff) & 0xff
@@ -743,19 +747,19 @@ if __name__ == "__main__":
             # 0x30 bytes in ME < 11 or bytes 0x10:0x30 in ME >= 11 (except for
             # 0x1b, the checksum itself). In other words, the sum of those
             # bytes must be always 0x00.
-            mef.write_to(me_start + 0x1b, pack("B", checksum))
+            mef.write_to(0x1b, pack("B", checksum))
 
             print("Reading FTPR modules list...")
             if me11:
                 end_addr, ftpr_offset = \
-                    check_and_remove_modules_me11(mef, me_start, me_end,
+                    check_and_remove_modules_me11(mef, me_end,
                                                   ftpr_offset, ftpr_length,
                                                   min_ftpr_offset,
                                                   args.relocate,
                                                   args.keep_modules)
             else:
                 end_addr, ftpr_offset = \
-                    check_and_remove_modules(mef, me_start, me_end, ftpr_offset,
+                    check_and_remove_modules(mef, me_end, ftpr_offset,
                                              min_ftpr_offset, args.relocate,
                                              args.keep_modules)
 
@@ -765,11 +769,12 @@ if __name__ == "__main__":
                 end_addr += spared_blocks * 0x1000
 
                 print("The ME minimum size should be {0} bytes "
-                      "({0:#x} bytes)".format(end_addr - me_start))
+                      "({0:#x} bytes)".format(end_addr))
 
                 if me_start > 0:
                     print("The ME region can be reduced up to:\n"
-                          " {:08x}:{:08x} me".format(me_start, end_addr - 1))
+                          " {:08x}:{:08x} me"
+                          .format(me_start, me_start + end_addr - 1))
                 elif args.truncate:
                     print("Truncating file at {:#x}...".format(end_addr))
                     f.truncate(end_addr)
@@ -808,12 +813,14 @@ if __name__ == "__main__":
             if bios_start == me_end:
                 print("Modifying the regions of the extracted descriptor...")
                 print(" {:08x}:{:08x} me   --> {:08x}:{:08x} me"
-                      .format(me_start, me_end - 1, me_start, end_addr - 1))
+                      .format(me_start, me_end - 1,
+                              me_start, me_start + end_addr - 1))
                 print(" {:08x}:{:08x} bios --> {:08x}:{:08x} bios"
-                      .format(bios_start, bios_end - 1, end_addr, bios_end - 1))
+                      .format(bios_start, bios_end - 1,
+                              me_start + end_addr, bios_end - 1))
 
-                flreg1 = start_end_to_flreg(end_addr, bios_end)
-                flreg2 = start_end_to_flreg(me_start, end_addr)
+                flreg1 = start_end_to_flreg(me_start + end_addr, bios_end)
+                flreg2 = start_end_to_flreg(me_start, me_start + end_addr)
 
                 fdf_copy.seek(frba + 0x4)
                 fdf_copy.write(pack("<II", flreg1, flreg2))
@@ -833,7 +840,7 @@ if __name__ == "__main__":
         if args.truncate:
             print("Extracting and truncating the ME image to \"{}\"..."
                   .format(args.extract_me))
-            mef_copy = mef.save(args.extract_me, end_addr - me_start)
+            mef_copy = mef.save(args.extract_me, end_addr)
         else:
             print("Extracting the ME image to \"{}\"..."
                   .format(args.extract_me))
@@ -841,12 +848,11 @@ if __name__ == "__main__":
 
         print("Checking the FTPR RSA signature of the extracted ME image... ",
               end="")
-        print_check_partition_signature(mef_copy, ftpr_offset +
-                                        ftpr_mn2_offset - me_start)
+        print_check_partition_signature(mef_copy, ftpr_offset + ftpr_mn2_offset)
         mef_copy.close()
 
     print("Checking the FTPR RSA signature... ", end="")
-    print_check_partition_signature(f, ftpr_offset + ftpr_mn2_offset)
+    print_check_partition_signature(mef, ftpr_offset + ftpr_mn2_offset)
 
     f.close()