Parcourir la source

Don't create pending jit entries

Fabian il y a 3 ans
Parent
commit
7826474fed
5 fichiers modifiés avec 121 ajouts et 169 suppressions
  1. 2 2
      src/browser/starter.js
  2. 7 16
      src/cpu.js
  3. 107 147
      src/rust/jit.rs
  4. 1 0
      src/rust/state_flags.rs
  5. 4 4
      src/rust/wasmgen/wasm_builder.rs

+ 2 - 2
src/browser/starter.js

@@ -163,8 +163,8 @@ function V86Starter(options)
             dbg_trace();
         },
 
-        "codegen_finalize": (wasm_table_index, start, end, first_opcode, state_flags) => {
-            cpu.codegen_finalize(wasm_table_index, start, end, first_opcode, state_flags);
+        "codegen_finalize": (wasm_table_index, start, state_flags, ptr, len) => {
+            cpu.codegen_finalize(wasm_table_index, start, state_flags, ptr, len);
         },
         "jit_clear_func": (wasm_table_index) => cpu.jit_clear_func(wasm_table_index),
         "jit_clear_all_funcs": () => cpu.jit_clear_all_funcs(),

+ 7 - 16
src/cpu.js

@@ -201,15 +201,6 @@ CPU.prototype.clear_opstats = function()
     this.wm.exports["profiler_init"]();
 };
 
-CPU.prototype.wasmgen_get_module_code = function()
-{
-    const ptr = this.jit_get_op_ptr();
-    const len = this.jit_get_op_len();
-
-    const output_buffer_view = new Uint8Array(this.wm.instance.exports.memory.buffer, ptr, len);
-    return output_buffer_view;
-};
-
 CPU.prototype.create_jit_imports = function()
 {
     // Set this.jit_imports as generated WASM modules will expect
@@ -315,9 +306,6 @@ CPU.prototype.wasm_patch = function(wm)
     this.jit_dirty_cache = get_import("jit_dirty_cache");
     this.codegen_finalize_finished = get_import("codegen_finalize_finished");
 
-    this.jit_get_op_ptr = get_import("jit_get_op_ptr");
-    this.jit_get_op_len = get_import("jit_get_op_len");
-
     this.allocate_memory = get_import("allocate_memory");
 };
 
@@ -1364,10 +1352,11 @@ CPU.prototype.cycle = function()
 var seen_code = {};
 var seen_code_uncompiled = {};
 
-CPU.prototype.codegen_finalize = function(wasm_table_index, start, end, first_opcode, state_flags)
+CPU.prototype.codegen_finalize = function(wasm_table_index, start, state_flags, ptr, len)
 {
     dbg_assert(wasm_table_index >= 0 && wasm_table_index < WASM_TABLE_SIZE);
-    const code = this.wasmgen_get_module_code();
+
+    const code = new Uint8Array(this.wm.instance.exports.memory.buffer, ptr, len);
 
     if(DEBUG)
     {
@@ -1379,6 +1368,8 @@ CPU.prototype.codegen_finalize = function(wasm_table_index, start, end, first_op
 
             if(DUMP_ASSEMBLY)
             {
+                let end = 0;
+
                 if((start ^ end) & ~0xFFF)
                 {
                     dbg_log("truncated disassembly start=" + h(start >>> 0) + " end=" + h(end >>> 0));
@@ -1418,7 +1409,7 @@ CPU.prototype.codegen_finalize = function(wasm_table_index, start, end, first_op
         const result = new WebAssembly.Instance(module, { "e": jit_imports });
         const f = result.exports["f"];
 
-        this.codegen_finalize_finished(wasm_table_index, start, end, first_opcode, state_flags);
+        this.codegen_finalize_finished(wasm_table_index, start, state_flags);
 
         this.wm.imports["env"][WASM_EXPORT_TABLE_NAME].set(wasm_table_index + WASM_TABLE_OFFSET, f);
 
@@ -1433,7 +1424,7 @@ CPU.prototype.codegen_finalize = function(wasm_table_index, start, end, first_op
     const result = WebAssembly.instantiate(code, { "e": jit_imports }).then(result => {
         const f = result.instance.exports["f"];
 
-        this.codegen_finalize_finished(wasm_table_index, start, end, first_opcode, state_flags);
+        this.codegen_finalize_finished(wasm_table_index, start, state_flags);
 
         this.wm.imports["env"][WASM_EXPORT_TABLE_NAME].set(wasm_table_index + WASM_TABLE_OFFSET, f);
 

+ 107 - 147
src/rust/jit.rs

@@ -18,13 +18,15 @@ use util::SafeToU16;
 use wasmgen::wasm_builder::{WasmBuilder, WasmLocal};
 
 mod unsafe_jit {
+    use ::jit::CachedStateFlags;
+
     extern "C" {
         pub fn codegen_finalize(
             wasm_table_index: u16,
             phys_addr: u32,
-            end_addr: u32,
-            first_opcode: u32,
-            state_flags: u32,
+            state_flags: CachedStateFlags,
+            ptr: u32,
+            len: u32,
         );
         pub fn jit_clear_func(wasm_table_index: u16);
         pub fn jit_clear_all_funcs();
@@ -34,19 +36,11 @@ mod unsafe_jit {
 fn codegen_finalize(
     wasm_table_index: u16,
     phys_addr: u32,
-    end_addr: u32,
-    first_opcode: u32,
     state_flags: CachedStateFlags,
+    ptr: u32,
+    len: u32,
 ) {
-    unsafe {
-        unsafe_jit::codegen_finalize(
-            wasm_table_index,
-            phys_addr,
-            end_addr,
-            first_opcode,
-            state_flags.to_u32(),
-        )
-    }
+    unsafe { unsafe_jit::codegen_finalize(wasm_table_index, phys_addr, state_flags, ptr, len) }
 }
 
 pub fn jit_clear_func(wasm_table_index: u16) {
@@ -99,7 +93,11 @@ pub struct Entry {
     pub initial_state: u16,
     pub wasm_table_index: u16,
     pub state_flags: CachedStateFlags,
-    pub pending: bool,
+}
+
+enum PageState {
+    Compiling { basic_blocks: Vec<BasicBlock> },
+    CompilingWritten,
 }
 
 pub struct JitState {
@@ -108,11 +106,11 @@ pub struct JitState {
     // or a compressed bitmap (likely faster)
     hot_pages: [u32; HASH_PRIME as usize],
     wasm_table_index_free_list: Vec<u16>,
-    wasm_table_index_pending_free: Vec<u16>,
     entry_points: HashMap<Page, HashSet<u16>>,
     wasm_builder: WasmBuilder,
 
     cache: BTreeMap<u32, Entry>,
+    page_has_pending_code: HashMap<Page, (u16, PageState)>,
 }
 
 impl JitState {
@@ -123,10 +121,10 @@ impl JitState {
         JitState {
             hot_pages: [0; HASH_PRIME as usize],
             wasm_table_index_free_list: Vec::from_iter(wasm_table_indices),
-            wasm_table_index_pending_free: vec![],
             entry_points: HashMap::new(),
             wasm_builder: WasmBuilder::new(),
             cache: BTreeMap::new(),
+            page_has_pending_code: HashMap::new(),
         }
     }
 }
@@ -195,16 +193,13 @@ pub fn jit_find_cache_entry(phys_address: u32, state_flags: CachedStateFlags) ->
 
     match ctx.cache.get(&phys_address) {
         Some(entry) => {
-            if entry.state_flags == state_flags && !entry.pending {
+            if entry.state_flags == state_flags {
                 return CachedCode {
                     wasm_table_index: entry.wasm_table_index,
                     initial_state: entry.initial_state,
                 };
             }
             else {
-                if entry.pending {
-                    profiler::stat_increment(stat::RUN_INTERPRETED_PENDING);
-                }
                 if entry.state_flags != state_flags {
                     profiler::stat_increment(stat::RUN_INTERPRETED_DIFFERENT_STATE);
                 }
@@ -230,10 +225,7 @@ pub fn jit_find_cache_entry_in_page(
 
     match ctx.cache.get(&phys_address) {
         Some(entry) => {
-            if entry.state_flags == state_flags
-                && !entry.pending
-                && entry.wasm_table_index == wasm_table_index
-            {
+            if entry.state_flags == state_flags && entry.wasm_table_index == wasm_table_index {
                 return entry.initial_state as i32;
             }
         },
@@ -542,7 +534,7 @@ fn jit_analyze_and_generate(
     cs_offset: u32,
     state_flags: CachedStateFlags,
 ) {
-    if jit_page_has_pending_code(ctx, page) {
+    if ctx.page_has_pending_code.contains_key(&page) {
         return;
     }
 
@@ -572,10 +564,7 @@ fn jit_analyze_and_generate(
         //}
 
         if ctx.wasm_table_index_free_list.is_empty() {
-            dbg_log!(
-                "wasm_table_index_free_list empty ({} pending_free), clearing cache",
-                ctx.wasm_table_index_pending_free.len(),
-            );
+            dbg_log!("wasm_table_index_free_list empty, clearing cache",);
 
             // When no free slots are available, delete all cached modules. We could increase the
             // size of the table, but this way the initial size acts as an upper bound for the
@@ -586,8 +575,7 @@ fn jit_analyze_and_generate(
             profiler::stat_increment(stat::INVALIDATE_ALL_MODULES_NO_FREE_WASM_INDICES);
 
             dbg_log!(
-                "after jit_clear_cache: {} pending_free {} free",
-                ctx.wasm_table_index_pending_free.len(),
+                "after jit_clear_cache: {} free",
                 ctx.wasm_table_index_free_list.len(),
             );
 
@@ -612,81 +600,30 @@ fn jit_analyze_and_generate(
             state_flags,
         );
 
-        // create entries for each basic block that is marked as an entry point
-        let mut entry_point_count = 0;
-
-        let mut check_for_unused_wasm_table_index = HashSet::new();
-        let mut check_for_unused_wasm_table_index_pending = HashSet::new();
-
-        for (i, block) in basic_blocks.iter().enumerate() {
-            profiler::stat_increment(stat::COMPILE_BASIC_BLOCK);
-
-            if block.is_entry_block && block.addr != block.end_addr {
-                dbg_assert!(block.addr != 0);
-
-                let initial_state = i.safe_to_u16();
-
-                let entry = Entry {
-                    wasm_table_index,
-                    initial_state,
-                    state_flags,
-                    pending: true,
-
-                    #[cfg(any(debug_assertions, feature = "profiler"))]
-                    len: block.end_addr - block.addr,
-
-                    #[cfg(debug_assertions)]
-                    opcode: memory::read32s(block.addr) as u32,
-                };
-
-                let old_entry = ctx.cache.insert(block.addr, entry);
-
-                if let Some(old_entry) = old_entry {
-                    if old_entry.pending {
-                        check_for_unused_wasm_table_index_pending
-                            .insert(old_entry.wasm_table_index);
-                    }
-                    else {
-                        check_for_unused_wasm_table_index.insert(old_entry.wasm_table_index);
-                    }
-                }
-
-                entry_point_count += 1;
-                profiler::stat_increment(stat::COMPILE_ENTRY_POINT);
-            }
-        }
-
-        for (_, entry) in ctx.cache.range(page.address_range()) {
-            check_for_unused_wasm_table_index.remove(&entry.wasm_table_index);
-            check_for_unused_wasm_table_index_pending.remove(&entry.wasm_table_index);
-        }
-
-        for index in check_for_unused_wasm_table_index {
-            free_wasm_table_index(ctx, index);
-        }
-        for index in check_for_unused_wasm_table_index_pending {
-            ctx.wasm_table_index_pending_free.push(index);
-        }
-
-        profiler::stat_increment_by(stat::COMPILE_WASM_TOTAL_BYTES, jit_get_op_len() as u64);
-
-        dbg_assert!(entry_point_count > 0);
+        profiler::stat_increment_by(
+            stat::COMPILE_WASM_TOTAL_BYTES,
+            ctx.wasm_builder.get_output_len() as u64,
+        );
 
         cpu::tlb_set_has_code(page, true);
 
         cpu::check_tlb_invariants();
 
-        let end_addr = 0;
-        let first_opcode = 0;
+        let previous_state = ctx.page_has_pending_code.insert(
+            page,
+            (wasm_table_index, PageState::Compiling { basic_blocks }),
+        );
+        dbg_assert!(previous_state.is_none());
+
         let phys_addr = page.to_address();
 
         // will call codegen_finalize_finished asynchronously when finished
         codegen_finalize(
             wasm_table_index,
             phys_addr,
-            end_addr,
-            first_opcode,
             state_flags,
+            ctx.wasm_builder.get_output_ptr() as u32,
+            ctx.wasm_builder.get_output_len(),
         );
 
         profiler::stat_increment(stat::COMPILE_SUCCESS);
@@ -701,34 +638,74 @@ fn jit_analyze_and_generate(
 pub fn codegen_finalize_finished(
     wasm_table_index: u16,
     phys_addr: u32,
-    _end_addr: u32,
-    _first_opcode: u32,
-    _state_flags: CachedStateFlags,
+    state_flags: CachedStateFlags,
 ) {
     let ctx = get_jit_state();
 
     dbg_assert!(wasm_table_index != 0);
 
-    match ctx
-        .wasm_table_index_pending_free
-        .iter()
-        .position(|i| *i == wasm_table_index)
-    {
-        Some(i) => {
-            ctx.wasm_table_index_pending_free.swap_remove(i);
+    let page = Page::page_of(phys_addr);
+
+    let basic_blocks = match ctx.page_has_pending_code.remove(&page) {
+        None => {
+            dbg_assert!(false);
+            return;
+        },
+        Some((in_progress_wasm_table_index, PageState::CompilingWritten)) => {
+            dbg_assert!(wasm_table_index == in_progress_wasm_table_index);
             free_wasm_table_index(ctx, wasm_table_index);
+            return;
         },
-        None => {
-            let page = Page::page_of(phys_addr);
+        Some((in_progress_wasm_table_index, PageState::Compiling { basic_blocks })) => {
+            dbg_assert!(wasm_table_index == in_progress_wasm_table_index);
+            basic_blocks
+        },
+    };
 
-            for (_phys_addr, entry) in ctx.cache.range_mut(page.address_range()) {
-                if entry.wasm_table_index == wasm_table_index {
-                    dbg_assert!(entry.pending);
-                    //dbg_log!("mark entry at {:x} not pending", phys_addr);
-                    entry.pending = false;
-                }
+    // create entries for each basic block that is marked as an entry point
+    let mut entry_point_count = 0;
+
+    let mut check_for_unused_wasm_table_index = HashSet::new();
+
+    for (i, block) in basic_blocks.iter().enumerate() {
+        profiler::stat_increment(stat::COMPILE_BASIC_BLOCK);
+
+        if block.is_entry_block && block.addr != block.end_addr {
+            dbg_assert!(block.addr != 0);
+
+            let initial_state = i.safe_to_u16();
+
+            let entry = Entry {
+                wasm_table_index,
+                initial_state,
+                state_flags,
+
+                #[cfg(any(debug_assertions, feature = "profiler"))]
+                len: block.end_addr - block.addr,
+
+                #[cfg(debug_assertions)]
+                opcode: memory::read32s(block.addr) as u32,
+            };
+
+            let old_entry = ctx.cache.insert(block.addr, entry);
+
+            if let Some(old_entry) = old_entry {
+                check_for_unused_wasm_table_index.insert(old_entry.wasm_table_index);
             }
-        },
+
+            entry_point_count += 1;
+            profiler::stat_increment(stat::COMPILE_ENTRY_POINT);
+        }
+    }
+
+    dbg_assert!(entry_point_count > 0);
+
+    for (_, entry) in ctx.cache.range(page.address_range()) {
+        check_for_unused_wasm_table_index.remove(&entry.wasm_table_index);
+    }
+
+    for index in check_for_unused_wasm_table_index {
+        free_wasm_table_index(ctx, index);
     }
 }
 
@@ -1093,30 +1070,17 @@ pub fn jit_dirty_page(ctx: &mut JitState, page: Page) {
         .collect();
 
     let mut index_to_free = HashSet::new();
-    let mut index_to_pending_free = HashSet::new();
 
     for phys_addr in entries {
         let entry = ctx.cache.remove(&phys_addr).unwrap();
         did_have_code = true;
-
-        if entry.pending {
-            dbg_assert!(!index_to_free.contains(&entry.wasm_table_index));
-            index_to_pending_free.insert(entry.wasm_table_index);
-        }
-        else {
-            dbg_assert!(!index_to_pending_free.contains(&entry.wasm_table_index));
-            index_to_free.insert(entry.wasm_table_index);
-        }
+        index_to_free.insert(entry.wasm_table_index);
     }
 
     for index in index_to_free {
         free_wasm_table_index(ctx, index)
     }
 
-    for index in index_to_pending_free {
-        ctx.wasm_table_index_pending_free.push(index);
-    }
-
     match ctx.entry_points.remove(&page) {
         None => {},
         Some(_entry_points) => {
@@ -1127,6 +1091,19 @@ pub fn jit_dirty_page(ctx: &mut JitState, page: Page) {
         },
     }
 
+    match ctx.page_has_pending_code.get(&page) {
+        None => {},
+        Some((_, PageState::CompilingWritten)) => {},
+        Some((wasm_table_index, PageState::Compiling { .. })) => {
+            let wasm_table_index = *wasm_table_index;
+            did_have_code = true;
+            ctx.page_has_pending_code
+                .insert(page, (wasm_table_index, PageState::CompilingWritten));
+        },
+    }
+
+    dbg_assert!(!jit_page_has_code(page));
+
     if did_have_code {
         cpu::tlb_set_has_code(page, false);
     }
@@ -1178,23 +1155,11 @@ pub fn jit_page_has_code(page: Page) -> bool {
     let ctx = get_jit_state();
     let mut entries = ctx.cache.range(page.address_range());
     // Does the page have compiled code
-    //jit_cache_array::get_page_index(page) != None ||
     entries.next().is_some() ||
     // Or are there any entry points that need to be removed on write to the page
     // (this function is used to mark the has_code bit in the tlb to optimise away calls jit_dirty_page)
-    ctx.entry_points.contains_key(&page)
-}
-
-pub fn jit_page_has_pending_code(ctx: &JitState, page: Page) -> bool {
-    let entries = ctx.cache.range(page.address_range());
-
-    for (_phys_addr, entry) in entries {
-        if entry.pending {
-            return true;
-        }
-    }
-
-    return false;
+    ctx.entry_points.contains_key(&page) ||
+    match ctx.page_has_pending_code.get(&page) { Some(&(_, PageState::Compiling { .. })) => true, _ => false }
 }
 
 #[no_mangle]
@@ -1226,11 +1191,6 @@ pub fn jit_get_wasm_table_index_free_list_count() -> u32 {
     }
 }
 
-#[no_mangle]
-pub fn jit_get_op_len() -> u32 { get_jit_state().wasm_builder.get_op_len() }
-#[no_mangle]
-pub fn jit_get_op_ptr() -> *const u8 { get_jit_state().wasm_builder.get_op_ptr() }
-
 #[cfg(feature = "profiler")]
 pub fn check_missed_entry_points(phys_address: u32, state_flags: CachedStateFlags) {
     let page = Page::page_of(phys_address);

+ 1 - 0
src/rust/state_flags.rs

@@ -1,4 +1,5 @@
 #[derive(Copy, Clone, PartialEq, Eq)]
+#[repr(transparent)]
 pub struct CachedStateFlags(u8);
 
 impl CachedStateFlags {

+ 4 - 4
src/rust/wasmgen/wasm_builder.rs

@@ -490,9 +490,9 @@ impl WasmBuilder {
         }
     }
 
-    pub fn get_op_ptr(&self) -> *const u8 { self.output.as_ptr() }
+    pub fn get_output_ptr(&self) -> *const u8 { self.output.as_ptr() }
 
-    pub fn get_op_len(&self) -> u32 { self.output.len() as u32 }
+    pub fn get_output_len(&self) -> u32 { self.output.len() as u32 }
 
     #[must_use = "local allocated but not used"]
     fn alloc_local(&mut self) -> WasmLocal {
@@ -938,8 +938,8 @@ mod tests {
 
         m.finish();
 
-        let op_ptr = m.get_op_ptr();
-        let op_len = m.get_op_len();
+        let op_ptr = m.get_output_ptr();
+        let op_len = m.get_output_len();
         dbg_log!("op_ptr: {:?}, op_len: {:?}", op_ptr, op_len);
 
         let mut f = File::create("build/dummy_output.wasm").expect("creating dummy_output.wasm");