Browse Source

Add region checks

Each write has been replaced with a write_to, which checks that the
final write location is inside the ME region before writing.
Nicola Corna 7 years ago
parent
commit
b6e58f41f3
1 changed files with 88 additions and 57 deletions
  1. 88 57
      me_cleaner.py

+ 88 - 57
me_cleaner.py

@@ -26,6 +26,58 @@ new_ftpr_offset = 0x1000
 unremovable_modules = ("BUP", "ROMP")
 
 
+class OutOfRegionException(Exception):
+    pass
+
+
+class regionFile:
+    def __init__(self, f, region_start, region_end):
+        self.f = f
+        self.region_start = region_start
+        self.region_end = region_end
+
+    def read(self, n):
+        return self.f.read(n)
+
+    def readinto(self, b):
+        return self.f.readinto(b)
+
+    def seek(self, offset):
+        return self.f.seek(offset)
+
+    def write_to(self, offset, data):
+        if offset >= self.region_start and \
+           offset + len(data) <= self.region_end:
+            self.f.seek(offset)
+            return self.f.write(data)
+        else:
+            raise OutOfRegionException()
+
+    def fill_range(self, start, end, fill):
+        if start >= self.region_start and end <= self.region_end:
+            block = fill * 4096
+            self.f.seek(start)
+            self.f.writelines(itertools.repeat(block, (end - start) // 4096))
+            self.f.write(block[:(end - start) % 4096])
+        else:
+            raise OutOfRegionException()
+
+    def move_range(self, offset_from, size, offset_to, fill):
+        if offset_from >= self.region_start and \
+           offset_from + size <= self.region_end and \
+           offset_to >= self.region_start and \
+           offset_to + size <= self.region_end:
+            for i in range(0, size, 4096):
+                self.f.seek(offset_from + i, 0)
+                block = self.f.read(4096 if size - i >= 4096 else size - i)
+                self.f.seek(offset_from + i, 0)
+                self.f.write(fill * len(block))
+                self.f.seek(offset_to + i, 0)
+                self.f.write(block)
+        else:
+            raise OutOfRegionException()
+
+
 def get_chunks_offsets(llut, me_start):
     chunk_count = unpack("<I", llut[0x04:0x08])[0]
     huffman_stream_end = sum(unpack("<II", llut[0x10:0x18])) + me_start
@@ -52,13 +104,6 @@ def get_chunks_offsets(llut, me_start):
     return offsets
 
 
-def fill_range(f, start, end, fill):
-    block = fill * 4096
-    f.seek(start)
-    f.writelines(itertools.repeat(block, (end - start) // 4096))
-    f.write(block[:(end - start) % 4096])
-
-
 def remove_modules(f, mod_headers, ftpr_offset):
     comp_str = ("Uncomp.", "Huffman", "LZMA")
     unremovable_huff_chunks = []
@@ -84,7 +129,7 @@ def remove_modules(f, mod_headers, ftpr_offset):
                 end_addr = max(end_addr, offset + size)
                 print("NOT removed, essential")
             else:
-                fill_range(f, offset, offset + size, b"\xff")
+                f.fill_range(offset, offset + size, b"\xff")
                 print("removed")
 
         elif comp_type == 0x01:
@@ -133,7 +178,7 @@ def remove_modules(f, mod_headers, ftpr_offset):
 
         for removable_chunk in removable_huff_chunks:
             if removable_chunk[1] > removable_chunk[0]:
-                fill_range(f, removable_chunk[0], removable_chunk[1], b"\xff")
+                f.fill_range(removable_chunk[0], removable_chunk[1], b"\xff")
 
         end_addr = max(end_addr,
                        max(unremovable_huff_chunks, key=lambda x: x[1])[1])
@@ -161,16 +206,6 @@ def check_partition_signature(f, offset):
     return "{:#x}".format(decrypted_sig).endswith(sha256.hexdigest())   # FIXME
 
 
-def move_range(f, offset_from, size, offset_to, fill):
-    for i in range(0, size, 4096):
-        f.seek(offset_from + i, 0)
-        block = f.read(4096 if size - i >= 4096 else size - i)
-        f.seek(offset_from + i, 0)
-        f.write(fill * 4096 if size - i >= 4096 else fill * (size - i))
-        f.seek(offset_to + i, 0)
-        f.write(block)
-
-
 def relocate_partition(f, me_start, partition_header_offset, new_offset,
                        mod_headers):
     f.seek(partition_header_offset)
@@ -183,8 +218,8 @@ def relocate_partition(f, me_start, partition_header_offset, new_offset,
           .format(name, new_offset, new_offset + partition_size))
 
     print(" Adjusting FPT entry...")
-    f.seek(partition_header_offset + 0x8)
-    f.write(pack("<I", new_offset - me_start))
+    f.write_to(partition_header_offset + 0x8,
+               pack("<I", new_offset - me_start))
 
     llut_start = 0
     for mod_header in mod_headers:
@@ -193,24 +228,24 @@ def relocate_partition(f, me_start, partition_header_offset, new_offset,
             break
 
     if llut_start != 0:
-        f.seek(llut_start, 0)
+        f.seek(llut_start)
         if f.read(4) == b"LLUT":
             print(" Adjusting LUT start offset...")
-            f.seek(llut_start + 0x0c, 0)
+            f.seek(llut_start + 0x0c)
             old_lut_offset = unpack("<I", f.read(4))[0]
-            f.seek(llut_start + 0x0c, 0)
-            f.write(pack("<I", old_lut_offset + offset_diff))
+            f.write_to(llut_start + 0x0c,
+                       pack("<I", old_lut_offset + offset_diff))
 
             print(" Adjusting Huffman start offset...")
-            f.seek(llut_start + 0x14, 0)
+            f.seek(llut_start + 0x14)
             old_huff_offset = unpack("<I", f.read(4))[0]
-            f.seek(llut_start + 0x14, 0)
-            f.write(pack("<I", old_huff_offset + offset_diff))
+            f.write_to(llut_start + 0x14,
+                       pack("<I", old_huff_offset + offset_diff))
 
             print(" Adjusting chunks offsets...")
-            f.seek(llut_start + 0x4, 0)
+            f.seek(llut_start + 0x4)
             chunk_count = unpack("<I", f.read(4))[0]
-            f.seek(llut_start + 0x40, 0)
+            f.seek(llut_start + 0x40)
             chunks = bytearray(chunk_count * 4)
             f.readinto(chunks)
             for i in range(0, chunk_count * 4, 4):
@@ -218,15 +253,14 @@ def relocate_partition(f, me_start, partition_header_offset, new_offset,
                     chunks[i:i + 3] = \
                         pack("<I", unpack("<I", chunks[i:i + 3] +
                              b"\x00")[0] + offset_diff)[0:3]
-            f.seek(llut_start + 0x40, 0)
-            f.write(chunks)
+            f.write_to(llut_start + 0x40, chunks)
         else:
             sys.exit("Huffman modules present but no LLUT found!")
     else:
         print(" No Huffman modules found")
 
     print(" Moving data...")
-    move_range(f, old_offset, partition_size, new_offset, b"\xff")
+    f.move_range(old_offset, partition_size, new_offset, b"\xff")
 
 
 if __name__ == "__main__":
@@ -242,34 +276,34 @@ if __name__ == "__main__":
                         action="store_true")
     args = parser.parse_args()
 
-    with open(args.file, "rb" if args.check else "r+b") as f:
-        f.seek(0x10)
-        magic = f.read(4)
+    with open(args.file, "rb" if args.check else "r+b") as fu:
+        fu.seek(0x10)
+        magic = fu.read(4)
 
         if magic == b"$FPT":
             print("ME/TXE image detected")
             me_start = 0
-            f.seek(0, 2)
-            me_end = f.tell()
+            fu.seek(0, 2)
+            me_end = fu.tell()
 
         elif magic == b"\x5a\xa5\xf0\x0f":
             print("Full image detected")
-            f.seek(0x14)
-            flmap0 = unpack("<I", f.read(4))[0]
+            fu.seek(0x14)
+            flmap0 = unpack("<I", fu.read(4))[0]
             nr = flmap0 >> 24 & 0x7
             frba = flmap0 >> 12 & 0xff0
             if nr >= 2:
-                f.seek(frba + 0x8)
-                flreg2 = unpack("<I", f.read(4))[0]
+                fu.seek(frba + 0x8)
+                flreg2 = unpack("<I", fu.read(4))[0]
                 me_start = (flreg2 & 0x1fff) << 12
-                me_end = flreg2 >> 4 & 0x1fff000 | 0xfff
+                me_end = flreg2 >> 4 & 0x1fff000 | 0xfff + 1
 
                 if me_start >= me_end:
                     sys.exit("The ME/TXE region in this image has been "
                              "disabled")
 
-                f.seek(me_start + 0x10)
-                if f.read(4) != b"$FPT":
+                fu.seek(me_start + 0x10)
+                if fu.read(4) != b"$FPT":
                     sys.exit("The ME/TXE region is corrupted or missing")
 
                 print("The ME/TXE region goes from {:#x} to {:#x}"
@@ -282,6 +316,8 @@ if __name__ == "__main__":
 
         print("Found FPT header at {:#x}".format(me_start + 0x10))
 
+        f = regionFile(fu, me_start, me_end)
+
         f.seek(me_start + 0x14)
         entries = unpack("<I", f.read(4))[0]
         print("Found {} partition(s)".format(entries))
@@ -324,24 +360,20 @@ if __name__ == "__main__":
 
         if not args.check:
             print("Removing extra partitions...")
-
-            fill_range(f, me_start + 0x30, ftpr_offset, b"\xff")
-            fill_range(f, ftpr_offset + ftpr_lenght, me_end, b"\xff")
+            f.fill_range(me_start + 0x30, ftpr_offset, b"\xff")
+            f.fill_range(ftpr_offset + ftpr_lenght, me_end, b"\xff")
 
             print("Removing extra partition entries in FPT...")
-            f.seek(me_start + 0x30)
-            f.write(ftpr_header)
-            f.seek(me_start + 0x14)
-            f.write(pack("<I", 1))
+            f.write_to(me_start + 0x30, ftpr_header)
+            f.write_to(me_start + 0x14, pack("<I", 1))
 
             print("Removing EFFS presence flag...")
             f.seek(me_start + 0x24)
             flags = unpack("<I", f.read(4))[0]
             flags &= ~(0x00000001)
-            f.seek(me_start + 0x24)
-            f.write(pack("<I", flags))
+            f.write_to(me_start + 0x24, pack("<I", flags))
 
-            f.seek(me_start, 0)
+            f.seek(me_start)
             header = bytearray(f.read(0x30))
             checksum = (0x100 - (sum(header) - header[0x1b]) & 0xff) & 0xff
 
@@ -349,8 +381,7 @@ if __name__ == "__main__":
             # The checksum is just the two's complement of the sum of the
             # first 0x30 bytes (except for 0x1b, the checksum itself). In other
             # words, the sum of the first 0x30 bytes must be always 0x00.
-            f.seek(me_start + 0x1b)
-            f.write(pack("B", checksum))
+            f.write_to(me_start + 0x1b, pack("B", checksum))
 
             if not me11:
                 print("Reading FTPR modules list...")