From 904ba7ceaf4bc29cad00bfb23380eb3404148609 Mon Sep 17 00:00:00 2001 From: Marius Wachtler Date: Sat, 14 Mar 2026 17:19:20 +0100 Subject: [PATCH 1/2] JIT: port the x86_64 Linux stencil JIT to DynASM MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the copy-and-patch relocation engine with a DynASM-based pipeline. Instead of manually copying pre-compiled stencil blobs and patching GOT entries / trampolines at runtime, Clang-generated assembly is converted at build time into DynASM .dasc source, which is then compiled into a C header (jit_stencils_dynasm.h). At runtime the DynASM assembler encodes native x86-64 directly, resolving all labels, jumps, and data references in a single pass. Key changes: Build pipeline (Tools/jit/): - _asm_to_dasc.py: New peephole optimizer that converts Clang AT&T asm to DynASM Intel-syntax .dasc. Uses typed operand classes (Reg, Mem, Imm) with Python 3.10+ match/case for pattern matching. Includes 15+ optimization patterns (immediate narrowing, test-self elimination, indexed memory folding, ALU immediate folding, redundant stack reload elimination, dead label removal, etc.). - _dasc_writer.py: Generates jit_stencils.h with DynASM preamble, emit helpers (emit_mov_imm, emit_call_ext, emit_cmp_reg_imm, emit_test/and/or/xor_reg_imm), and per-stencil emit functions. - _targets.py: Reworked to drive the DynASM pipeline — compiles stencils, converts asm, generates .dasc, runs the DynASM preprocessor, and produces the final header. - _stencils.py: Adds COLD_CODE HoleValue for hot/cold section splitting. - _optimizers.py: Extended with stencil frame-size tracking and frame-group merging infrastructure. - build.py: Adds --peephole-stats flag for optimization statistics. - test_peephole.py: unit tests covering peephole patterns and the line classification infrastructure. - Lib/test/test_jit_peephole.py: Hooks peephole tests into make test. Runtime (Python/jit.c): - Complete rewrite of _PyJIT_Compile: uses DynASM dasm_init / dasm_setup / per-stencil emit / dasm_link / dasm_encode instead of memcpy+patch. - Hot/cold code splitting: cold (error) paths are placed in a separate DynASM section after the hot code, improving i-cache locality. - Frame merging: stencils share a single prologue/epilogue, eliminating redundant rsp adjustments. - SET_IP delta encoding: incremental IP updates avoid redundant full address loads. - Hint-based mmap: jit_alloc() places JIT code near the CPython text segment for short (±2 GB) RIP-relative calls and LEAs. - jit_shrink(): releases unused pages at the end of each compiled trace. - emit_call_ext: emits direct RIP-relative call when target is within ±2 GB, otherwise falls back to indirect call through register. - emit_mov_imm: picks the shortest encoding (xor/mov32/mov64/lea rip) based on the runtime value. Freelist inlining (Tools/jit/jit.h + template.c): - Macro overrides redirect float/int allocation and deallocation to JIT-inlined versions that directly access the thread-state freelists, avoiding function call overhead for the most common object types. - _PyJIT_FloatFromDouble / _PyJIT_FloatDealloc: inline float freelist. - _PyJIT_LongDealloc / _PyJIT_FastDealloc: inline int/generic dealloc. - _PyJIT_CompactLong_{Add,Subtract,Multiply}: inline compact long ops. - PyStackRef_CLOSE / Py_DECREF overrides use the fast dealloc path. LuaJIT submodule: - Added as Tools/jit/LuaJIT for the DynASM assembler (dynasm/ only used at build time; no LuaJIT runtime code is linked). This is an experimental port, currently tested on x86_64 Linux only. The approach is a hybrid between Pyston's fully hand-written DynASM JIT (https://github.com/pyston/pyston/blob/pyston_main/Python/aot_ceval_jit.c) and CPython's Clang-generated stencils: Clang produces the stencil assembly, and DynASM handles encoding and relocation at runtime. --- .gitmodules | 3 + Lib/test/test_jit_peephole.py | 33 + Makefile.pre.in | 5 +- Python/jit.c | 798 ++++-------- Python/optimizer.c | 26 +- Tools/jit/LuaJIT | 1 + Tools/jit/_asm_to_dasc.py | 2093 +++++++++++++++++++++++++++++++ Tools/jit/_asm_to_dasc_amd64.py | 1464 +++++++++++++++++++++ Tools/jit/_dasc_writer.py | 448 +++++++ Tools/jit/_optimizers.py | 533 +++++--- Tools/jit/_schema.py | 1 + Tools/jit/_stencils.py | 28 +- Tools/jit/_targets.py | 269 +++- Tools/jit/_writer.py | 2 +- Tools/jit/build.py | 7 + Tools/jit/jit_fold_pass.cpp | 682 ++++++++++ Tools/jit/template.c | 412 +++++- Tools/jit/test_optimizers.py | 86 ++ Tools/jit/test_peephole.py | 687 ++++++++++ 19 files changed, 6797 insertions(+), 781 deletions(-) create mode 100644 .gitmodules create mode 100644 Lib/test/test_jit_peephole.py create mode 160000 Tools/jit/LuaJIT create mode 100644 Tools/jit/_asm_to_dasc.py create mode 100644 Tools/jit/_asm_to_dasc_amd64.py create mode 100644 Tools/jit/_dasc_writer.py create mode 100644 Tools/jit/jit_fold_pass.cpp create mode 100644 Tools/jit/test_optimizers.py create mode 100644 Tools/jit/test_peephole.py diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000000000..671c95aead403d --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "Tools/jit/LuaJIT"] + path = Tools/jit/LuaJIT + url = https://github.com/LuaJIT/LuaJIT.git diff --git a/Lib/test/test_jit_peephole.py b/Lib/test/test_jit_peephole.py new file mode 100644 index 00000000000000..bc0ee72d84b6c8 --- /dev/null +++ b/Lib/test/test_jit_peephole.py @@ -0,0 +1,33 @@ +"""Wrapper to run the JIT peephole optimizer tests via 'make test'. + +The actual tests live in Tools/jit/test_peephole.py. This module +adds Tools/jit to sys.path and imports the test cases so they are +discovered by the standard test runner. +""" + +import os +import sys +import unittest + +# Tools/jit is not on the default path — add it so test_peephole can +# import _asm_to_dasc. +_jit_tools_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), + "Tools", "jit", +) + +# Skip entirely if Tools/jit doesn't exist (e.g. minimal install). +if not os.path.isfile(os.path.join(_jit_tools_dir, "test_peephole.py")): + raise unittest.SkipTest("Tools/jit/test_peephole.py not found") + +_saved_path = sys.path[:] +try: + if _jit_tools_dir not in sys.path: + sys.path.insert(0, _jit_tools_dir) + # Import all test classes from the real test module. + from test_peephole import * # noqa: F401,F403 +finally: + sys.path[:] = _saved_path + +if __name__ == "__main__": + unittest.main() diff --git a/Makefile.pre.in b/Makefile.pre.in index f4119abf324fca..1285a177768e3d 100644 --- a/Makefile.pre.in +++ b/Makefile.pre.in @@ -3173,6 +3173,7 @@ Python/emscripten_trampoline_wasm.c: Python/emscripten_trampoline_inner.wasm JIT_DEPS = \ $(srcdir)/Tools/jit/*.c \ $(srcdir)/Tools/jit/*.py \ + $(srcdir)/Tools/jit/LuaJIT/dynasm \ $(srcdir)/Python/executor_cases.c.h \ pyconfig.h @@ -3180,7 +3181,7 @@ jit_stencils.h @JIT_STENCILS_H@: $(JIT_DEPS) @REGEN_JIT_COMMAND@ Python/jit.o: $(srcdir)/Python/jit.c @JIT_STENCILS_H@ - $(CC) -c $(PY_CORE_CFLAGS) -o $@ $< + $(CC) -c $(PY_CORE_CFLAGS) -I$(srcdir)/Tools/jit/LuaJIT/dynasm -o $@ $< .PHONY: regen-jit regen-jit: @@ -3305,7 +3306,7 @@ clean-profile: clean-retain-profile clean-bolt # gh-141808: The JIT stencils are deliberately kept in clean-profile .PHONY: clean-jit-stencils clean-jit-stencils: - -rm -f jit_stencils*.h + -rm -f jit_stencils*.h jit_stencils*.dasc .PHONY: clean clean: clean-profile clean-jit-stencils diff --git a/Python/jit.c b/Python/jit.c index 3e0a0aa8bfcc81..31b95755757b5a 100644 --- a/Python/jit.c +++ b/Python/jit.c @@ -30,6 +30,7 @@ #include "pycore_unicodeobject.h" #include "pycore_jit.h" +#include "pycore_uop_metadata.h" // Memory management stuff: //////////////////////////////////////////////////// @@ -103,13 +104,15 @@ _PyJIT_AddressInJitCode(PyInterpreterState *interp, uintptr_t addr) return 0; } +// Next mmap hint address for placing JIT code near CPython text. +// File-scope so jit_shrink() can rewind it when releasing unused pages. +#if defined(__linux__) && defined(__x86_64__) +static uintptr_t jit_next_hint = 0; +#endif + static unsigned char * jit_alloc(size_t size) { - if (size > PY_MAX_JIT_CODE_SIZE) { - jit_error("code too big; refactor bytecodes.c to keep uop size down, or reduce maximum trace length."); - return NULL; - } assert(size); assert(size % get_page_size() == 0); #ifdef MS_WINDOWS @@ -119,8 +122,30 @@ jit_alloc(size_t size) #else int flags = MAP_ANONYMOUS | MAP_PRIVATE; int prot = PROT_READ | PROT_WRITE; - unsigned char *memory = mmap(NULL, size, prot, flags, -1, 0); + void *hint = NULL; +#if defined(__linux__) && defined(__x86_64__) + // Allocate JIT code near CPython text so emit_call_ext and emit_mov_imm + // can use short RIP-relative encodings (within ±2GB). + { + if (jit_next_hint == 0) { + size_t page_size = get_page_size(); + extern char _end[]; + // Start 25MB after the end of CPython text, rounded up to the next page. + jit_next_hint = ((uintptr_t)_end + 25000000 + page_size - 1) & ~(uintptr_t)(page_size - 1); + } + hint = (void *)jit_next_hint; + } +#endif + unsigned char *memory = mmap(hint, size, prot, flags, -1, 0); + if (memory == MAP_FAILED && hint != NULL) { + memory = mmap(NULL, size, prot, flags, -1, 0); + } int failed = memory == MAP_FAILED; +#if defined(__linux__) && defined(__x86_64__) + if (!failed) { + jit_next_hint = (uintptr_t)memory + size; + } +#endif if (!failed) { (void)_PyAnnotateMemoryMap(memory, size, "cpython:jit"); } @@ -132,6 +157,30 @@ jit_alloc(size_t size) return memory; } +// Shrink a JIT allocation by releasing unused tail pages back to the OS. +// Updates jit_next_hint so the next allocation continues right after the +// trimmed region (avoids leaving gaps in the address space). +static void +jit_shrink(unsigned char *memory, size_t alloc_size, size_t used_size) +{ + assert(used_size <= alloc_size); + assert(used_size % get_page_size() == 0); + assert(alloc_size % get_page_size() == 0); + if (used_size < alloc_size) { +#ifdef MS_WINDOWS + VirtualFree(memory + used_size, alloc_size - used_size, MEM_DECOMMIT); +#else + munmap(memory + used_size, alloc_size - used_size); +#endif +#if defined(__linux__) && defined(__x86_64__) + // Rewind hint so the next allocation fills the gap we just freed. + if (jit_next_hint == (uintptr_t)memory + alloc_size) { + jit_next_hint = (uintptr_t)memory + used_size; + } +#endif + } +} + static int jit_free(unsigned char *memory, size_t size) { @@ -178,592 +227,257 @@ mark_executable(unsigned char *memory, size_t size) } // JIT compiler stuff: ///////////////////////////////////////////////////////// +// +// DynASM-based JIT: We use Clang to compile each uop template to optimized +// assembly at build time, convert the assembly to DynASM directives via +// _asm_to_dasc.py, and run the DynASM preprocessor (dynasm.lua) to produce +// jit_stencils.h containing an action list and per-uop emit functions. +// +// At runtime, DynASM's tiny encoding engine (dasm_x86.h) assembles the trace +// by replaying the action list with concrete operand values, resolving labels +// and branches automatically. This replaces the entire copy-and-patch +// relocation layer: no more patch_* functions, no trampolines, no GOT. + +#include "dasm_proto.h" + +// DynASM configuration: Dst is always dasm_State** passed as first argument +// to emit functions. +#define Dst_DECL dasm_State **Dst +#define Dst_REF (*Dst) + +#include "dasm_x86.h" +#include "jit_stencils.h" -#define GOT_SLOT_SIZE sizeof(uintptr_t) -#define SYMBOL_MASK_WORDS 8 - -typedef uint32_t symbol_mask[SYMBOL_MASK_WORDS]; - -typedef struct { - unsigned char *mem; - symbol_mask mask; - size_t size; -} symbol_state; - -typedef struct { - symbol_state trampolines; - symbol_state got_symbols; - uintptr_t instruction_starts[UOP_MAX_TRACE_LENGTH]; -} jit_state; - -// Warning! AArch64 requires you to get your hands dirty. These are your gloves: - -// value[value_start : value_start + len] -static uint32_t -get_bits(uint64_t value, uint8_t value_start, uint8_t width) -{ - assert(width <= 32); - return (value >> value_start) & ((1ULL << width) - 1); -} - -// *loc[loc_start : loc_start + width] = value[value_start : value_start + width] +// Compiles executor in-place using DynASM. +// +// The DynASM flow: +// 1. Initialize DynASM state and pre-allocate PC labels for all uops +// plus their internal branch targets. +// 2. Emit each uop stencil via the generated emit_*() functions. These +// call dasm_put() to append encoded instructions to the action buffer, +// using PC labels for inter-uop jumps and DynASM sections for hot/cold +// code separation. +// 3. Append a _FATAL_ERROR sentinel after the last uop to catch overruns. +// 4. dasm_link() computes the final code layout and resolves all labels. +// 5. Allocate executable memory (page-aligned) and dasm_encode() into it. +// 6. Mark memory executable and shrink unused pages. +// +// This replaces the old copy-and-patch approach and eliminates all manual +// relocation patching, GOT/trampoline generation. + +/* Emit all uop stencils (Phase 3-4) into the DynASM state. + * + * Handles _SET_IP delta encoding, shared trace cleanup stubs, and the + * _FATAL_ERROR sentinel. + */ static void -set_bits(uint32_t *loc, uint8_t loc_start, uint64_t value, uint8_t value_start, - uint8_t width) -{ - assert(loc_start + width <= 32); - uint32_t temp_val; - // Use memcpy to safely read the value, avoiding potential alignment - // issues and strict aliasing violations. - memcpy(&temp_val, loc, sizeof(temp_val)); - // Clear the bits we're about to patch: - temp_val &= ~(((1ULL << width) - 1) << loc_start); - assert(get_bits(temp_val, loc_start, width) == 0); - // Patch the bits: - temp_val |= get_bits(value, value_start, width) << loc_start; - assert(get_bits(temp_val, loc_start, width) == get_bits(value, value_start, width)); - // Safely write the modified value back to memory. - memcpy(loc, &temp_val, sizeof(temp_val)); -} - -// See https://developer.arm.com/documentation/ddi0602/2023-09/Base-Instructions -// for instruction encodings: -#define IS_AARCH64_ADD_OR_SUB(I) (((I) & 0x11C00000) == 0x11000000) -#define IS_AARCH64_ADRP(I) (((I) & 0x9F000000) == 0x90000000) -#define IS_AARCH64_BRANCH(I) (((I) & 0x7C000000) == 0x14000000) -#define IS_AARCH64_BRANCH_COND(I) (((I) & 0x7C000000) == 0x54000000) -#define IS_AARCH64_BRANCH_ZERO(I) (((I) & 0x7E000000) == 0x34000000) -#define IS_AARCH64_TEST_AND_BRANCH(I) (((I) & 0x7E000000) == 0x36000000) -#define IS_AARCH64_LDR_OR_STR(I) (((I) & 0x3B000000) == 0x39000000) -#define IS_AARCH64_MOV(I) (((I) & 0x9F800000) == 0x92800000) - -// LLD is a great reference for performing relocations... just keep in -// mind that Tools/jit/build.py does filtering and preprocessing for us! -// Here's a good place to start for each platform: -// - aarch64-apple-darwin: -// - https://github.com/llvm/llvm-project/blob/main/lld/MachO/Arch/ARM64.cpp -// - https://github.com/llvm/llvm-project/blob/main/lld/MachO/Arch/ARM64Common.cpp -// - https://github.com/llvm/llvm-project/blob/main/lld/MachO/Arch/ARM64Common.h -// - aarch64-pc-windows-msvc: -// - https://github.com/llvm/llvm-project/blob/main/lld/COFF/Chunks.cpp -// - aarch64-unknown-linux-gnu: -// - https://github.com/llvm/llvm-project/blob/main/lld/ELF/Arch/AArch64.cpp -// - i686-pc-windows-msvc: -// - https://github.com/llvm/llvm-project/blob/main/lld/COFF/Chunks.cpp -// - x86_64-apple-darwin: -// - https://github.com/llvm/llvm-project/blob/main/lld/MachO/Arch/X86_64.cpp -// - x86_64-pc-windows-msvc: -// - https://github.com/llvm/llvm-project/blob/main/lld/COFF/Chunks.cpp -// - x86_64-unknown-linux-gnu: -// - https://github.com/llvm/llvm-project/blob/main/lld/ELF/Arch/X86_64.cpp - - -// Get the symbol slot memory location for a given symbol ordinal. -static unsigned char * -get_symbol_slot(int ordinal, symbol_state *state, int size) -{ - const uint32_t symbol_mask = 1U << (ordinal % 32); - const uint32_t state_mask = state->mask[ordinal / 32]; - assert(symbol_mask & state_mask); - - // Count the number of set bits in the symbol mask lower than ordinal - size_t index = _Py_popcount32(state_mask & (symbol_mask - 1)); - for (int i = 0; i < ordinal / 32; i++) { - index += _Py_popcount32(state->mask[i]); - } - - unsigned char *slot = state->mem + index * size; - assert((size_t)(index + 1) * size <= state->size); - return slot; -} - -// Return the address of the GOT slot for the requested symbol ordinal. -static uintptr_t -got_symbol_address(int ordinal, jit_state *state) -{ - return (uintptr_t)get_symbol_slot(ordinal, &state->got_symbols, GOT_SLOT_SIZE); -} - -// Many of these patches are "relaxing", meaning that they can rewrite the -// code they're patching to be more efficient (like turning a 64-bit memory -// load into a 32-bit immediate load). These patches have an "x" in their name. -// Relative patches have an "r" in their name. - -// 32-bit absolute address. -void -patch_32(unsigned char *location, uint64_t value) -{ - // Check that we're not out of range of 32 unsigned bits: - assert(value < (1ULL << 32)); - uint32_t final_value = (uint32_t)value; - memcpy(location, &final_value, sizeof(final_value)); -} - -// 32-bit relative address. -void -patch_32r(unsigned char *location, uint64_t value) +emit_trace(dasm_State **Dst, + const _PyUOpInstruction *trace, size_t length) { - value -= (uintptr_t)location; - // Check that we're not out of range of 32 signed bits: - assert((int64_t)value >= -(1LL << 31)); - assert((int64_t)value < (1LL << 31)); - uint32_t final_value = (uint32_t)value; - memcpy(location, &final_value, sizeof(final_value)); -} - -// 64-bit absolute address. -void -patch_64(unsigned char *location, uint64_t value) -{ - memcpy(location, &value, sizeof(value)); -} + int sentinel_label = (int)length; + int label_base = sentinel_label + 1; + uintptr_t last_ip = 0; // track last _SET_IP value for delta encoding -// 12-bit low part of an absolute address. Pairs nicely with patch_aarch64_21r -// (below). -void -patch_aarch64_12(unsigned char *location, uint64_t value) -{ - uint32_t *loc32 = (uint32_t *)location; - assert(IS_AARCH64_LDR_OR_STR(*loc32) || IS_AARCH64_ADD_OR_SUB(*loc32)); - // There might be an implicit shift encoded in the instruction: - uint8_t shift = 0; - if (IS_AARCH64_LDR_OR_STR(*loc32)) { - shift = (uint8_t)get_bits(*loc32, 30, 2); - // If both of these are set, the shift is supposed to be 4. - // That's pretty weird, and it's never actually been observed... - assert(get_bits(*loc32, 23, 1) == 0 || get_bits(*loc32, 26, 1) == 0); - } - value = get_bits(value, 0, 12); - assert(get_bits(value, 0, shift) == 0); - set_bits(loc32, 10, value, shift, 12); -} - -// Relaxable 12-bit low part of an absolute address. Pairs nicely with -// patch_aarch64_21rx (below). -void -patch_aarch64_12x(unsigned char *location, uint64_t value) -{ - // This can *only* be relaxed if it occurs immediately before a matching - // patch_aarch64_21rx. If that happens, the JIT build step will replace both - // calls with a single call to patch_aarch64_33rx. Otherwise, we end up - // here, and the instruction is patched normally: - patch_aarch64_12(location, value); -} - -// 16-bit low part of an absolute address. -void -patch_aarch64_16a(unsigned char *location, uint64_t value) -{ - uint32_t *loc32 = (uint32_t *)location; - assert(IS_AARCH64_MOV(*loc32)); - // Check the implicit shift (this is "part 0 of 3"): - assert(get_bits(*loc32, 21, 2) == 0); - set_bits(loc32, 5, value, 0, 16); -} + emit_trace_entry_frame(Dst); -// 16-bit middle-low part of an absolute address. -void -patch_aarch64_16b(unsigned char *location, uint64_t value) -{ - uint32_t *loc32 = (uint32_t *)location; - assert(IS_AARCH64_MOV(*loc32)); - // Check the implicit shift (this is "part 1 of 3"): - assert(get_bits(*loc32, 21, 2) == 1); - set_bits(loc32, 5, value, 16, 16); -} - -// 16-bit middle-high part of an absolute address. -void -patch_aarch64_16c(unsigned char *location, uint64_t value) -{ - uint32_t *loc32 = (uint32_t *)location; - assert(IS_AARCH64_MOV(*loc32)); - // Check the implicit shift (this is "part 2 of 3"): - assert(get_bits(*loc32, 21, 2) == 2); - set_bits(loc32, 5, value, 32, 16); -} - -// 16-bit high part of an absolute address. -void -patch_aarch64_16d(unsigned char *location, uint64_t value) -{ - uint32_t *loc32 = (uint32_t *)location; - assert(IS_AARCH64_MOV(*loc32)); - // Check the implicit shift (this is "part 3 of 3"): - assert(get_bits(*loc32, 21, 2) == 3); - set_bits(loc32, 5, value, 48, 16); -} - -// 21-bit count of pages between this page and an absolute address's page... I -// know, I know, it's weird. Pairs nicely with patch_aarch64_12 (above). -void -patch_aarch64_21r(unsigned char *location, uint64_t value) -{ - uint32_t *loc32 = (uint32_t *)location; - value = (value >> 12) - ((uintptr_t)location >> 12); - // Check that we're not out of range of 21 signed bits: - assert((int64_t)value >= -(1 << 20)); - assert((int64_t)value < (1 << 20)); - // value[0:2] goes in loc[29:31]: - set_bits(loc32, 29, value, 0, 2); - // value[2:21] goes in loc[5:26]: - set_bits(loc32, 5, value, 2, 19); -} - -// Relaxable 21-bit count of pages between this page and an absolute address's -// page. Pairs nicely with patch_aarch64_12x (above). -void -patch_aarch64_21rx(unsigned char *location, uint64_t value) -{ - // This can *only* be relaxed if it occurs immediately before a matching - // patch_aarch64_12x. If that happens, the JIT build step will replace both - // calls with a single call to patch_aarch64_33rx. Otherwise, we end up - // here, and the instruction is patched normally: - patch_aarch64_21r(location, value); -} - -// 21-bit relative branch. -void -patch_aarch64_19r(unsigned char *location, uint64_t value) -{ - uint32_t *loc32 = (uint32_t *)location; - assert(IS_AARCH64_BRANCH_COND(*loc32) || IS_AARCH64_BRANCH_ZERO(*loc32)); - value -= (uintptr_t)location; - // Check that we're not out of range of 21 signed bits: - assert((int64_t)value >= -(1 << 20)); - assert((int64_t)value < (1 << 20)); - // Since instructions are 4-byte aligned, only use 19 bits: - assert(get_bits(value, 0, 2) == 0); - set_bits(loc32, 5, value, 2, 19); -} - -// 28-bit relative branch. -void -patch_aarch64_26r(unsigned char *location, uint64_t value) -{ - uint32_t *loc32 = (uint32_t *)location; - assert(IS_AARCH64_BRANCH(*loc32)); - value -= (uintptr_t)location; - // Check that we're not out of range of 28 signed bits: - assert((int64_t)value >= -(1 << 27)); - assert((int64_t)value < (1 << 27)); - // Since instructions are 4-byte aligned, only use 26 bits: - assert(get_bits(value, 0, 2) == 0); - set_bits(loc32, 0, value, 2, 26); -} - -// A pair of patch_aarch64_21rx and patch_aarch64_12x. -void -patch_aarch64_33rx(unsigned char *location, uint64_t value) -{ - uint32_t *loc32 = (uint32_t *)location; - // Try to relax the pair of GOT loads into an immediate value: - assert(IS_AARCH64_ADRP(*loc32)); - unsigned char reg = get_bits(loc32[0], 0, 5); - assert(IS_AARCH64_LDR_OR_STR(loc32[1])); - // There should be only one register involved: - assert(reg == get_bits(loc32[1], 0, 5)); // ldr's output register. - assert(reg == get_bits(loc32[1], 5, 5)); // ldr's input register. - uint64_t relaxed = *(uint64_t *)value; - if (relaxed < (1UL << 16)) { - // adrp reg, AAA; ldr reg, [reg + BBB] -> movz reg, XXX; nop - loc32[0] = 0xD2800000 | (get_bits(relaxed, 0, 16) << 5) | reg; - loc32[1] = 0xD503201F; - return; - } - if (relaxed < (1ULL << 32)) { - // adrp reg, AAA; ldr reg, [reg + BBB] -> movz reg, XXX; movk reg, YYY - loc32[0] = 0xD2800000 | (get_bits(relaxed, 0, 16) << 5) | reg; - loc32[1] = 0xF2A00000 | (get_bits(relaxed, 16, 16) << 5) | reg; - return; - } - int64_t page_delta = (relaxed >> 12) - ((uintptr_t)location >> 12); - if (page_delta >= -(1L << 20) && - page_delta < (1L << 20)) - { - // adrp reg, AAA; ldr reg, [reg + BBB] -> adrp reg, AAA; add reg, reg, BBB - patch_aarch64_21rx(location, relaxed); - loc32[1] = 0x91000000 | get_bits(relaxed, 0, 12) << 10 | reg << 5 | reg; - return; - } - relaxed = value - (uintptr_t)location; - if ((relaxed & 0x3) == 0 && - (int64_t)relaxed >= -(1L << 19) && - (int64_t)relaxed < (1L << 19)) - { - // adrp reg, AAA; ldr reg, [reg + BBB] -> ldr reg, XXX; nop - loc32[0] = 0x58000000 | (get_bits(relaxed, 2, 19) << 5) | reg; - loc32[1] = 0xD503201F; - return; - } - // Couldn't do it. Just patch the two instructions normally: - patch_aarch64_21rx(location, value); - patch_aarch64_12x(location + 4, value); -} - -// Relaxable 32-bit relative address. -void -patch_x86_64_32rx(unsigned char *location, uint64_t value) -{ - uint8_t *loc8 = (uint8_t *)location; - // Try to relax the GOT load into an immediate value: - uint64_t relaxed; - memcpy(&relaxed, (void *)(value + 4), sizeof(relaxed)); - relaxed -= 4; - - if ((int64_t)relaxed - (int64_t)location >= -(1LL << 31) && - (int64_t)relaxed - (int64_t)location + 1 < (1LL << 31)) - { - if (loc8[-2] == 0x8B) { - // mov reg, dword ptr [rip + AAA] -> lea reg, [rip + XXX] - loc8[-2] = 0x8D; - value = relaxed; + for (size_t i = 0; i < length; i++) { + const _PyUOpInstruction *instruction = &trace[i]; + int uop_label = (int)i; + int continue_label = (int)(i + 1); + + int opcode = instruction->opcode; + if ((opcode == _SET_IP_r00 || opcode == _SET_IP_r11 + || opcode == _SET_IP_r22 || opcode == _SET_IP_r33) + && last_ip != 0) + { + uintptr_t new_ip = (uintptr_t)instruction->operand0; + intptr_t delta = (intptr_t)(new_ip - last_ip); + if (delta != 0 + && delta >= INT32_MIN && delta <= INT32_MAX) + { + emit_set_ip_delta(Dst, uop_label, delta); + label_base += jit_internal_label_count(opcode); + last_ip = new_ip; + // SET_IP delta only modifies [r13+56], preserves rax + continue; + } } - else if (loc8[-2] == 0xFF && loc8[-1] == 0x15) { - // call qword ptr [rip + AAA] -> nop; call XXX - loc8[-2] = 0x90; - loc8[-1] = 0xE8; - value = relaxed; + + jit_emit_one(Dst, instruction->opcode, instruction, + uop_label, continue_label, label_base); + label_base += jit_internal_label_count(instruction->opcode); + if (opcode == _SET_IP_r00 || opcode == _SET_IP_r11 + || opcode == _SET_IP_r22 || opcode == _SET_IP_r33) + { + last_ip = (uintptr_t)instruction->operand0; } - else if (loc8[-2] == 0xFF && loc8[-1] == 0x25) { - // jmp qword ptr [rip + AAA] -> nop; jmp XXX - loc8[-2] = 0x90; - loc8[-1] = 0xE9; - value = relaxed; + else if (jit_invalidates_ip(opcode)) { + last_ip = 0; } } - patch_32r(location, value); -} -void patch_got_symbol(jit_state *state, int ordinal); -void patch_aarch64_trampoline(unsigned char *location, int ordinal, jit_state *state); -void patch_x86_64_trampoline(unsigned char *location, int ordinal, jit_state *state); - -#include "jit_stencils.h" - -#if defined(__aarch64__) || defined(_M_ARM64) - #define TRAMPOLINE_SIZE 16 - #define DATA_ALIGN 8 -#elif defined(__x86_64__) && defined(__APPLE__) - // LLVM 20 on macOS x86_64 debug builds: GOT entries may exceed ±2GB PC-relative - // range. - #define TRAMPOLINE_SIZE 16 // 14 bytes + 2 bytes padding for alignment - #define DATA_ALIGN 8 -#else - #define TRAMPOLINE_SIZE 0 - #define DATA_ALIGN 1 -#endif - -// Populate the GOT entry for the given symbol ordinal with its resolved address. -void -patch_got_symbol(jit_state *state, int ordinal) -{ - uint64_t value = (uintptr_t)symbols_map[ordinal]; - unsigned char *location = (unsigned char *)get_symbol_slot(ordinal, &state->got_symbols, GOT_SLOT_SIZE); - patch_64(location, value); -} - -// Generate and patch AArch64 trampolines. The symbols to jump to are stored -// in the jit_stencils.h in the symbols_map. -void -patch_aarch64_trampoline(unsigned char *location, int ordinal, jit_state *state) -{ - - uint64_t value = (uintptr_t)symbols_map[ordinal]; - int64_t range = value - (uintptr_t)location; - - // If we are in range of 28 signed bits, we patch the instruction with - // the address of the symbol. - if (range >= -(1 << 27) && range < (1 << 27)) { - patch_aarch64_26r(location, (uintptr_t)value); - return; + // Emit _FATAL_ERROR sentinel after the last uop to catch overruns + { + _PyUOpInstruction sentinel = {0}; + sentinel.opcode = _FATAL_ERROR_r00; + int sentinel_continue = sentinel_label; + jit_emit_one(Dst, _FATAL_ERROR_r00, &sentinel, + sentinel_label, sentinel_continue, label_base); } - - // Out of range - need a trampoline - uint32_t *p = (uint32_t *)get_symbol_slot(ordinal, &state->trampolines, TRAMPOLINE_SIZE); - - /* Generate the trampoline - 0: 58000048 ldr x8, 8 - 4: d61f0100 br x8 - 8: 00000000 // The next two words contain the 64-bit address to jump to. - c: 00000000 - */ - p[0] = 0x58000048; - p[1] = 0xD61F0100; - p[2] = value & 0xffffffff; - p[3] = value >> 32; - - patch_aarch64_26r(location, (uintptr_t)p); -} - -// Generate and patch x86_64 trampolines. -void -patch_x86_64_trampoline(unsigned char *location, int ordinal, jit_state *state) -{ - uint64_t value = (uintptr_t)symbols_map[ordinal]; - int64_t range = (int64_t)value - 4 - (int64_t)location; - - // If we are in range of 32 signed bits, we can patch directly - if (range >= -(1LL << 31) && range < (1LL << 31)) { - patch_32r(location, value - 4); - return; - } - - // Out of range - need a trampoline - unsigned char *trampoline = get_symbol_slot(ordinal, &state->trampolines, TRAMPOLINE_SIZE); - - /* Generate the trampoline (14 bytes, padded to 16): - 0: ff 25 00 00 00 00 jmp *(%rip) - 6: XX XX XX XX XX XX XX XX (64-bit target address) - - Reference: https://wiki.osdev.org/X86-64_Instruction_Encoding#FF (JMP r/m64) - */ - trampoline[0] = 0xFF; - trampoline[1] = 0x25; - memset(trampoline + 2, 0, 4); - memcpy(trampoline + 6, &value, 8); - - // Patch the call site to call the trampoline instead - patch_32r(location, (uintptr_t)trampoline - 4); } +/* Initialize a DynASM state for trace compilation. */ static void -combine_symbol_mask(const symbol_mask src, symbol_mask dest) +init_dasm(dasm_State **Dst, int total_labels) { - // Calculate the union of the trampolines required by each StencilGroup - for (size_t i = 0; i < SYMBOL_MASK_WORDS; i++) { - dest[i] |= src[i]; - } + dasm_init(Dst, DASM_MAXSECTION); + dasm_setup(Dst, jit_actionlist); + dasm_growpc(Dst, total_labels); } -// Compiles executor in-place. Don't forget to call _PyJIT_Free later! int _PyJIT_Compile(_PyExecutorObject *executor, const _PyUOpInstruction trace[], size_t length) { - const StencilGroup *group; - // Loop once to find the total compiled size: - size_t code_size = 0; - size_t data_size = 0; - jit_state state = {0}; + // Phase 1: Count total PC labels needed. + // Labels [0..length-1] are uop entry points; additional labels are + // allocated for internal branch targets within each stencil. + int total_labels = (int)length; for (size_t i = 0; i < length; i++) { - const _PyUOpInstruction *instruction = &trace[i]; - group = &stencil_groups[instruction->opcode]; - state.instruction_starts[i] = code_size; - code_size += group->code_size; - data_size += group->data_size; - combine_symbol_mask(group->trampoline_mask, state.trampolines.mask); - combine_symbol_mask(group->got_mask, state.got_symbols.mask); - } - group = &stencil_groups[_FATAL_ERROR_r00]; - code_size += group->code_size; - data_size += group->data_size; - combine_symbol_mask(group->trampoline_mask, state.trampolines.mask); - combine_symbol_mask(group->got_mask, state.got_symbols.mask); - // Calculate the size of the trampolines required by the whole trace - for (size_t i = 0; i < Py_ARRAY_LENGTH(state.trampolines.mask); i++) { - state.trampolines.size += _Py_popcount32(state.trampolines.mask[i]) * TRAMPOLINE_SIZE; - } - for (size_t i = 0; i < Py_ARRAY_LENGTH(state.got_symbols.mask); i++) { - state.got_symbols.size += _Py_popcount32(state.got_symbols.mask[i]) * GOT_SLOT_SIZE; - } - // Round up to the nearest page: + total_labels += jit_internal_label_count(trace[i].opcode); + } + // One extra label for the _FATAL_ERROR sentinel. + total_labels += 1; + // Extra internal labels for _FATAL_ERROR + total_labels += jit_internal_label_count(_FATAL_ERROR_r00); + + // Phase 2–6: Single-pass JIT compilation. + // + // Allocate PY_MAX_JIT_CODE_SIZE up front. Since jit_alloc() places + // code near CPython text (via mmap hints on Linux x86-64), the real + // allocation address is always usable as jit_code_base — emit_mov_imm() + // and emit_call_ext() will use short RIP-relative encodings. + // + // After encoding, unused tail pages are released back to the OS and + // jit_next_hint is rewound so the next allocation fills the gap. + dasm_State *d; + size_t code_size; + int status; + size_t page_size = get_page_size(); assert((page_size & (page_size - 1)) == 0); - size_t code_padding = DATA_ALIGN - ((code_size + state.trampolines.size) & (DATA_ALIGN - 1)); - size_t padding = page_size - ((code_size + state.trampolines.size + code_padding + data_size + state.got_symbols.size) & (page_size - 1)); - size_t total_size = code_size + state.trampolines.size + code_padding + data_size + state.got_symbols.size + padding; - unsigned char *memory = jit_alloc(total_size); + size_t alloc_size = (PY_MAX_JIT_CODE_SIZE + page_size - 1) & ~(page_size - 1); + unsigned char *memory = jit_alloc(alloc_size); if (memory == NULL) { return -1; } + + jit_code_base = (uintptr_t)memory; + + init_dasm(&d, total_labels); + emit_trace(&d, trace, length); + status = dasm_link(&d, &code_size); + if (status != DASM_S_OK) { + jit_free(memory, alloc_size); + dasm_free(&d); + PyErr_Format(PyExc_RuntimeWarning, + "JIT DynASM link failed (status %d)", status); + return -1; + } + if (code_size > PY_MAX_JIT_CODE_SIZE) { + // Trace too large — give up on this trace. + jit_free(memory, alloc_size); + dasm_free(&d); + jit_error("code too big; refactor bytecodes.c to keep uop size down, or reduce maximum trace length."); + return -1; + } + if (code_size > alloc_size) { + // Trace too large — give up on this trace. + jit_free(memory, alloc_size); + dasm_free(&d); + PyErr_Format(PyExc_RuntimeWarning, + "JIT code too large (%zu bytes)", code_size); + return -1; + } + + // Phase 7: Encode — writes final machine code into memory. + status = dasm_encode(&d, memory); + if (status != DASM_S_OK) { + jit_free(memory, alloc_size); + dasm_free(&d); + PyErr_Format(PyExc_RuntimeWarning, + "JIT DynASM encode failed (status %d)", status); + return -1; + } + + dasm_free(&d); + + // Release unused tail pages and rewind jit_next_hint. + size_t total_size = (code_size + page_size - 1) & ~(page_size - 1); + jit_shrink(memory, alloc_size, total_size); + // Collect memory stats OPT_STAT_ADD(jit_total_memory_size, total_size); OPT_STAT_ADD(jit_code_size, code_size); - OPT_STAT_ADD(jit_trampoline_size, state.trampolines.size); - OPT_STAT_ADD(jit_data_size, data_size); - OPT_STAT_ADD(jit_got_size, state.got_symbols.size); - OPT_STAT_ADD(jit_padding_size, padding); + OPT_STAT_ADD(jit_padding_size, total_size - code_size); OPT_HIST(total_size, trace_total_memory_hist); - // Update the offsets of each instruction: - for (size_t i = 0; i < length; i++) { - state.instruction_starts[i] += (uintptr_t)memory; - } - // Loop again to emit the code: - unsigned char *code = memory; - state.trampolines.mem = memory + code_size; - unsigned char *data = memory + code_size + state.trampolines.size + code_padding; - assert(trace[0].opcode == _START_EXECUTOR_r00 || trace[0].opcode == _COLD_EXIT_r00 || trace[0].opcode == _COLD_DYNAMIC_EXIT_r00); - state.got_symbols.mem = data + data_size; - for (size_t i = 0; i < length; i++) { - const _PyUOpInstruction *instruction = &trace[i]; - group = &stencil_groups[instruction->opcode]; - group->emit(code, data, executor, instruction, &state); - code += group->code_size; - data += group->data_size; - } - // Protect against accidental buffer overrun into data: - group = &stencil_groups[_FATAL_ERROR_r00]; - group->emit(code, data, executor, NULL, &state); - code += group->code_size; - data += group->data_size; - assert(code == memory + code_size); - assert(data == memory + code_size + state.trampolines.size + code_padding + data_size); + if (mark_executable(memory, total_size)) { jit_free(memory, total_size); return -1; } + executor->jit_code = memory; executor->jit_size = total_size; return 0; } -/* One-off compilation of the jit entry shim - * We compile this once only as it effectively a normal - * function, but we need to use the JIT because it needs - * to understand the jit-specific calling convention. - * Don't forget to call _PyJIT_Fini later! +/* One-off compilation of the jit entry shim. + * + * The shim bridges the native C calling convention to the JIT's internal + * calling convention. It is compiled once and shared across all traces. + * Uses DynASM just like trace compilation, but with a single emit_shim() + * call instead of a loop over uops. */ static _PyJitEntryFuncPtr compile_shim(void) { - _PyExecutorObject dummy; - const StencilGroup *group; - size_t code_size = 0; - size_t data_size = 0; - jit_state state = {0}; - group = &shim; - code_size += group->code_size; - data_size += group->data_size; - combine_symbol_mask(group->trampoline_mask, state.trampolines.mask); - combine_symbol_mask(group->got_mask, state.got_symbols.mask); - // Round up to the nearest page: + int total_labels = 1 + jit_internal_label_count_shim(); + dasm_State *d; + size_t code_size; + int status; + + // The shim is tiny (~100 bytes). Allocate one page, compile once. size_t page_size = get_page_size(); - assert((page_size & (page_size - 1)) == 0); - size_t code_padding = DATA_ALIGN - ((code_size + state.trampolines.size) & (DATA_ALIGN - 1)); - size_t padding = page_size - ((code_size + state.trampolines.size + code_padding + data_size + state.got_symbols.size) & (page_size - 1)); - size_t total_size = code_size + state.trampolines.size + code_padding + data_size + state.got_symbols.size + padding; - unsigned char *memory = jit_alloc(total_size); + size_t alloc_size = page_size; + unsigned char *memory = jit_alloc(alloc_size); if (memory == NULL) { return NULL; } - unsigned char *code = memory; - state.trampolines.mem = memory + code_size; - unsigned char *data = memory + code_size + state.trampolines.size + code_padding; - state.got_symbols.mem = data + data_size; - // Compile the shim, which handles converting between the native - // calling convention and the calling convention used by jitted code - // (which may be different for efficiency reasons). - group = &shim; - group->emit(code, data, &dummy, NULL, &state); - code += group->code_size; - data += group->data_size; - assert(code == memory + code_size); - assert(data == memory + code_size + state.trampolines.size + code_padding + data_size); - if (mark_executable(memory, total_size)) { - jit_free(memory, total_size); + + jit_code_base = (uintptr_t)memory; + + init_dasm(&d, total_labels); + emit_shim(&d, 0, 1); + status = dasm_link(&d, &code_size); + if (status != DASM_S_OK) { + jit_free(memory, alloc_size); + dasm_free(&d); + return NULL; + } + assert(code_size <= alloc_size); + + status = dasm_encode(&d, memory); + dasm_free(&d); + if (status != DASM_S_OK) { + jit_free(memory, alloc_size); + return NULL; + } + + if (mark_executable(memory, alloc_size)) { + jit_free(memory, alloc_size); return NULL; } - _Py_jit_shim_size = total_size; + _Py_jit_shim_size = alloc_size; return (_PyJitEntryFuncPtr)memory; } diff --git a/Python/optimizer.c b/Python/optimizer.c index 466729b158d345..09936cafbb0a98 100644 --- a/Python/optimizer.c +++ b/Python/optimizer.c @@ -1469,10 +1469,19 @@ stack_allocate(_PyUOpInstruction *buffer, _PyUOpInstruction *output, int length) if (uop == _NOP) { continue; } + if (uop <= 0 || uop > MAX_UOP_ID) { + return 0; + } int new_depth = _PyUop_Caching[uop].best[depth]; + if (new_depth < 0 || new_depth > MAX_CACHED_REGISTER) { + return 0; + } if (new_depth != depth) { - write->opcode = _PyUop_SpillsAndReloads[depth][new_depth]; - assert(write->opcode != 0); + uint16_t spill_reload = _PyUop_SpillsAndReloads[depth][new_depth]; + if (spill_reload == 0 || spill_reload > MAX_UOP_REGS_ID) { + return 0; + } + write->opcode = spill_reload; write->format = UOP_FORMAT_TARGET; write->oparg = 0; write->target = 0; @@ -1481,10 +1490,16 @@ stack_allocate(_PyUOpInstruction *buffer, _PyUOpInstruction *output, int length) } *write = buffer[i]; uint16_t new_opcode = _PyUop_Caching[uop].entries[depth].opcode; - assert(new_opcode != 0); + if (new_opcode == 0 || new_opcode > MAX_UOP_REGS_ID) { + return 0; + } write->opcode = new_opcode; write++; - depth = _PyUop_Caching[uop].entries[depth].output; + int output_depth = _PyUop_Caching[uop].entries[depth].output; + if (output_depth < 0 || output_depth > MAX_CACHED_REGISTER) { + return 0; + } + depth = output_depth; } return (int)(write - output); } @@ -1542,6 +1557,9 @@ uop_optimize( OPT_HIST(effective_trace_length(buffer, length), optimized_trace_length_hist); _PyUOpInstruction *output = &_tstate->jit_tracer_state->uop_array[0]; length = stack_allocate(buffer, output, length); + if (length <= 0) { + return 0; + } buffer = output; length = prepare_for_execution(buffer, length); assert(length <= UOP_MAX_TRACE_LENGTH); diff --git a/Tools/jit/LuaJIT b/Tools/jit/LuaJIT new file mode 160000 index 00000000000000..659a61693aa3b8 --- /dev/null +++ b/Tools/jit/LuaJIT @@ -0,0 +1 @@ +Subproject commit 659a61693aa3b87661864ad0f12eee14c865cd7f diff --git a/Tools/jit/_asm_to_dasc.py b/Tools/jit/_asm_to_dasc.py new file mode 100644 index 00000000000000..8062fe1b7c770f --- /dev/null +++ b/Tools/jit/_asm_to_dasc.py @@ -0,0 +1,2093 @@ +"""Convert Intel-syntax x86-64 assembly (from Clang) to DynASM .dasc format. + +This module transforms the optimized .s files produced by Clang (Intel syntax, +medium code model, -fno-pic -fno-plt) into DynASM directives suitable for the +DynASM Lua preprocessor (dynasm.lua). + +All labels (uop entry points, internal branch targets, JIT jump/error targets) +use DynASM PC labels (=>N), which are dynamically allocated. The label +numbering scheme is: + + [0 .. trace_len-1] : uop entry point labels + [trace_len .. trace_len+K-1] : internal stencil labels (allocated per-emit) + +External symbol references (function pointers, type addresses) use +``emit_call_ext()`` for direct calls and ``emit_mov_imm()`` for address loads, +both of which generate optimal encodings at JIT compile time. + +JIT Register Roles +~~~~~~~~~~~~~~~~~~ + +The preserve_none calling convention assigns fixed register roles (see +REG_FRAME, REG_STACK_PTR, REG_TSTATE, REG_EXECUTOR constants below). +Frame struct offsets (FRAME_IP_OFFSET, FRAME_STACKPOINTER_OFFSET) are +also defined as constants to avoid hardcoded magic numbers. + +Peephole Optimization +~~~~~~~~~~~~~~~~~~~~~ + +After converting each stencil to DynASM assembly, a multi-pass peephole +optimizer folds emit_mov_imm sequences with subsequent instructions. +Since emit_mov_imm values are C expressions evaluated at JIT compile time, +folding allows moving work from runtime to compile time. + +Two categories of patterns: + + 1. **emit_mov_imm chain patterns** (Patterns 1-15): Start from an + emit_mov_imm call and attempt to fold the loaded value into subsequent + instructions (truncation, arithmetic, branch elimination, memory + indexing, ALU folding, etc.). Handled by ``_fold_mov_imm()``. + + 2. **Standalone patterns** (SP0-SP3): Independent patterns that operate + on raw DynASM assembly lines: + SP0 — Preserve flags across immediate loads + SP1 — Store-reload elimination (hot jcc fallthrough) + SP2 — Cold-path reload insertion (for __del__ safety) + SP3 — Inverted store-reload deferral (hot jcc jump-to-merge) + Registered in ``_STANDALONE_PATTERNS``. + +Use ``--peephole-stats`` in ``build.py`` to see how often each fires. + +Cross-Stencil Optimizations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Some optimizations span stencil boundaries and are handled at JIT compile +time (in jit.c) rather than at build time in this module: + + - **Frame merging**: Consecutive stencils with matching frame sizes + can elide the epilogue/prologue pair. Managed by ``emit_trace()`` + in jit.c. + + - **SET_IP delta encoding**: When consecutive SET_IP values are close, + emit ``add qword [frame+56], delta`` instead of a full mov. +""" + +from __future__ import annotations + +import dataclasses +import enum +import re +import typing + + +# ── Register name mapping ─────────────────────────────────────────────── +# REX-prefix byte registers that DynASM doesn't natively understand. +# We teach DynASM these names via .define directives in the .dasc header +# (see _dasc_writer.py), so we keep them as-is in the assembly output +# for readability — no Rb(N) substitution needed. +_REX_BYTE_REGS = frozenset({ + "spl", "bpl", "sil", "dil", + "r8b", "r9b", "r10b", "r11b", "r12b", "r13b", "r14b", "r15b", +}) + +# Mapping from 64-bit register name → human-readable register index +# constant name (used for emit_mov_imm calls). Prefixed with JREG_ +# to avoid collisions with system headers (e.g. ucontext.h REG_R8). +_REG_IDX_NAME: dict[str, str] = { + "rax": "JREG_RAX", "rcx": "JREG_RCX", "rdx": "JREG_RDX", "rbx": "JREG_RBX", + "rsp": "JREG_RSP", "rbp": "JREG_RBP", "rsi": "JREG_RSI", "rdi": "JREG_RDI", + "r8": "JREG_R8", "r9": "JREG_R9", "r10": "JREG_R10", "r11": "JREG_R11", + "r12": "JREG_R12", "r13": "JREG_R13", "r14": "JREG_R14", "r15": "JREG_R15", +} + + +# ── _JIT_* symbol → C expression ─────────────────────────────────────── +_JIT_SYMBOL_EXPR: dict[str, str] = { + # The stencil template uses PATCH_VALUE(TYPE, NAME, ALIAS) which + # expands to ``TYPE NAME = (TYPE)(uintptr_t)&ALIAS;``. The compiler + # generates ``movabs REG, offset _JIT_*`` to load the symbol's + # address directly into a register — there is NO dereference. The + # original stencil JIT patches the movabs immediate with the VALUE + # itself (not a pointer), so here we emit the value too. + "_JIT_OPERAND0": "instruction->operand0", + "_JIT_OPERAND1": "instruction->operand1", + "_JIT_OPARG": "instruction->oparg", + "_JIT_OPARG_16": "instruction->oparg", + "_JIT_OPERAND0_16": "instruction->operand0", + "_JIT_OPERAND0_32": "instruction->operand0", + "_JIT_OPERAND1_16": "instruction->operand1", + "_JIT_OPERAND1_32": "instruction->operand1", + "_JIT_TARGET": "instruction->target", +} + +# Map 64-bit register name → 32-bit register name. +_REG64_TO_REG32: dict[str, str] = { + "rax": "eax", "rbx": "ebx", "rcx": "ecx", "rdx": "edx", + "rsi": "esi", "rdi": "edi", "rbp": "ebp", "rsp": "esp", + "r8": "r8d", "r9": "r9d", "r10": "r10d", "r11": "r11d", + "r12": "r12d", "r13": "r13d", "r14": "r14d", "r15": "r15d", +} + +# Map 64-bit register name → DynASM register index for Rq()/Rd() macros. +_REG64_TO_IDX: dict[str, int] = { + "rax": 0, "rcx": 1, "rdx": 2, "rbx": 3, + "rsp": 4, "rbp": 5, "rsi": 6, "rdi": 7, + "r8": 8, "r9": 9, "r10": 10, "r11": 11, + "r12": 12, "r13": 13, "r14": 14, "r15": 15, +} + +# Map any register name (64-bit, 32-bit, 16-bit) to DynASM index +_ANY_REG_TO_IDX: dict[str, int] = {**_REG64_TO_IDX} +# Map any register name to the human-readable REG_* constant name +_ANY_REG_TO_NAME: dict[str, str] = {**_REG_IDX_NAME} +# 16-bit register names +_REG16_NAMES: dict[str, str] = { + "rax": "ax", "rbx": "bx", "rcx": "cx", "rdx": "dx", + "rsi": "si", "rdi": "di", "rbp": "bp", "rsp": "sp", + "r8": "r8w", "r9": "r9w", "r10": "r10w", "r11": "r11w", + "r12": "r12w", "r13": "r13w", "r14": "r14w", "r15": "r15w", +} +for _r64, _idx in list(_REG64_TO_IDX.items()): + _r32 = _REG64_TO_REG32[_r64] + _r16 = _REG16_NAMES[_r64] + _ANY_REG_TO_IDX[_r32] = _idx + _ANY_REG_TO_IDX[_r16] = _idx + _name = _REG_IDX_NAME[_r64] + _ANY_REG_TO_NAME[_r32] = _name + _ANY_REG_TO_NAME[_r16] = _name + +# Map register index → set of all alias names (for liveness analysis) +_IDX_TO_ALL_NAMES: dict[int, set[str]] = {} +for _name, _idx in _ANY_REG_TO_IDX.items(): + _IDX_TO_ALL_NAMES.setdefault(_idx, set()).add(_name) +# Add 8-bit register names manually +_8BIT_NAMES: dict[int, list[str]] = { + 0: ["al", "ah"], 1: ["cl", "ch"], 2: ["dl", "dh"], 3: ["bl", "bh"], + 4: ["spl"], 5: ["bpl"], 6: ["sil"], 7: ["dil"], + 8: ["r8b"], 9: ["r9b"], 10: ["r10b"], 11: ["r11b"], + 12: ["r12b"], 13: ["r13b"], 14: ["r14b"], 15: ["r15b"], +} +for _idx, _names in _8BIT_NAMES.items(): + for _n in _names: + _IDX_TO_ALL_NAMES.setdefault(_idx, set()).add(_n) + _ANY_REG_TO_IDX[_n] = _idx + +# ── Compiled regexes ─────────────────────────────────────────────────── + +# movabs REG, offset SYMBOL or movabs REG, offset SYMBOL+N +_RE_MOVABS = re.compile( + r"^\s*movabs\s+(\w+),\s*offset\s+([\w.]+)(?:\+(\d+))?\s*(?:#.*)?$" +) + +# movabs REG, IMM (plain integer immediate, no "offset" keyword) +_RE_MOVABS_IMM = re.compile( + r"^\s*movabs\s+(\w+),\s*(-?\d+)\s*(?:#.*)?$" +) + +# call/jmp qword ptr [rip + SYM@GOTPCREL] +_RE_GOTPCREL_CALL = re.compile( + r"^\s*(call|jmp)\s+qword\s+ptr\s+\[rip\s*\+\s*([\w.]+)@GOTPCREL\]\s*(?:#.*)?$" +) + +# Generic instruction with GOTPCREL in a memory operand +_RE_GOTPCREL_MEM = re.compile( + r"^(\s*\w+\s+)(.*?)(byte|word|dword|qword)\s+ptr\s+" + r"\[rip\s*\+\s*([\w.]+)@GOTPCREL\](.*?)$" +) + +# jmp/jcc to _JIT_JUMP_TARGET or _JIT_ERROR_TARGET +_RE_JIT_BRANCH = re.compile( + r"^\s*(j\w+)\s+(_JIT_JUMP_TARGET|_JIT_ERROR_TARGET)\s*(?:#.*)?$" +) + +# jmp/jcc to _JIT_CONTINUE or .L_JIT_CONTINUE +_RE_JIT_CONTINUE = re.compile( + r"^\s*(j\w+)\s+(?:\.L)?_JIT_CONTINUE\s*(?:#.*)?$" +) + +# Pattern for recognized local labels from LLVM (broad match for first pass) +# This matches any label-like definition that is NOT a _JIT_* special symbol. +_RE_ANY_LABEL_DEF = re.compile(r"^([\w.]+):\s*(?:#.*)?$") + +# Local branch: jmp/jcc/call to a non-_JIT_ label (matched dynamically) +# (Compiled after the first pass discovers which labels are local) + +# Local label definition (compiled after first pass) +# These are all just re-used later as local_map lookups + +_RE_ENTRY = re.compile(r"^_JIT_ENTRY:\s*(?:#.*)?$") +_RE_CONTINUE_LABEL = re.compile(r"^(?:\.L)?_JIT_CONTINUE:\s*(?:#.*)?$") +_RE_FUNC_END = re.compile(r"^\.Lfunc_end\d+:\s*$") + +# Section directives +_RE_COLD_SECTION = re.compile(r'^\s*\.section\s+(?:\.text\.cold|__llvm_cold)') +_RE_TEXT_SECTION = re.compile(r"^\s*\.text\s*$") +_RE_RODATA_SECTION = re.compile(r"^\s*\.section\s+\.l?rodata") + +# Data inside rodata +_RE_ASCIZ = re.compile(r'^\s*\.asciz\s+"(.*?)"') +_RE_DATA_LABEL = re.compile(r"^(\.L[\w.]+):\s*(?:#.*)?$") +_RE_BYTE_DATA = re.compile(r"^\s*\.(byte|short|long|quad)\s+(.*)") + +# Directives to skip entirely +_RE_SKIP = re.compile( + r"^\s*\.(file|globl|type|size|addrsig|addrsig_sym|hidden|ident|" + r"intel_syntax|section\s+\"\.note|p2align|cfi_\w+)\b" +) + +_RE_BLANK = re.compile(r"^\s*(?:#.*)?$") +_RE_ALIGN = re.compile(r"^\s*\.p2align\s+(\d+)") + +# LLVM JIT fold pass inline-asm markers. +# Format: nop # @@JIT_MOV_IMM %reg, @@ +_RE_JIT_MARKER = re.compile( + r"^\s*nop\s+#\s*@@(JIT_MOV_IMM|JIT_TEST|JIT_CMP|JIT_FRAME_ANCHOR)(?:\s+(%?\w+)(?:,\s*(.+?))?)?@@\s*$" +) + +# ── Peephole optimization patterns ──────────────────────────────────── + +# emit_mov_imm(Dst, REG_NAME_OR_IDX, EXPR); +# emit_mov_imm_preserve_flags(Dst, REG_NAME_OR_IDX, EXPR); +_RE_EMIT_MOV_IMM = re.compile( + r"^(\s*)emit_mov_imm(?:_preserve_flags)?\(Dst,\s*(\w+),\s*(.+?)\);$" +) + +# ── Regexes for parse_line() ─────────────────────────────────────────── +# +# These patterns are used by parse_line() to classify DynASM output lines +# into typed Line objects (Asm, CCall, Label, Section, FuncDef, Blank). + +# C helper calls: emit_mov_imm(Dst, REG, EXPR); +# emit_mov_imm_preserve_flags(Dst, REG, EXPR); +# Re-uses the existing _RE_EMIT_MOV_IMM but with different group semantics. +_RE_C_CALL_MOV_IMM = re.compile( + r"^(\s*)(emit_mov_imm(?:_preserve_flags)?)\(Dst,\s*(.+)\);$" +) +# C helper calls: emit_call_ext(Dst, ARGS); +_RE_C_CALL_EXT = re.compile(r"^(\s*)emit_call_ext\(Dst,\s*(.+)\);$") +# C helper calls: emit_cmp_reg_imm(Dst, ARGS); +_RE_C_CALL_CMP = re.compile(r"^(\s*)emit_cmp_reg_imm\(Dst,\s*(.+)\);$") +# C helper calls: emit_cmp_mem64_imm(Dst, ARGS); +_RE_C_CALL_CMP_MEM64 = re.compile(r"^(\s*)emit_cmp_mem64_imm\(Dst,\s*(.+)\);$") +# C helper calls: emit_{test,and,or,xor}_reg_imm(Dst, ARGS); +_RE_C_CALL_ALU = re.compile( + r"^(\s*)(emit_(?:test|and|or|xor)_reg_imm)\(Dst,\s*(.+)\);$" +) + +# DynASM label definition: |=>LABEL_NAME: +_RE_DASC_LABEL = re.compile(r"^\s*\|\s*=>\s*(.+?)\s*:\s*$") +# DynASM section directive: |.code, |.cold, |.data +_RE_DASC_SECTION = re.compile(r"^\s*\|\s*\.(code|cold|data)\b") +# DynASM assembly instruction: | mnemonic [operands] +_RE_ASM_LINE = re.compile(r"^\s*\|\s*(\w+)(?:\s+(.+))?\s*$") +# Function definition: static void emit_OPNAME(...) +_RE_DASC_FUNC_DEF = re.compile(r"^\s*static\s+void\s+emit_\w+\s*\(") + +# ── Typed operand and line classification ────────────────────────────── +# +# Instead of parsing lines into flat strings and then re-matching with +# per-pattern regexes, we parse each line *once* into typed objects: +# +# Operand types: Reg, Mem, Imm (what instructions operate on) +# Line types: Asm, CCall, Label, Section, FuncDef, Blank, CCode +# +# Pattern functions use Python 3.10+ structural pattern matching +# (match/case) to destructure these objects directly. For example: +# +# match parse_line(text): +# case Asm("mov", dst=Reg(name="rax"), src=Mem(size="qword")): +# ... # handle mov rax, qword [...] +# case CCall(kind=CCallKind.CALL_EXT): +# ... # handle emit_call_ext(...) +# +# This replaces the old LineKind enum + monolithic Line dataclass with +# proper typing that enables exhaustive matching and IDE autocompletion. +# +# Design principles: +# - Operand types are frozen (immutable, hashable) for safe matching +# - Each line type carries only the fields relevant to that type +# - Raw text preserved in every line type for output generation +# - Helper functions (is_call, is_branch, etc.) work across types + +# ── Operand types ────────────────────────────────────────────────────── + + +@dataclasses.dataclass(frozen=True, slots=True) +class Reg: + """Register operand (e.g., rax, eax, al, r14d). + + Attributes: + name: Register name as it appears in the assembly (case-preserved). + """ + + name: str + + @property + def idx(self) -> int | None: + """Canonical register index (0=rax, 1=rcx, ..., 15=r15).""" + return _ANY_REG_TO_IDX.get(self.name.lower()) + + @property + def bits(self) -> int: + """Register width in bits (8, 16, 32, or 64).""" + return _reg_bits(self.name) + + @property + def jreg(self) -> str | None: + """JREG_* constant name (e.g., "JREG_RAX"), or None.""" + return _ANY_REG_TO_NAME.get(self.name.lower()) + + def __str__(self) -> str: + return self.name + + +@dataclasses.dataclass(frozen=True, slots=True) +class Mem: + """Memory operand (e.g., qword [r14 + 8], byte [rax]). + + Attributes: + size: Size prefix ("qword", "dword", "word", "byte") or None. + base: Base register name, or None for complex addressing. + offset: Displacement (default 0). + index: Index register name, or None. + scale: Scale factor (1, 2, 4, 8), or None. + expr: Full bracket expression for output (e.g., "[r14 + 8]"). + """ + + size: str | None + base: str | None + offset: int = 0 + index: str | None = None + scale: int | None = None + expr: str = "" + + def __str__(self) -> str: + return f"{self.size} {self.expr}" if self.size else self.expr + + +@dataclasses.dataclass(frozen=True, slots=True) +class Imm: + """Immediate operand (e.g., 42, -1, 0xff). + + Attributes: + value: Numeric value of the immediate. + text: Original text representation (preserved for output). + """ + + value: int + text: str = "" + + def __str__(self) -> str: + return self.text if self.text else str(self.value) + + +# Union of all operand types, for type annotations. +Op = Reg | Mem | Imm + + +# ── Line types ───────────────────────────────────────────────────────── + + +@enum.unique +class CCallKind(enum.Enum): + """Sub-classification for C helper calls.""" + + MOV_IMM = "emit_mov_imm" + CALL_EXT = "emit_call_ext" + CMP_REG_IMM = "emit_cmp_reg_imm" + CMP_MEM64_IMM = "emit_cmp_mem64_imm" + ALU_REG_IMM = "emit_alu_reg_imm" # test/and/or/xor_reg_imm + OTHER = "other" + + +@dataclasses.dataclass(slots=True) +class Asm: + """Assembly instruction (e.g., ``| mov rax, qword [r14 + 8]``). + + Attributes: + mnemonic: Instruction mnemonic (e.g., "mov", "cmp", "je"). + dst: First (destination) operand as typed Reg/Mem/Imm, or None. + src: Second (source) operand as typed Reg/Mem/Imm, or None. + target: Branch target for jmp/jcc (e.g., "=>L(3)"), or None. + raw: Original line text (preserved for output). + """ + + mnemonic: str + dst: Op | None = None + src: Op | None = None + target: str | None = None + raw: str = "" + + def __str__(self) -> str: + return self.raw + + +@dataclasses.dataclass(slots=True) +class CCall: + """C helper call (e.g., ``emit_call_ext(Dst, ...)``). + + Attributes: + kind: Which helper (MOV_IMM, CALL_EXT, CMP_REG_IMM). + helper: Helper function name as emitted in the C source. + args: Raw argument string inside parentheses. + argv: Parsed argument tokens split at top-level commas. + indent: Leading whitespace (for replacement line generation). + raw: Original line text. + """ + + kind: CCallKind + helper: str = "" + args: str = "" + argv: tuple[str, ...] = () + indent: str = "" + raw: str = "" + + +@dataclasses.dataclass(slots=True) +class Label: + """Label definition (e.g., ``|=>L(3):``). + + Attributes: + name: Label identifier (e.g., "L(3)", "uop_label"). + raw: Original line text. + """ + + name: str + raw: str = "" + + +@dataclasses.dataclass(slots=True) +class Section: + """Section directive (e.g., ``|.code``, ``|.cold``). + + Attributes: + name: Section name ("code", "cold", or "data"). + raw: Original line text. + """ + + name: str + raw: str = "" + + +@dataclasses.dataclass(slots=True) +class FuncDef: + """Function definition (e.g., ``static void emit_BINARY_OP_...``). + + Attributes: + raw: Original line text. + """ + + raw: str = "" + + +@dataclasses.dataclass(slots=True) +class Blank: + """Empty line or comment. + + Attributes: + raw: Original line text. + """ + + raw: str = "" + + +@dataclasses.dataclass(slots=True) +class CCode: + """C code line (if/else/braces/etc.). + + Attributes: + raw: Original line text. + """ + + raw: str = "" + + +# Union of all line types — use as type annotation for parsed lines. +Line = Asm | CCall | Label | Section | FuncDef | Blank | CCode + + +# ── Operand parsing ─────────────────────────────────────────────────── + +_SIZE_PREFIXES = frozenset(("qword", "dword", "word", "byte")) + +_ALL_REGS = frozenset(_ANY_REG_TO_IDX.keys()) + +_RE_MEM_TERM_SCALED = re.compile(r"^(\d+)\s*\*\s*(\w+)$") +_RE_MEM_TERM_SCALED_REV = re.compile(r"^(\w+)\s*\*\s*(\d+)$") + + +def _parse_mem_expr(inner: str) -> tuple[str | None, int, str | None, int | None]: + """Parse the expression inside memory brackets. + + Examples: + "r14" → ("r14", 0, None, None) + "r14 + 8" → ("r14", 8, None, None) + "r14 - 8" → ("r14", -8, None, None) + "rdi + rcx*4" → ("rdi", 0, "rcx", 4) + "rdi + 4*rcx + 8" → ("rdi", 8, "rcx", 4) + "rcx*8+0" → (None, 0, "rcx", 8) + """ + base = None + offset = 0 + index = None + scale = None + + # Split on + and - while preserving the sign operator + terms = re.split(r"\s*([+-])\s*", inner.strip()) + sign = 1 + for term in terms: + term = term.strip() + if not term: + continue + if term == "+": + sign = 1 + continue + if term == "-": + sign = -1 + continue + # scale*register (e.g., "4*rcx") + m = _RE_MEM_TERM_SCALED.match(term) + if m: + scale = int(m.group(1)) + index = m.group(2) + continue + # register*scale (e.g., "rcx*4") + m = _RE_MEM_TERM_SCALED_REV.match(term) + if m: + index = m.group(1) + scale = int(m.group(2)) + continue + # Numeric displacement + try: + offset += sign * int(term, 0) + continue + except ValueError: + pass + # Register — assign to base or index + if base is None: + base = term + elif index is None: + index = term + scale = 1 + + return base, offset, index, scale + + +def _parse_operand(text: str) -> Op: + """Parse a single operand string into a typed Reg, Mem, or Imm. + + Examples: + "rax" → Reg("rax") + "qword [r14 + 8]" → Mem(size="qword", base="r14", offset=8, ...) + "[rax]" → Mem(size=None, base="rax", ...) + "42" → Imm(42, "42") + "-1" → Imm(-1, "-1") + """ + text = text.strip() + + # Memory operand: contains brackets + if "[" in text: + size = None + rest = text + for s in _SIZE_PREFIXES: + if rest.lower().startswith(s + " "): + size = s + rest = rest[len(s) :].strip() + break + # Extract bracket contents + bracket_start = rest.index("[") + bracket_end = rest.rindex("]") + inner = rest[bracket_start + 1 : bracket_end].strip() + expr = rest[bracket_start : bracket_end + 1] + base, offset, index, scale = _parse_mem_expr(inner) + return Mem( + size=size, + base=base, + offset=offset, + index=index, + scale=scale, + expr=expr, + ) + + # Register: check all known register names + if text.lower() in _ALL_REGS: + return Reg(text) + + # Immediate: try parsing as integer + try: + return Imm(int(text, 0), text) + except ValueError: + pass + + # DynASM label reference (e.g., "=>L(3)") — treat as Imm-like + # (used by branch targets, but normally handled via target field) + return Imm(0, text) + + +def _split_operands(operands: str) -> list[str]: + """Split operand string on commas, respecting brackets. + + ``"qword [r13 + 64], r14"`` → ``["qword [r13 + 64]", "r14"]`` + ``"rax"`` → ``["rax"]`` + ``"qword [rax + rbx*8 + 16]"`` → ``["qword [rax + rbx*8 + 16]"]`` + """ + parts: list[str] = [] + depth = 0 + start = 0 + for j, ch in enumerate(operands): + if ch == "[": + depth += 1 + elif ch == "]": + depth -= 1 + elif ch == "," and depth == 0: + parts.append(operands[start:j]) + start = j + 1 + parts.append(operands[start:]) + return parts + + +def _split_call_args(args: str) -> tuple[str, ...]: + """Split C helper arguments on top-level commas. + + Unlike ``_split_operands()``, this helper understands parentheses as well + as brackets so expressions like ``(uintptr_t)&PyType_Type`` stay intact. + """ + parts: list[str] = [] + depth = 0 + start = 0 + for j, ch in enumerate(args): + if ch in "([": + depth += 1 + elif ch in ")]": + depth = max(0, depth - 1) + elif ch == "," and depth == 0: + parts.append(args[start:j].strip()) + start = j + 1 + parts.append(args[start:].strip()) + return tuple(part for part in parts if part) + + +# ── DynASM assembly line parser (produces typed Asm / CCall / etc.) ──── + +_BRANCH_MNEMONICS = frozenset(( + "jmp", "je", "jne", "jz", "jnz", "ja", "jae", "jb", "jbe", + "jg", "jge", "jl", "jle", "js", "jns", "jo", "jno", "jp", "jnp", +)) + + +def parse_line(text: str) -> Line: + """Parse a DynASM output line into a typed Line object. + + Returns one of: Asm, CCall, Label, Section, FuncDef, Blank, CCode. + Each type carries only the fields relevant to that line kind, with + structured operands (Reg/Mem/Imm) for assembly instructions. + + Classification priority: + 1. C helper calls (emit_mov_imm, emit_call_ext, emit_cmp_reg_imm) + 2. DynASM labels (|=>NAME:) + 3. DynASM section directives (|.code, |.cold) + 4. DynASM assembly instructions (| mnemonic ...) + 5. Function definitions (static void emit_...) + 6. Blanks / comments + 7. Everything else (C code) + """ + stripped = text.strip() + + # ── C helper calls ── + m = _RE_C_CALL_MOV_IMM.match(stripped) + if m: + args = m.group(3) + return CCall( + kind=CCallKind.MOV_IMM, + helper=m.group(2), + indent=m.group(1), + args=args, + argv=_split_call_args(args), + raw=text, + ) + m = _RE_C_CALL_EXT.match(stripped) + if m: + args = m.group(2) + return CCall( + kind=CCallKind.CALL_EXT, + helper="emit_call_ext", + indent=m.group(1), + args=args, + argv=_split_call_args(args), + raw=text, + ) + m = _RE_C_CALL_CMP.match(stripped) + if m: + args = m.group(2) + return CCall( + kind=CCallKind.CMP_REG_IMM, + helper="emit_cmp_reg_imm", + indent=m.group(1), + args=args, + argv=_split_call_args(args), + raw=text, + ) + m = _RE_C_CALL_CMP_MEM64.match(stripped) + if m: + args = m.group(2) + return CCall( + kind=CCallKind.CMP_MEM64_IMM, + helper="emit_cmp_mem64_imm", + indent=m.group(1), + args=args, + argv=_split_call_args(args), + raw=text, + ) + m = _RE_C_CALL_ALU.match(stripped) + if m: + args = m.group(3) + return CCall( + kind=CCallKind.ALU_REG_IMM, + helper=m.group(2), + indent=m.group(1), + args=args, + argv=_split_call_args(args), + raw=text, + ) + + # ── DynASM labels ── + m = _RE_DASC_LABEL.match(stripped) + if m: + return Label(name=m.group(1), raw=text) + + # ── DynASM section directives ── + m = _RE_DASC_SECTION.match(stripped) + if m: + return Section(name=m.group(1), raw=text) + + # ── DynASM assembly instructions ── + m = _RE_ASM_LINE.match(stripped) + if m: + mnemonic = m.group(1) + operands_str = m.group(2) + dst: Op | None = None + src: Op | None = None + target: str | None = None + + if operands_str: + # Branch instructions: operand is a target label, not a dst/src + if mnemonic in _BRANCH_MNEMONICS: + target = operands_str.strip() + else: + parts = _split_operands(operands_str) + if parts: + dst = _parse_operand(parts[0]) + if len(parts) > 1: + src = _parse_operand(parts[1]) + + return Asm(mnemonic=mnemonic, dst=dst, src=src, + target=target, raw=text) + + # ── Function definitions ── + if _RE_DASC_FUNC_DEF.match(stripped): + return FuncDef(raw=text) + + # ── Blanks / comments ── + if not stripped or stripped.startswith("//"): + return Blank(raw=text) + + # ── Everything else (C code) ── + return CCode(raw=text) + + +def parse_lines(lines: list[str]) -> list[Line]: + """Batch-parse a list of DynASM output lines into typed objects.""" + return [parse_line(text) for text in lines] + + +# ── Query helpers (work across the Line type hierarchy) ──────────────── + + +def is_call(line: Line) -> bool: + """Is this line a function call (ASM 'call' or C emit_call_ext)?""" + match line: + case CCall(kind=CCallKind.CALL_EXT): + return True + case Asm(mnemonic="call"): + return True + return False + + +def is_branch(line: Line) -> bool: + """Is this a conditional jump (jne, je, jae, etc.)?""" + match line: + case Asm(mnemonic=m) if m.startswith("j") and m != "jmp": + return True + return False + + +def is_jump(line: Line) -> bool: + """Is this an unconditional jmp?""" + return isinstance(line, Asm) and line.mnemonic == "jmp" + + +def branch_target(line: Line) -> str | None: + """Extract branch/jump target label (e.g., "=>L(3)"), or None.""" + match line: + case Asm(target=t) if t and t.startswith("=>"): + return t + return None + + +def line_raw(line: Line) -> str: + """Get the original text of any Line type.""" + return line.raw + + +@dataclasses.dataclass(frozen=True, slots=True) +class _LineEffect: + """Structured dataflow summary for one parsed line.""" + + reads: frozenset[int] = dataclasses.field(default_factory=frozenset) + full_writes: frozenset[int] = dataclasses.field(default_factory=frozenset) + partial_writes: frozenset[int] = dataclasses.field(default_factory=frozenset) + uses_flags: bool = False + writes_flags: bool = False + + +@dataclasses.dataclass(frozen=True, slots=True) +class _BasicBlock: + """Basic block inside one emitted stencil function.""" + + start: int + end: int + labels: tuple[str, ...] = () + successors: tuple[int, ...] = () + + +@dataclasses.dataclass(frozen=True, slots=True) +class _PeepholeFunction: + """Structured view of one emitted stencil function.""" + + start: int + end: int + blocks: tuple[_BasicBlock, ...] = () + + +def _operand_regs(op: Op | None) -> frozenset[int]: + """Registers read while evaluating an operand.""" + regs: set[int] = set() + match op: + case Reg(idx=idx) if idx is not None: + regs.add(idx) + case Mem(base=base, index=index): + for name in (base, index): + if name is None: + continue + idx = Reg(name).idx + if idx is not None: + regs.add(idx) + return frozenset(regs) + + +def _mem_uses_reg(mem: Mem, reg_idx: int) -> bool: + """Does a memory address depend on the given canonical register?""" + return reg_idx in _operand_regs(mem) + + +def _compute_c_depth(lines: list[Line]) -> list[int]: + """Track inline-C nesting depth for each emitted line.""" + c_depth = [0] * len(lines) + depth = 0 + for i, line in enumerate(lines): + stripped = line.raw.strip() + if stripped.endswith("{"): + depth += 1 + c_depth[i] = depth + if stripped == "}" or stripped == "} else {": + depth = max(0, depth - 1) + c_depth[i] = depth + return c_depth + + +def _build_blocks( + parsed: list[Line], + label_to_line: dict[str, int], + start: int, + end: int, +) -> tuple[_BasicBlock, ...]: + """Split one stencil function into coarse basic blocks.""" + + if start >= end: + return () + + starts = {start} + for i in range(start + 1, end): + if isinstance(parsed[i], (Label, Section)): + starts.add(i) + prev = parsed[i - 1] + if isinstance(prev, Asm) and (is_branch(prev) or is_jump(prev) or prev.mnemonic == "ret"): + starts.add(i) + + ordered = sorted(starts) + blocks: list[_BasicBlock] = [] + for idx, block_start in enumerate(ordered): + block_end = ordered[idx + 1] if idx + 1 < len(ordered) else end + labels: list[str] = [] + j = block_start + while j < block_end and isinstance(parsed[j], Label): + labels.append(parsed[j].name) + j += 1 + + successors: list[int] = [] + last = parsed[block_end - 1] + if isinstance(last, Asm): + target = branch_target(last) + if is_branch(last): + if block_end < end: + successors.append(block_end) + if target and target.startswith("=>"): + target_idx = label_to_line.get(target[2:]) + if target_idx is not None: + successors.append(target_idx) + elif is_jump(last): + if target and target.startswith("=>"): + target_idx = label_to_line.get(target[2:]) + if target_idx is not None: + successors.append(target_idx) + elif last.mnemonic != "ret" and block_end < end: + successors.append(block_end) + elif block_end < end: + successors.append(block_end) + + blocks.append( + _BasicBlock( + start=block_start, + end=block_end, + labels=tuple(labels), + successors=tuple(dict.fromkeys(successors)), + ) + ) + return tuple(blocks) + + +@dataclasses.dataclass(slots=True) +class _PeepholeProgram: + """Parsed view of the current emitted DynASM function bodies.""" + + lines: list[str] + parsed: list[Line] + c_depth: list[int] + label_to_line: dict[str, int] + effects: list[_LineEffect] + functions: tuple[_PeepholeFunction, ...] + function_starts: frozenset[int] + function_end_by_line: list[int] + + @classmethod + def from_lines(cls, lines: list[str]) -> "_PeepholeProgram": + parsed = parse_lines(lines) + c_depth = _compute_c_depth(parsed) + label_to_line = { + line.name: i for i, line in enumerate(parsed) if isinstance(line, Label) + } + effects = [_line_effect(line) for line in parsed] + + func_starts = [i for i, line in enumerate(parsed) if isinstance(line, FuncDef)] + if not func_starts: + func_starts = [0] + ranges = [ + ( + start, + func_starts[idx + 1] + if idx + 1 < len(func_starts) + else len(lines), + ) + for idx, start in enumerate(func_starts) + ] + + functions = tuple( + _PeepholeFunction( + start=start, + end=end, + blocks=_build_blocks(parsed, label_to_line, start, end), + ) + for start, end in ranges + ) + function_end_by_line = [len(lines)] * len(lines) + for start, end in ranges: + for i in range(start, end): + function_end_by_line[i] = end + + return cls( + lines=lines, + parsed=parsed, + c_depth=c_depth, + label_to_line=label_to_line, + effects=effects, + functions=functions, + function_starts=frozenset(func_starts), + function_end_by_line=function_end_by_line, + ) + + def reg_dead_after(self, start: int, reg_idx: int) -> bool: + """Control-flow-aware deadness query using structured line effects.""" + if start >= len(self.lines): + return True + if reg_idx not in _IDX_TO_ALL_NAMES: + return False + + func_end = self.function_end_by_line[start] + start_depth = self.c_depth[start] + visited: set[int] = set() + worklist = [start] + + while worklist: + pos = worklist.pop() + while pos < func_end: + if pos in visited: + break + visited.add(pos) + + effect = self.effects[pos] + if reg_idx in effect.reads or reg_idx in effect.partial_writes: + return False + if reg_idx in effect.full_writes: + break + + successors = self.successors(pos, start_depth) + if not successors: + break + fallthrough = pos + 1 + if len(successors) == 1 and successors[0] == fallthrough: + pos = fallthrough + continue + for succ in successors[1:]: + if succ < func_end: + worklist.append(succ) + pos = successors[0] + + return True + + def successors(self, pos: int, start_depth: int) -> tuple[int, ...]: + """Reachable successors from one line in a deadness query.""" + if pos >= len(self.parsed): + return () + line = self.parsed[pos] + func_end = self.function_end_by_line[pos] + next_pos = pos + 1 + in_c_block = self.c_depth[pos] > start_depth + + match line: + case Asm(mnemonic="jmp", target=target): + if target and target.startswith("=>"): + target_idx = self.label_to_line.get(target[2:]) + if target_idx is not None: + if in_c_block and next_pos < func_end: + return (next_pos, target_idx + 1) + return (target_idx + 1,) + if in_c_block and next_pos < func_end: + return (next_pos,) + return () + case Asm(mnemonic="ret"): + if in_c_block and next_pos < func_end: + return (next_pos,) + return () + case Asm() if is_branch(line): + succs: list[int] = [] + if next_pos < func_end: + succs.append(next_pos) + target = branch_target(line) + if target and target.startswith("=>"): + target_idx = self.label_to_line.get(target[2:]) + if target_idx is not None: + succs.append(target_idx + 1) + return tuple(succs) + case _: + if next_pos < func_end: + return (next_pos,) + return () + + +def fmt_op(op: Op) -> str: + """Format a typed operand back to assembly text.""" + match op: + case Reg(name=n): + return n + case Mem(size=s, expr=e): + return f"{s} {e}" if s else e + case Imm(text=t) if t: + return t + case Imm(value=v): + return str(v) + + +# Operand size in bits based on register name (for cast selection) +def _reg_bits(name: str) -> int: + """Return the operand size in bits for a register name.""" + name = name.lower() + if name in ("al", "ah", "bl", "bh", "cl", "ch", "dl", "dh", + "spl", "bpl", "sil", "dil") or name.endswith("b"): + return 8 + if name in ("ax", "bx", "cx", "dx", "si", "di", "bp", "sp") \ + or name.endswith("w"): + return 16 + if name.startswith("e") or (name.startswith("r") and name.endswith("d")): + return 32 + return 64 + + +# ── Peephole optimizer — pattern-based architecture ──────────────────── +# +# Inspired by _optimizers.py's clean separation of concerns, the peephole +# operates as a *pattern registry* with a simple driver loop. Each pattern +# is a small, self-contained function that examines lines at position ``i`` +# and returns a ``_Match`` (consumed lines + output lines) or ``None``. +# +# Two categories of patterns: +# +# 1. **emit_mov_imm chain patterns** — These fire when the current line +# is ``emit_mov_imm(Dst, REG, EXPR);`` and try to fold subsequent +# instructions into the immediate expression. They compose: e.g. +# mov_imm → movzx → shl chains. Handled by ``_fold_mov_imm()``. +# +# 2. **Standalone patterns** — Independent patterns that operate on raw +# DynASM assembly lines (e.g. store-reload elimination). Each is a +# function registered in ``_STANDALONE_PATTERNS``. +# +# Adding a new pattern: write a function matching the ``_PatternFn`` +# signature, add it to ``_STANDALONE_PATTERNS``, done. + + +@dataclasses.dataclass +class _Match: + """Result of a successful pattern match.""" + consumed: int # number of input lines consumed + output: list[str] # replacement lines to emit + + +@dataclasses.dataclass +class _FoldCtx: + """Shared context for emit_mov_imm fold patterns. + + Bundles the parameters that every ``_try_*`` function needs, eliminating + the 5 different call signatures that previously required a dispatch table + in ``_fold_mov_imm``. All terminal ``_try_*`` functions now take a single + ``ctx: _FoldCtx`` argument and return ``_Match | None``. + + The ``parsed`` list contains typed Line objects (Asm, CCall, etc.) + parallel to ``lines``. Pattern functions use ``ctx.cur`` to get the + current typed instruction and ``match``/``case`` for destructuring. + """ + program: _PeepholeProgram # structured program for effects/CFG queries + lines: list[str] # all input lines (raw text) + parsed: list[Line] # typed Line objects (parallel to lines) + i: int # current look-ahead position + src_idx: int # register index (0=RAX, 1=RCX, etc.) + src_name: str # "JREG_RAX" etc. + indent: str # indentation from the emit_mov_imm + expr: str # current expression (mutated by modifier phases) + + @property + def cur(self) -> Line: + """The parsed Line at the current look-ahead position.""" + return self.parsed[self.i] + + +@dataclasses.dataclass +class _PeepholeState: + """Mutable state carried across a single peephole pass. + + Reset at function boundaries (``static void emit_...``) to avoid + cross-stencil interference. + """ + + # Merge labels where we eliminated a stackpointer reload; + # cold-path jumps back to these need a reload inserted. + need_reload_before_jmp: set[str] = dataclasses.field(default_factory=set) + + def reset(self) -> None: + """Reset per-function state at stencil boundaries.""" + self.need_reload_before_jmp.clear() + + +_PeepholePassFn = typing.Callable[ + [_PeepholeProgram, int, list[str], _PeepholeState], + int | None, +] + + +@dataclasses.dataclass(frozen=True, slots=True) +class _PeepholePass: + """One pass step in the peephole pipeline.""" + + name: str + apply: _PeepholePassFn + + +# ── Peephole statistics ──────────────────────────────────────────────── +# +# Each pattern increments its counter when it fires. The stats are +# printed at the end of stencil conversion if PEEPHOLE_STATS is True. + +PEEPHOLE_STATS = False # set True or use --peephole-stats to see counts + +_peephole_counts: dict[str, int] = { + "P6_indexed_mem": 0, + "P8_alu_imm_fold": 0, + "P12_store_imm": 0, + "P13_dead_null_check": 0, + "P14_test_memory_fold": 0, + "P15_shift_fold": 0, + "P16_two_mov_add": 0, + "P17_lea_fold": 0, + "P18_dead_frame_anchor": 0, + "P19_inverse_mov_restore": 0, + "SP0_preserve_flags_mov_imm": 0, + "SP1_store_reload_elim": 0, + "SP2_cold_reload_insert": 0, + "SP3_inverted_store_reload": 0, + "dead_label_elim": 0, + "LLVM_fold_marker": 0, +} + + +def _stat(name: str) -> None: + """Increment a peephole pattern counter.""" + _peephole_counts[name] = _peephole_counts.get(name, 0) + 1 + + +def get_peephole_stats() -> dict[str, int]: + """Return current peephole statistics (pattern name → fire count).""" + return dict(_peephole_counts) + + +def reset_peephole_stats() -> None: + """Reset all peephole counters to zero.""" + for key in _peephole_counts: + _peephole_counts[key] = 0 + + +def print_peephole_stats() -> None: + """Print peephole statistics to stderr.""" + import sys + + total = sum(_peephole_counts.values()) + if total == 0: + return + print(f"\nPeephole optimization statistics ({total} total):", file=sys.stderr) + for name, count in sorted(_peephole_counts.items()): + if count > 0: + print(f" {name:30s}: {count:5d}", file=sys.stderr) + + +# ── x86-64 specific imports ──────────────────────────────────────────── +# Import architecture-specific peephole patterns, instruction effects, +# and calling convention constants from the x86-64 module. +# These are re-exported for backward compatibility with existing consumers +# (build.py, _dasc_writer.py, _targets.py, test_peephole.py). +# +# This import is placed here (after all generic definitions) because the +# amd64 module imports types and infrastructure from this file. By this +# point all those symbols are defined, so the circular import resolves +# cleanly. +from _asm_to_dasc_amd64 import ( # noqa: E402 + # Calling convention + REG_FRAME, REG_STACK_PTR, REG_TSTATE, REG_EXECUTOR, + FRAME_IP_OFFSET, FRAME_STACKPOINTER_OFFSET, + _C, _SP_STORE, _SP_RELOAD, _SP_RELOAD_LINE, + # JREG mapping + _JREG_TO_IDX, _IDX_TO_JREG, _parse_jreg, _jreg_arg_index, + # Instruction effects + _line_effect, _reg_write_sets, + # Pattern helpers + _reg_dead_after, + uses_reg, is_store_sp, is_reload_sp, + _is_flag_writer, _is_flag_consumer, + _preserve_flags_mov_imm, _parse_emit_mov_imm_call, + # Pass registry + _PEEPHOLE_PASSES, +) + + +# ── Driver ───────────────────────────────────────────────────────────── + + +def _peephole_pass(lines: list[str]) -> tuple[list[str], bool]: + """Single pass of peephole optimization. Returns (result, changed). + + Parses all input lines once into a structured program with typed lines, + helper-call arguments, block boundaries, and line effects. Registered + passes then match against that program rather than raw-text regex state. + """ + changed = False + result: list[str] = [] + state = _PeepholeState() + program = _PeepholeProgram.from_lines(lines) + i = 0 + while i < len(lines): + # Reset per-function state at stencil boundaries + if i in program.function_starts: + state.reset() + + matched = False + for peephole_pass in _PEEPHOLE_PASSES: + advance = peephole_pass.apply(program, i, result, state) + if advance is not None: + i += advance + changed = True + matched = True + break + if matched: + continue + + # No pattern matched — pass through + result.append(lines[i]) + i += 1 + return result, changed + + +def _peephole_optimize(lines: list[str]) -> list[str]: + """Apply peephole optimizations until fixpoint (max 5 passes). + + Multi-pass iteration enables chained optimizations: e.g. Pattern 10 + creates a new emit_mov_imm that Pattern 6 can then consume. + """ + for _pass in range(5): + result, changed = _peephole_pass(lines) + if not changed: + break + lines = result + else: + lines = result + return _eliminate_dead_labels(lines) + + +def _eliminate_dead_labels(lines: list[str]) -> list[str]: + """Remove labels that are defined but never jumped to. + + Many stencils contain structural labels (e.g. |=>L(2):) that serve as + fall-through targets but are never referenced by any jump instruction. + These dead labels clutter the assembly output and make it harder to read. + + Uses ``parse_lines()`` to identify labels by type, avoiding ad-hoc regex + duplication. + """ + parsed = parse_lines(lines) + # Collect all L(N) references (non-definition uses). + # A reference is any occurrence of =>L(N) that is NOT a label def. + text = "".join(lines) + referenced = set(re.findall(r"=>L\((\d+)\)(?!:)", text)) + result: list[str] = [] + removed = 0 + for line_obj in parsed: + # Only remove L(N) labels; leave uop_label, continue_label etc. + if isinstance(line_obj, Label): + m = re.match(r"L\((\d+)\)", line_obj.name) + if m and m.group(1) not in referenced: + removed += 1 + continue + result.append(line_obj.raw) + for _ in range(removed): + _stat("dead_label_elim") + return result + + +# ── Data structures ──────────────────────────────────────────────────── + + +@dataclasses.dataclass +class DataItem: + """A data blob from a .rodata section (e.g. an assert filename string).""" + + label: str + data: bytearray = dataclasses.field(default_factory=bytearray) + + +@dataclasses.dataclass +class ConvertedStencil: + """Result of converting one stencil to DynASM lines.""" + + opname: str + lines: list[str] + # Number of internal PC labels needed by this stencil (excludes GOT) + num_internal_labels: int + # Data items from .rodata sections + data_items: list[DataItem] + # Stack frame size requested by the stencil's function prologue + # (push rbp; [mov rbp, rsp;] sub rsp, N), or 0 if the stencil does not + # use a standard entry frame. The runtime allocates the maximum such + # frame once in the shim and the per-stencil prologue/epilogue is then + # stripped from the emitted DynASM. + frame_size: int = 0 + + +# ── Helpers ───────────────────────────────────────────────────────────── + +# x86 register names used to distinguish scale*reg from scale*immediate +_X86_REGS = frozenset([ + "rax", "rcx", "rdx", "rbx", "rsp", "rbp", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "eax", "ecx", "edx", "ebx", "esp", "ebp", "esi", "edi", + "r8d", "r9d", "r10d", "r11d", "r12d", "r13d", "r14d", "r15d", + "ax", "cx", "dx", "bx", "sp", "bp", "si", "di", +]) + +def _swap_scale_reg(m: re.Match) -> str: + """Swap scale*register to register*scale in SIB memory operands.""" + scale, name = m.group(1), m.group(2) + if name.lower() in _X86_REGS: + return f"{name}*{scale}" + return m.group(0) + + +def _fix_syntax(text: str) -> str: + """Apply DynASM syntax adjustments to an instruction.""" + # Convert tabs to spaces + text = text.replace("\t", " ") + # Strip inline comments (# ...) + text = re.sub(r"\s*#.*$", "", text) + # Convert xmmword to oword (DynASM uses oword for 128-bit) + text = re.sub(r"\bxmmword\b", "oword", text) + # Remove 'ptr' from size specifiers + text = re.sub( + r"\b(byte|word|dword|qword|tword|oword)\s+ptr\b", r"\1", text + ) + # REX-prefix byte register names (sil, dil, bpl, spl, r8b-r15b) are + # handled by .define directives in the .dasc header — keep them as-is + # for readability. + # Remove instruction suffixes (callq→call etc.) + text = re.sub(r"\b(call|ret|push|pop|jmp)q\b", r"\1", text) + # Convert negative byte immediates to unsigned (test byte [...], -128 → 0x80) + text = _fix_negative_byte_imm(text) + # DynASM requires explicit shift count: shr reg → shr reg, 1 + text = re.sub( + r"\b(shr|shl|sar|sal|ror|rol|rcr|rcl)\s+(\w+)\s*$", r"\1 \2, 1", text + ) + # DynASM requires register*scale, not scale*register in memory operands + # e.g. [8*rcx] → [rcx*8], [rdi + 8*rax + 80] → [rdi + rax*8 + 80] + text = re.sub(r"(\d+)\*(\w+)", _swap_scale_reg, text) + # DynASM can't encode SIB-only addressing [reg*scale] without a base register. + # Add explicit +0 displacement: [reg*N] → [reg*N+0] + text = re.sub(r"\[(\w+\*\d+)\]", r"[\1+0]", text) + # DynASM uses "movd" for both 32-bit and 64-bit GPR<->XMM transfers. + # When the GPR is 64-bit (rax, rcx, etc.), DynASM infers REX.W from the + # register name, producing the correct movq encoding. + # E.g. "movq rax, xmm1" → "movd rax, xmm1" + text = re.sub(r"\bmovq\s+(r\w+),\s*(xmm\d+)", r"movd \1, \2", text) + text = re.sub(r"\bmovq\s+(xmm\d+),\s*(r\w+)", r"movd \1, \2", text) + # Normalize whitespace + text = " ".join(text.split()) + return text + + +def _fix_negative_byte_imm(text: str) -> str: + """Convert negative immediates in byte operations to unsigned form. + + DynASM requires unsigned immediates for byte-sized operations. + E.g. ``test byte [...], -128`` → ``test byte [...], 128`` + """ + m = re.match(r"^(.+\bbyte\b.+,\s*)(-\d+)\s*$", text) + if m: + val = int(m.group(2)) + if -128 <= val < 0: + return f"{m.group(1)}{val & 0xFF}" + return text + + +def _jit_expr(symbol: str, offset: int = 0) -> str: + """C expression for a _JIT_* symbol value.""" + base = _JIT_SYMBOL_EXPR.get(symbol) + if base is None: + raise ValueError(f"Unknown _JIT_ symbol: {symbol}") + if offset: + return f"((uintptr_t)({base}) + {offset})" + return f"(uintptr_t)({base})" + + +# ── Main conversion ──────────────────────────────────────────────────── + + +def convert_stencil(opname: str, assembly: str, *, is_shim: bool = False) -> ConvertedStencil: + """Convert one stencil's optimized Intel-syntax .s to DynASM lines. + + Internal branch targets use PC labels relative to ``label_base``. + """ + lines: list[str] = [] + # Map local label name → internal index (0-based) + local_map: dict[str, int] = {} + data_items: list[DataItem] = [] + counter = 0 + + in_rodata = False + cur_data: DataItem | None = None + + def _local(name: str) -> int: + nonlocal counter + if name not in local_map: + local_map[name] = counter + counter += 1 + return local_map[name] + + def _flush_data(): + nonlocal cur_data + if cur_data and cur_data.data: + data_items.append(cur_data) + cur_data = None + + # Special labels we handle separately (not local) + _SPECIAL_LABELS = { + "_JIT_ENTRY", "_JIT_CONTINUE", ".L_JIT_CONTINUE", + } + + # First pass: discover all label definitions and branch/call targets + # to determine which ones are local (internal to this stencil) + all_label_defs: set[str] = set() + all_branch_targets: set[str] = set() + for raw in assembly.splitlines(): + if m := _RE_ANY_LABEL_DEF.match(raw): + label = m.group(1) + if label not in _SPECIAL_LABELS: + all_label_defs.add(label) + # Find branch targets (labels only, not register names) + m_br = re.match(r"^\s*(j\w+|call)\s+([\w.]+)\s*(?:#.*)?$", raw) + if m_br: + target = m_br.group(2) + if target not in _SPECIAL_LABELS and not target.startswith("_JIT_"): + all_branch_targets.add(target) + + # Local labels are those with actual definitions in this stencil. + # Only include branch targets that have matching definitions — + # targets without definitions are register-indirect branches + # (e.g., "call rax", "jmp r11") and not real label references. + local_labels = all_label_defs + for label in sorted(local_labels): + _local(label) + + # Second pass: emit DynASM lines + _FRAME_ANCHOR_MARKER = " // __JIT_FRAME_ANCHOR__" + cur_section = "code" # track current section for data entry restoration + for raw in assembly.splitlines(): + if _RE_BLANK.match(raw): + continue + if _RE_SKIP.match(raw): + continue + + # ── rodata collection ── + if _RE_RODATA_SECTION.match(raw): + in_rodata = True + continue + + if in_rodata: + if _RE_TEXT_SECTION.match(raw) or _RE_COLD_SECTION.match(raw): + _flush_data() + in_rodata = False + # fall through to handle section switch + elif m := _RE_DATA_LABEL.match(raw): + _flush_data() + cur_data = DataItem(label=m.group(1)) + continue + elif m := _RE_ASCIZ.match(raw): + if cur_data is not None: + s = ( + m.group(1) + .encode("raw_unicode_escape") + .decode("unicode_escape") + ) + cur_data.data.extend(s.encode("utf-8")) + cur_data.data.append(0) + continue + elif m := _RE_BYTE_DATA.match(raw): + if cur_data is not None: + kind = m.group(1) + for v in m.group(2).split(","): + # Strip inline comments (# ...) + v = re.sub(r"\s*#.*$", "", v).strip() + if not v: + continue + n = int(v, 0) + sz = {"byte": 1, "short": 2, "long": 4, "quad": 8}[kind] + cur_data.data.extend( + n.to_bytes(sz, "little", signed=(n < 0)) + ) + continue + else: + continue # skip unknown rodata lines + + # ── section switches ── + if _RE_COLD_SECTION.match(raw): + cur_section = "cold" + lines.append("") + lines.append(" // ---- cold path ----") + lines.append(" |.cold") + continue + if _RE_TEXT_SECTION.match(raw): + cur_section = "code" + lines.append(" |.code") + continue + + # ── special labels ── + if _RE_ENTRY.match(raw): + lines.append(" |=>uop_label:") + continue + if _RE_CONTINUE_LABEL.match(raw): + continue # handled by caller via continue_label + + # ── alignment ── + if m := _RE_ALIGN.match(raw): + lines.append(f" | .align {1 << int(m.group(1))}") + continue + + # ── local label definitions (any label we discovered in pass 1) ── + if m := _RE_ANY_LABEL_DEF.match(raw): + label = m.group(1) + if label in local_map: + idx = local_map[label] + lines.append(f" |=>L({idx}):") + continue + # Skip other label defs we don't recognize + continue + + # ── LLVM JIT fold pass markers ── + # Inline asm markers injected by jit_fold_pass.so: + # nop # @@JIT_MOV_IMM %rax, @@ + # The register name has a % prefix from AT&T syntax in inline asm. + if m := _RE_JIT_MARKER.match(raw): + kind, reg, expr = m.groups() + # Strip AT&T % prefix from register name. + if reg is not None: + reg = reg.lstrip("%") + reg_name = _REG_IDX_NAME.get(reg.lower()) if reg else None + if kind == "JIT_MOV_IMM": + if reg_name is not None: + lines.append( + f" emit_mov_imm_preserve_flags(Dst, {reg_name}, {expr});" + ) + else: + lines.append(f" | mov64 {reg}, {expr}") + elif kind == "JIT_TEST": + if reg_name is not None: + lines.append( + f" emit_test_reg_imm(Dst, {reg_name}, JREG_RAX, {expr});" + ) + elif kind == "JIT_CMP": + if reg_name is not None: + lines.append( + f" emit_cmp_reg_imm(Dst, {reg_name}, JREG_RAX, {expr});" + ) + elif kind == "JIT_FRAME_ANCHOR": + if reg is not None and lines: + anchor_re = re.compile( + rf"^\s*\|\s*(?:" + rf"lea\s+{re.escape(reg)},\s*\[(?:rbp|rsp)(?:\s*[+-]\s*\d+)?\]" + rf"|mov\s+{re.escape(reg)},\s*(?:rbp|rsp)" + rf")\s*$" + ) + if anchor_re.match(lines[-1]): + lines.pop() + lines.append(_FRAME_ANCHOR_MARKER) + _stat("LLVM_fold_marker") + continue + + # ── movabs REG, offset SYMBOL[+N] ── + if m := _RE_MOVABS.match(raw): + reg, sym, off_s = m.groups() + off = int(off_s) if off_s else 0 + if sym.startswith("_JIT_"): + expr = _jit_expr(sym, off) + reg_name = _REG_IDX_NAME.get(reg.lower()) + if reg_name is not None: + # Use emit_mov_imm which picks optimal encoding at + # JIT compile time (xor for 0, mov32 for ≤UINT32_MAX, + # mov64 otherwise). + lines.append(f" emit_mov_imm(Dst, {reg_name}, {expr});") + else: + lines.append(f" | mov64 {reg}, {expr}") + elif sym.startswith(".L"): + safe = sym.replace(".", "_") + reg_name = _REG_IDX_NAME.get(reg.lower()) + if reg_name is not None: + lines.append( + f" emit_mov_imm(Dst, {reg_name}, (uintptr_t)jit_data_{opname}_{safe});" + ) + else: + lines.append( + f" | mov64 {reg}, (uintptr_t)jit_data_{opname}_{safe}" + ) + else: + # External symbol: load address via emit_mov_imm which + # picks the optimal encoding at JIT time (xor/mov32/lea/mov64). + expr = ( + f"((uintptr_t)&{sym} + {off})" + if off + else f"(uintptr_t)&{sym}" + ) + reg_name = _REG_IDX_NAME.get(reg.lower()) + if reg_name is not None: + lines.append(f" emit_mov_imm(Dst, {reg_name}, {expr});") + else: + lines.append(f" | mov64 {reg}, {expr}") + continue + + # ── movabs REG, IMM (plain integer) ── + if m := _RE_MOVABS_IMM.match(raw): + reg, imm = m.groups() + # Use unsigned form for mov64 + val = int(imm) + if val < 0: + val = val & 0xFFFFFFFFFFFFFFFF + lines.append(f" | mov64 {reg}, {val}ULL") + continue + + # ── call/jmp via GOTPCREL ── + if m := _RE_GOTPCREL_CALL.match(raw): + instr, sym = m.groups() + if sym.startswith("_JIT_"): + # _JIT_* symbols are runtime values — emit optimal mov then + # indirect call/jmp through rax. + expr = _jit_expr(sym) + lines.append(f" emit_mov_imm(Dst, JREG_RAX, {expr});") + lines.append(f" | {instr} rax") + else: + # External function: emit a direct relative call using + # DynASM's &addr syntax (5-byte E8 rel32). Like Pyston's + # emit_call_ext_func. Falls back to mov64+call for + # targets beyond ±2GB. + lines.append(f" emit_call_ext(Dst, (void *)&{sym});") + if instr == "jmp": + # Tail call: after the callee returns, we need to exit + # the trace. Emit a ret which the epilogue rewriting + # will convert to jmp =>cleanup_ret_label. + lines.append(" | ret") + continue + + # ── generic instruction with GOTPCREL memory operand ── + if m := _RE_GOTPCREL_MEM.match(raw): + prefix, before, size, sym, after = m.groups() + instr = prefix.strip().split()[0] + dest = before.strip().rstrip(",").strip() + + if instr == "mov" and not after.strip(): + # mov REG, qword ptr [rip + SYM@GOTPCREL] + # Load the symbol address directly via emit_mov_imm. + if sym.startswith("_JIT_"): + expr = _jit_expr(sym) + else: + expr = f"(uintptr_t)&{sym}" + reg_name = _ANY_REG_TO_NAME.get(dest.lower()) + if reg_name is not None: + lines.append(f" emit_mov_imm(Dst, {reg_name}, {expr});") + else: + lines.append(f" | mov64 {dest}, {expr}") + continue + + if instr == "movzx" and not after.strip(): + # movzx REG, word/byte ptr [rip + SYM@GOTPCREL] + # Load the symbol value with appropriate truncation. + if sym.startswith("_JIT_"): + expr = _jit_expr(sym) + else: + expr = f"(uintptr_t)&{sym}" + _SIZE_CAST = {"word": "uint16_t", "byte": "uint8_t"} + cast = _SIZE_CAST.get(size, None) + if cast: + expr = f"({cast})({expr})" + reg_name = _ANY_REG_TO_NAME.get(dest.lower()) + if reg_name is not None: + lines.append(f" emit_mov_imm(Dst, {reg_name}, {expr});") + else: + lines.append(f" | mov64 {dest}, {expr}") + continue + + if instr == "call": + # External function: direct relative call via emit_call_ext. + if not sym.startswith("_JIT_"): + lines.append(f" emit_call_ext(Dst, (void *)&{sym});") + else: + expr = _jit_expr(sym) + lines.append(f" emit_mov_imm(Dst, JREG_RAX, {expr});") + lines.append(f" | call rax") + continue + + # Other instructions (cmp, test, etc.) with GOTPCREL memory. + # Emit the symbol address into a per-instruction data section + # entry, then reference it with a RIP-relative memory operand. + if sym.startswith("_JIT_"): + expr = _jit_expr(sym) + else: + expr = f"(uintptr_t)&{sym}" + data_label = counter + counter += 1 + lines.append(f" |.data") + lines.append(f" |=>L({data_label}):") + lines.append(f" | .qword {expr}") + lines.append(f" |.{cur_section}") + new_line = f"{prefix}{before}qword [=>L({data_label})]{after}" + new_line = _fix_syntax(new_line.strip()) + lines.append(f" | {new_line}") + continue + + # ── JIT branch targets ── + if m := _RE_JIT_BRANCH.match(raw): + instr, target = m.groups() + field = ( + "jump_target" + if target == "_JIT_JUMP_TARGET" + else "error_target" + ) + lines.append(f" | {instr} =>instruction->{field}") + continue + + # ── JIT continue ── + if m := _RE_JIT_CONTINUE.match(raw): + lines.append(f" | {m.group(1)} =>continue_label") + continue + + # ── local branches / calls ── + # Match jmp/jcc/call to a label we discovered in pass 1 + m_br = re.match(r"^\s*(j\w+|call)\s+([\w.]+)\s*(?:#.*)?$", raw) + if m_br: + instr, label = m_br.groups() + if label in local_map: + idx = local_map[label] + if instr == "call": + lines.append(f" | call =>L({idx})") + else: + lines.append(f" | {instr} =>L({idx})") + continue + + # ── default: plain instruction ── + stripped = raw.strip() + if not stripped or stripped.startswith(".section"): + continue + + fixed = _fix_syntax(stripped) + if "@" in fixed: + raise ValueError( + f"Unhandled @ symbol in stencil {opname}: {raw!r}" + ) + lines.append(f" | {fixed}") + + _flush_data() + + # Strip the frame-anchor marker emitted by template.c. + # With the simplified approach (no ForceFrameCall), the marker stands alone. + stripped_lines: list[str] = [] + for line in lines: + if line == _FRAME_ANCHOR_MARKER: + continue + stripped_lines.append(line) + lines = stripped_lines + + # Find the FIRST |.cold (the hot-cold boundary). There may be multiple + # |.cold directives (e.g. when the optimizer appends cold blocks after + # LLVM's own cold section), but the first one marks the end of hot code. + hot_end = len(lines) + for i in range(len(lines)): + stripped = lines[i].strip() + if stripped == "|.cold": + hot_end = i + break + + # Strip the stencil's outer function prologue/epilogue and record its + # requested frame size. The shared JIT shim recreates one canonical + # frame before calling into the trace, so individual stencils no longer + # need their function entry/exit stack manipulation. + frame_size = 0 + _RE_PUSH_RBP = re.compile(r"^\s*\|\s*push rbp\s*$") + _RE_MOV_RBP_RSP = re.compile(r"^\s*\|\s*mov rbp, rsp\s*$") + _RE_SUB_RSP = re.compile(r"^\s*\|\s*sub rsp,\s*(\d+)\s*$") + _RE_ADD_RSP = re.compile(r"^\s*\|\s*add rsp,\s*(\d+)\s*$") + _RE_LEA_RSP = re.compile(r"^\s*\|\s*lea rsp, \[rsp \+ (\d+)\]\s*$") + _RE_POP_RBP = re.compile(r"^\s*\|\s*pop rbp\s*$") + + # Detect entry prologue: |=>uop_label: then | push rbp, an optional + # | mov rbp, rsp, and finally | sub rsp, N. + prologue_push_idx = -1 + prologue_mov_idx = -1 + prologue_sub_idx = -1 + for i in range(len(lines)): + if lines[i].strip() == "|=>uop_label:": + if i + 2 < len(lines) and _RE_PUSH_RBP.match(lines[i + 1]): + prologue_push_idx = i + 1 + sub_idx = i + 2 + if sub_idx < len(lines) and _RE_MOV_RBP_RSP.match( + lines[sub_idx] + ): + prologue_mov_idx = sub_idx + sub_idx += 1 + if sub_idx < len(lines) and ( + m_sub := _RE_SUB_RSP.match(lines[sub_idx]) + ): + frame_size = int(m_sub.group(1)) + prologue_sub_idx = sub_idx + break + + total_push_rbp = sum(1 for line in lines if _RE_PUSH_RBP.match(line)) + + if frame_size > 0 and prologue_push_idx >= 0 and prologue_sub_idx >= 0: + lines[prologue_push_idx] = "" + if prologue_mov_idx >= 0: + lines[prologue_mov_idx] = "" + lines[prologue_sub_idx] = "" + + # When the entry prologue is the only push rbp in the stencil, every + # pop rbp belongs to the outer function epilogue even if the compiler + # hoists the shared add rsp, N before a branch. Strip all such outer + # unwinds and let the shim own the frame instead. + if total_push_rbp == 1: + for i in range(len(lines)): + if _RE_POP_RBP.match(lines[i]): + lines[i] = "" + continue + m_add = _RE_ADD_RSP.match(lines[i]) or _RE_LEA_RSP.match( + lines[i] + ) + if m_add and int(m_add.group(1)) == frame_size: + lines[i] = "" + else: + # Fall back to the conservative adjacent add/lea + pop pattern + # when the stencil contains other push rbp saves of its own. + stripped_any_epilogue = False + for i in range(len(lines) - 1): + if not _RE_POP_RBP.match(lines[i + 1]): + continue + m_add = _RE_ADD_RSP.match(lines[i]) or _RE_LEA_RSP.match( + lines[i] + ) + if m_add and int(m_add.group(1)) == frame_size: + stripped_any_epilogue = True + lines[i] = "" + lines[i + 1] = "" + if not stripped_any_epilogue: + frame_size = 0 + else: + frame_size = 0 + + # Inline trace exit sequences: tear down the shared trace frame + # (mov rsp, rbp; pop rbp) and then exit via the original instruction. + # The shim has its own prologue/epilogue and must not be rewritten. + if not is_shim: + _TRACE_EXIT_REWRITES = { + "| ret": [ + " | mov rsp, rbp", + " | pop rbp", + " | ret", + ], + "| jmp rax": [ + " | mov rsp, rbp", + " | pop rbp", + " | jmp rax", + ], + "| jmp rcx": [ + " | mov rsp, rbp", + " | pop rbp", + " | jmp rcx", + ], + "| jmp qword [rax + 48]": [ + " | mov rsp, rbp", + " | pop rbp", + " | jmp qword [rax + 48]", + ], + } + has_internal_push_rbp = total_push_rbp > 1 and frame_size > 0 + # In stencils with internal push/pop rbp pairs (subroutines like + # _Py_Dealloc called via `call =>L(N)`), exit instructions reachable + # after a `pop rbp` within the same basic block are internal subroutine + # exits, not trace exits. Track per-block state to skip rewriting. + _RE_LABEL_LINE = re.compile(r"^\s*\|=>") + in_internal_exit_block = False + new_lines = [] + for line in lines: + stripped = line.strip() + if has_internal_push_rbp: + if _RE_LABEL_LINE.match(stripped): + in_internal_exit_block = False + if _RE_POP_RBP.match(line): + in_internal_exit_block = True + if in_internal_exit_block and stripped in _TRACE_EXIT_REWRITES: + new_lines.append(line) + continue + replacement = _TRACE_EXIT_REWRITES.get(stripped) + if replacement is not None: + new_lines.extend(replacement) + else: + new_lines.append(line) + lines = new_lines + + # Apply peephole optimizations (fuse emit_mov_imm + movzx/or) + lines = _peephole_optimize(lines) + + # Eliminate trampoline jumps: when a label contains only a single + # `jmp` to another target, rewrite all branches to that label to + # jump directly to the final target. + _RE_LABEL_DEF = re.compile(r"^\s*\|=>L\((\d+)\):\s*$") + _RE_TRAMPOLINE_JMP = re.compile( + r"^\s*\|\s*jmp\s+=>(instruction->(?:jump_target|error_target)" + r"|L\(\d+\)|continue_label)\s*$" + ) + _RE_BRANCH_TO_LABEL = re.compile( + r"^(\s*\|\s*j\w+)\s+=>(L\(\d+\))\s*$" + ) + # Pass 1: find trampoline labels (label followed immediately by lone jmp) + trampoline_targets: dict[str, str] = {} # "L(N)" → target + for i in range(len(lines) - 1): + m_label = _RE_LABEL_DEF.match(lines[i]) + if not m_label: + continue + # Skip blank/comment/section lines to find the next real instruction + j = i + 1 + while j < len(lines) and ( + not lines[j].strip() + or lines[j].strip().startswith("//") + ): + j += 1 + if j >= len(lines): + continue + m_jmp = _RE_TRAMPOLINE_JMP.match(lines[j]) + if not m_jmp: + continue + label_name = f"L({m_label.group(1)})" + trampoline_targets[label_name] = m_jmp.group(1) + + # Resolve chains: L(8) → L(1) → instruction->jump_target + changed = True + while changed: + changed = False + for src, dst in list(trampoline_targets.items()): + if dst in trampoline_targets: + trampoline_targets[src] = trampoline_targets[dst] + changed = True + + if trampoline_targets: + # Collect line indices of trampoline labels and their jmp targets + dead_lines: set[int] = set() + for i in range(len(lines)): + m_label = _RE_LABEL_DEF.match(lines[i]) + if not m_label: + continue + label_name = f"L({m_label.group(1)})" + if label_name not in trampoline_targets: + continue + dead_lines.add(i) + # Find and mark the jmp line (skipping blanks/comments) + for j in range(i + 1, len(lines)): + if _RE_TRAMPOLINE_JMP.match(lines[j]): + dead_lines.add(j) + break + if lines[j].strip() and not lines[j].strip().startswith("//"): + break + + new_lines = [] + for i, line in enumerate(lines): + if i in dead_lines: + continue + # Rewrite branches to trampolines → direct branches + m_branch = _RE_BRANCH_TO_LABEL.match(line) + if m_branch: + target_label = m_branch.group(2) + if target_label in trampoline_targets: + branch_instr = m_branch.group(1) + new_lines.append( + f"{branch_instr} =>{trampoline_targets[target_label]}" + ) + continue + new_lines.append(line) + lines = new_lines + + # Remove blank lines and trailing whitespace from the output (clang emits + # blank lines between basic blocks; some peephole patterns add trailing \n). + lines = [l.rstrip() for l in lines if l.strip()] + + return ConvertedStencil( + opname=opname, + lines=lines, + num_internal_labels=counter, + data_items=data_items, + frame_size=frame_size, + ) diff --git a/Tools/jit/_asm_to_dasc_amd64.py b/Tools/jit/_asm_to_dasc_amd64.py new file mode 100644 index 00000000000000..8062c99457abee --- /dev/null +++ b/Tools/jit/_asm_to_dasc_amd64.py @@ -0,0 +1,1464 @@ +"""x86-64 specific peephole optimizations for the DynASM JIT backend. + +This module contains all x86-64 architecture-specific code for the peephole +optimizer: + +- JIT calling convention (register roles, frame offsets) +- Instruction effect analysis (which registers/flags each instruction touches) +- Peephole optimization patterns (23 patterns, organized by category) +- Pattern registry + +Architecture-generic infrastructure (types, parsing, pass management, dead +label elimination) lives in ``_asm_to_dasc.py``. A future ARM64 backend would +create ``_asm_to_dasc_aarch64.py`` following the same structure. +""" + +from __future__ import annotations + +import typing + +from _asm_to_dasc import ( + # Operand types + Reg, Mem, Imm, Op, + # Line types + Asm, CCall, CCallKind, Label, Section, FuncDef, Blank, CCode, Line, + # Infrastructure + _Match, _FoldCtx, _PeepholeState, _PeepholePass, + _PeepholeProgram, _LineEffect, + # Functions + parse_line, parse_lines, + is_call, is_branch, is_jump, branch_target, line_raw, fmt_op, + _stat, + # Data + _ANY_REG_TO_IDX, _ANY_REG_TO_NAME, _IDX_TO_ALL_NAMES, + _REG_IDX_NAME, _REG64_TO_REG32, _BRANCH_MNEMONICS, _RE_EMIT_MOV_IMM, + # Helpers + _operand_regs, _mem_uses_reg, _reg_bits, +) + +# ── JIT calling convention: register roles ────────────────────────────── +# The preserve_none calling convention assigns these fixed roles: +# r13 = frame pointer (_PyInterpreterFrame *) +# r14 = cached stack pointer (frame->stackpointer) +# r15 = thread state (PyThreadState *) +# r12 = current executor (_PyExecutorObject *) +# These constants are used by the peephole optimizer to recognize +# store/reload patterns without hardcoding register names. +REG_FRAME = "r13" # _PyInterpreterFrame *frame +REG_STACK_PTR = "r14" # _PyStackRef *stack_pointer (cached) +REG_TSTATE = "r15" # PyThreadState *tstate +REG_EXECUTOR = "r12" # _PyExecutorObject *executor + +# Frame struct field offsets (from _PyInterpreterFrame in +# Include/internal/pycore_interpframe_structs.h). +# Used by the peephole optimizer to match store/reload patterns. +FRAME_IP_OFFSET = 56 # offsetof(_PyInterpreterFrame, instr_ptr) +FRAME_STACKPOINTER_OFFSET = 64 # offsetof(_PyInterpreterFrame, stackpointer) + + +class _C: + """Constants usable as value patterns in match/case statements. + + In Python structural pattern matching, only dotted names (like ``_C.SP``) + are treated as value patterns that compare against the constant. Simple + names (like ``REG_STACK_PTR``) would be treated as capture patterns. + """ + FRAME = REG_FRAME # "r13" + SP = REG_STACK_PTR # "r14" + TSTATE = REG_TSTATE # "r15" + EXECUTOR = REG_EXECUTOR # "r12" + FRAME_SP_OFS = FRAME_STACKPOINTER_OFFSET # 64 + FRAME_IP_OFS = FRAME_IP_OFFSET # 56 + +# Derived patterns used by store/reload elimination. +# Store: mov qword [r13 + 64], r14 (frame->stackpointer = stack_pointer) +# Reload: mov r14, qword [r13 + 64] (stack_pointer = frame->stackpointer) +_SP_STORE = f"| mov qword [{REG_FRAME} + {FRAME_STACKPOINTER_OFFSET}], {REG_STACK_PTR}" +_SP_RELOAD = f"| mov {REG_STACK_PTR}, qword [{REG_FRAME} + {FRAME_STACKPOINTER_OFFSET}]" +_SP_RELOAD_LINE = f" {_SP_RELOAD}\n" + +# ── JREG name ↔ index mapping (used by emit_mov_imm patterns) ───────── + +_JREG_TO_IDX: dict[str, int] = { + "JREG_RAX": 0, "JREG_RCX": 1, "JREG_RDX": 2, "JREG_RBX": 3, + "JREG_RSP": 4, "JREG_RBP": 5, "JREG_RSI": 6, "JREG_RDI": 7, + "JREG_R8": 8, "JREG_R9": 9, "JREG_R10": 10, "JREG_R11": 11, + "JREG_R12": 12, "JREG_R13": 13, "JREG_R14": 14, "JREG_R15": 15, +} +_IDX_TO_JREG: dict[int, str] = {v: k for k, v in _JREG_TO_IDX.items()} + +# ── SysV ABI register classification ──────────────────────────────────── +# Used by _is_dead_before_any_call to determine how opaque calls +# (emit_call_ext / | call) interact with each register. + +# Callee-saved registers: preserved across function calls. +_CALLEE_SAVED_REGS = frozenset({3, 5, 12, 13, 14, 15}) # rbx rbp r12 r13 r14 r15 + +# SysV integer argument registers: may be read by function calls. +_SYSV_ARGUMENT_REGS = frozenset({7, 6, 2, 1, 8, 9}) # rdi rsi rdx rcx r8 r9 + +# Caller-saved registers: clobbered by function calls (SysV ABI). +_CALLER_SAVED_REGS = frozenset({0, 1, 2, 6, 7, 8, 9, 10, 11}) +# rax rcx rdx rsi rdi r8 r9 r10 r11 + + +def _parse_jreg(token: str) -> tuple[int, str]: + """Parse a JREG_* name or integer → (index, name).""" + if token in _JREG_TO_IDX: + return _JREG_TO_IDX[token], token + idx = int(token) + return idx, _IDX_TO_JREG.get(idx, str(idx)) + + +def _jreg_arg_index(token: str) -> int | None: + """Best-effort parse of a helper argument naming a JIT register.""" + token = token.strip() + if token in _JREG_TO_IDX: + return _JREG_TO_IDX[token] + try: + return int(token, 0) + except ValueError: + return None + + +def _reg_write_sets(reg: Reg) -> tuple[frozenset[int], frozenset[int]]: + """Return (full_writes, partial_writes) for a register destination.""" + if reg.idx is None: + return frozenset(), frozenset() + if reg.bits >= 32: + return frozenset({reg.idx}), frozenset() + return frozenset(), frozenset({reg.idx}) + + +def _line_effect(line: Line) -> _LineEffect: + """Summarize register and flags effects for one parsed line.""" + match line: + case Blank() | Label() | Section() | FuncDef(): + return _LineEffect() + case CCode(): + return _LineEffect() + case CCall(kind=CCallKind.MOV_IMM, argv=argv): + idx = _jreg_arg_index(argv[0]) if argv else None + return _LineEffect( + full_writes=frozenset({idx}) + if idx is not None + else frozenset() + ) + case CCall(kind=CCallKind.CALL_EXT): + # Model the emitted call as clobbering all caller-saved regs + # per SysV ABI. We intentionally do NOT model argument registers + # as reads: the peephole patterns never eliminate register writes + # that feed call arguments (they only fold emit_mov_imm into + # memory addressing or ALU patterns), so the risk is negligible. + return _LineEffect( + full_writes=_CALLER_SAVED_REGS, + writes_flags=True, + ) + case CCall(kind=CCallKind.CMP_REG_IMM, argv=argv): + reads = frozenset({_jreg_arg_index(argv[0])}) if len(argv) >= 1 else frozenset() + scratch = _jreg_arg_index(argv[1]) if len(argv) >= 2 else None + return _LineEffect( + reads=frozenset(idx for idx in reads if idx is not None), + full_writes=frozenset({scratch}) if scratch is not None else frozenset(), + writes_flags=True, + ) + case CCall(kind=CCallKind.CMP_MEM64_IMM, argv=argv): + base = _jreg_arg_index(argv[0]) if len(argv) >= 1 else None + scratch = _jreg_arg_index(argv[2]) if len(argv) >= 3 else None + return _LineEffect( + reads=frozenset({base}) if base is not None else frozenset(), + full_writes=frozenset({scratch}) if scratch is not None else frozenset(), + writes_flags=True, + ) + case CCall(kind=CCallKind.ALU_REG_IMM, helper=helper, argv=argv): + reg_idx = _jreg_arg_index(argv[0]) if len(argv) >= 1 else None + scratch = _jreg_arg_index(argv[1]) if len(argv) >= 2 else None + reads = frozenset({reg_idx}) if reg_idx is not None else frozenset() + full_writes = set() + if scratch is not None: + full_writes.add(scratch) + if helper in {"emit_and_reg_imm", "emit_or_reg_imm", "emit_xor_reg_imm", + "emit_add_reg_imm", "emit_sub_reg_imm"}: + if reg_idx is not None: + full_writes.add(reg_idx) + return _LineEffect( + reads=reads, + full_writes=frozenset(full_writes), + writes_flags=True, + ) + case Asm(mnemonic=mnemonic, dst=dst, src=src): + reads = set(_operand_regs(src)) + full_writes: set[int] = set() + partial_writes: set[int] = set() + uses_flags = False + writes_flags = False + + match mnemonic: + case "jmp" | "ret": + return _LineEffect() + case _ if is_branch(line): + return _LineEffect(uses_flags=True) + case "call": + # Reads the target operand (address computation), + # clobbers all caller-saved registers per SysV ABI. + reads |= set(_operand_regs(dst)) + return _LineEffect( + reads=frozenset(reads), + full_writes=_CALLER_SAVED_REGS, + writes_flags=True, + ) + case _ if mnemonic.startswith("set"): + uses_flags = True + if isinstance(dst, Reg): + full, partial = _reg_write_sets(dst) + full_writes |= full + partial_writes |= partial + case _ if mnemonic.startswith("cmov"): + uses_flags = True + reads |= set(_operand_regs(dst)) + reads |= set(_operand_regs(src)) + if isinstance(dst, Reg): + full, partial = _reg_write_sets(dst) + full_writes |= full + partial_writes |= partial + case "mov": + if isinstance(dst, Mem): + reads |= set(_operand_regs(dst)) + elif isinstance(dst, Reg): + full, partial = _reg_write_sets(dst) + full_writes |= full + partial_writes |= partial + case "movzx" | "movsxd" | "lea": + if isinstance(dst, Reg): + full, partial = _reg_write_sets(dst) + full_writes |= full + partial_writes |= partial + reads |= set(_operand_regs(src)) + case "cmp" | "test" | "bt" | "ucomisd": + reads |= set(_operand_regs(dst)) + reads |= set(_operand_regs(src)) + writes_flags = True + case "pop": + # pop writes to the destination register (from stack), + # it does NOT read the register. + if isinstance(dst, Reg): + full, partial = _reg_write_sets(dst) + full_writes |= full + partial_writes |= partial + case "push": + # push reads the source register to put it on the stack, + # but does not write to any GP register. + reads |= set(_operand_regs(dst)) + case "neg" | "not" | "inc" | "dec": + reads |= set(_operand_regs(dst)) + if isinstance(dst, Reg): + full, partial = _reg_write_sets(dst) + full_writes |= full + partial_writes |= partial + writes_flags = True + case "xor" if ( + isinstance(dst, Reg) + and isinstance(src, Reg) + and (dst.idx is not None and dst.idx == src.idx) + ): + # xor reg, reg is a zeroing idiom — no read dependency. + reads.discard(dst.idx) + full, partial = _reg_write_sets(dst) + full_writes |= full + partial_writes |= partial + writes_flags = True + case _: + reads |= set(_operand_regs(dst)) + reads |= set(_operand_regs(src)) + if isinstance(dst, Reg): + full, partial = _reg_write_sets(dst) + full_writes |= full + partial_writes |= partial + writes_flags = mnemonic in { + "add", "and", "or", "sub", "xor", "shl", "shr", + "sar", "sal", "rol", "ror", "rcl", "rcr", + } + + return _LineEffect( + reads=frozenset(reads), + full_writes=frozenset(full_writes), + partial_writes=frozenset(partial_writes), + uses_flags=uses_flags, + writes_flags=writes_flags, + ) + return _LineEffect() + + +def uses_reg(line: Line, reg_idx: int) -> bool: + """Does this line reference the given register (by index)? + + Uses structured operand/helper effects where possible, then falls back to + the raw text for unclassified C lines. + """ + effect = _line_effect(line) + if ( + reg_idx in effect.reads + or reg_idx in effect.full_writes + or reg_idx in effect.partial_writes + ): + return True + names = _IDX_TO_ALL_NAMES.get(reg_idx, set()) + raw = line.raw + return any(name in raw for name in names) + + +def is_store_sp(line: Line) -> bool: + """Matches: ``| mov qword [r13 + 64], r14`` (store stack pointer).""" + return line.raw.strip() == _SP_STORE + + +def is_reload_sp(line: Line) -> bool: + """Matches: ``| mov r14, qword [r13 + 64]`` (reload stack pointer).""" + return line.raw.strip() == _SP_RELOAD + + +def _preserve_flags_mov_imm(line: str) -> str: + """Rewrite emit_mov_imm(...) to emit_mov_imm_preserve_flags(...).""" + return line.replace("emit_mov_imm(", "emit_mov_imm_preserve_flags(", 1) + + +def _is_flag_writer(line: Line) -> bool: + """Does this instruction define flags for a later consumer?""" + return _line_effect(line).writes_flags + + +def _is_flag_consumer(line: Line) -> bool: + """Does this instruction consume previously computed flags?""" + return _line_effect(line).uses_flags + + +def _parse_emit_mov_imm_call(line: str) -> tuple[int, str, str] | None: + """Parse an emit_mov_imm* helper call into (reg_idx, reg_name, expr).""" + match parse_line(line): + case CCall(kind=CCallKind.MOV_IMM, argv=(reg, expr, *_)): + reg_idx, reg_name = _parse_jreg(reg) + return reg_idx, reg_name, expr + return None + + +# Map jcc mnemonic → C comparison operator for unsigned cmp folding. +# Given "cmp REG, IMM; jcc label", the branch is taken when REG IMM. + + +def _reg_dead_after( + program_or_lines: _PeepholeProgram | list[str], + start: int, + reg_idx: int, +) -> bool: + """Control-flow-aware deadness query backed by structured effects.""" + if isinstance(program_or_lines, _PeepholeProgram): + program = program_or_lines + else: + program = _PeepholeProgram.from_lines(program_or_lines) + return program.reg_dead_after(start, reg_idx) + + +def _is_dead_before_any_call( + program: _PeepholeProgram, + start: int, + reg_idx: int, +) -> bool: + """Check that *reg_idx* is dead AND safe to eliminate. + + This performs a CFG-aware scan from *start* verifying that on EVERY + reachable path, *reg_idx* is fully overwritten before: + 1. an opaque call (``emit_call_ext`` / ``| call``) that might read + it — only relevant for SysV argument registers (rdi, rsi, rdx, + rcx, r8, r9), OR + 2. the end of the function or a ``.cold`` section boundary. + + Callee-saved registers (rbx, rbp, r12–r15) are preserved across calls + by the SysV ABI, so calls are transparent for them. Caller-saved + non-argument registers (rax, r10, r11) are clobbered by calls, which + counts as a full write. + + Condition (2) is needed because in the JIT stencil system, registers + that are live at the end of a stencil function are inter-stencil + outputs consumed by the next stencil. DynASM ``.cold`` sections are + placed in a separate memory region, so linear fallthrough across a + section switch does not exist at runtime. + """ + if not program.reg_dead_after(start, reg_idx): + return False + + # Callee-saved registers are preserved across calls. + callee_saved = reg_idx in _CALLEE_SAVED_REGS + # Argument registers may be read by calls as function arguments. + is_argument_reg = reg_idx in _SYSV_ARGUMENT_REGS + # Other caller-saved regs (rax, r10, r11) are clobbered by calls. + + func_end = program.function_end_by_line[start] + start_depth = program.c_depth[start] + parsed = program.parsed + visited: set[int] = set() + worklist = [start + 1] + + while worklist: + pos = worklist.pop() + while pos < func_end: + if pos in visited: + break + visited.add(pos) + pj = parsed[pos] + + is_opaque_call = ( + isinstance(pj, CCall) and pj.kind == CCallKind.CALL_EXT + ) or (isinstance(pj, Asm) and is_call(pj)) + + if is_opaque_call: + if is_argument_reg: + # Call might read this register — unsafe. + return False + if not callee_saved: + # Caller-saved non-argument reg (rax, r10, r11): + # the call clobbers it → effectively a full write. + break + # Callee-saved: call preserves it, continue scanning. + + # .cold section boundary — treat as function end. The + # register hasn't been written on the hot path so it may + # be an inter-stencil output. (.code sections are benign; + # they just confirm we are already in the hot section.) + if isinstance(pj, Section) and pj.name == "cold": + return False + + # R1 fully overwritten — this path is safe + eff = program.effects[pos] + if reg_idx in eff.full_writes: + break + # Follow the same successor logic as reg_dead_after + successors = program.successors(pos, start_depth) + if not successors: + # No successors and no write — R1 reaches function end + # alive, so it may be an inter-stencil output register. + return False + fallthrough = pos + 1 + if len(successors) == 1 and successors[0] == fallthrough: + pos = fallthrough + continue + for succ in successors[1:]: + if succ < func_end: + worklist.append(succ) + pos = successors[0] + else: + # Inner while exited because pos >= func_end without a + # full_write — the register reaches function end alive. + return False + + return True + + +# ── emit_mov_imm chain patterns ─────────────────────────────────────── +# +# These patterns all start from a parsed emit_mov_imm line and attempt +# to fold subsequent instructions. They share (expr, consumed, src_idx) +# state and compose in sequence: Pattern 1 can modify expr, then Pattern +# 2 refines it further, etc. + + +def _try_indexed_mem(ctx: _FoldCtx) -> _Match | None: + """Pattern 6: Fold indexed memory loads/stores with computed index. + + When emit_mov_imm loads a value that's used as an index in a + memory access [base + REG*scale + disp], precompute the offset + at JIT compile time. This eliminates the index register and the + scaled addressing mode, replacing it with a simple [base + const]. + + Example — index into PyObject array: + emit_mov_imm(Dst, JREG_RCX, instruction->oparg); + | mov rax, qword [rbx + rcx*8 + 48] + → + | mov rax, qword [rbx + ((int)(instruction->oparg) * 8 + 48)] + + Multiple consecutive accesses using the same index are all folded: + emit_mov_imm(Dst, JREG_RCX, instruction->oparg); + | mov rax, qword [r14 + rcx*8 + 0] + | mov rdx, qword [r14 + rcx*8 + 8] + → + | mov rax, qword [r14 + ((int)(instruction->oparg) * 8 + 0)] + | mov rdx, qword [r14 + ((int)(instruction->oparg) * 8 + 8)] + + Safety: the index register must be either overwritten by the load + destination or dead after all folded accesses. + """ + lines, i, src_idx, expr = ctx.lines, ctx.i, ctx.src_idx, ctx.expr + parsed = ctx.parsed + folded: list[str] = [] + scan = i + while scan < len(lines): + p = parsed[scan] + # Match instructions with indexed memory operand containing our reg + mem_op: Mem | None = None + is_load = False + match p: + # Load: | mov rax, qword [base + idx*scale + disp] + # Exclude LEA (handled by P7a/P7b). + case Asm(mnemonic=mn, dst=Reg(), src=Mem(index=idx, scale=sc)) if ( + mn != "lea" and idx and sc and Reg(idx).idx == src_idx + ): + mem_op = p.src + is_load = True + # Store: | mov qword [base + idx*scale + disp], reg + case Asm( + mnemonic="mov", dst=Mem(index=idx, scale=sc), src=Reg() + ) if idx and sc and Reg(idx).idx == src_idx: + mem_op = p.dst + is_load = False + case _: + break + if mem_op is None or mem_op.base is None: + break + # Reconstruct with computed offset + computed = f"(int)({expr}) * {mem_op.scale} + {mem_op.offset}" + new_mem = f"[{mem_op.base} + ({computed})]" + if is_load: + new_line = ( + f" | {p.mnemonic} {fmt_op(p.dst)}, {mem_op.size} {new_mem}" + if mem_op.size + else f" | {p.mnemonic} {fmt_op(p.dst)}, {new_mem}" + ) + else: + new_line = ( + f" | {p.mnemonic} {mem_op.size} {new_mem}, {fmt_op(p.src)}" + if mem_op.size + else f" | {p.mnemonic} {new_mem}, {fmt_op(p.src)}" + ) + folded.append(new_line) + scan += 1 + if not folded: + return None + # Safety: index reg overwritten by load dest, or dead after + first = parsed[i] + dest_idx = ( + first.dst.idx + if isinstance(first, Asm) and isinstance(first.dst, Reg) + else None + ) + if dest_idx == src_idx or _reg_dead_after(ctx.program, scan, src_idx): + _stat("P6_indexed_mem") + return _Match(scan - i, folded) + return None + + +def _try_two_mov_add(ctx: _FoldCtx) -> _Match | None: + """Pattern 15: Combine two immediate loads followed by add.""" + lines, i, parsed, indent, expr = ( + ctx.lines, ctx.i, ctx.parsed, ctx.indent, ctx.expr) + mov_info = _parse_emit_mov_imm_call(lines[i]) + if mov_info is None or i + 1 >= len(lines): + return None + dst_idx, dst_name, rhs_expr = mov_info + match parsed[i + 1]: + case Asm( + mnemonic="add", dst=Reg(name=add_dst), src=Reg(name=add_src) + ) if Reg(add_dst).idx == dst_idx and Reg(add_src).idx == ctx.src_idx: + pass + case _: + return None + if not _reg_dead_after(ctx.program, i + 2, ctx.src_idx): + return None + bits = _reg_bits(add_dst) + if bits <= 32: + combined = f"(uint{bits}_t)(({rhs_expr}) + ({expr}))" + else: + combined = f"({rhs_expr}) + ({expr})" + _stat("P16_two_mov_add") + return _Match( + 2, + [ + f"{indent}emit_mov_imm(Dst, {dst_name}, {combined});", + ], + ) + + +def _try_alu_imm(ctx: _FoldCtx) -> _Match | None: + """Pattern 8: Fold ALU instruction's register operand into immediate. + + When emit_mov_imm loads a value into a register and the next + instruction uses that register as the second operand of an ALU + operation (cmp, test, and, or, xor, add, sub), replace the register + with the immediate value directly. This eliminates the mov entirely. + + Example — compare with runtime constant: + emit_mov_imm(Dst, JREG_RAX, instruction->operand0); + | cmp rcx, rax + → (if value fits in sign-extended imm32) + | cmp rcx, (int)(instruction->operand0) + → (if value does NOT fit in imm32, falls back) + emit_mov_imm(Dst, JREG_RAX, instruction->operand0); + | cmp rcx, rax + + Example — OR with type tag: + emit_mov_imm(Dst, JREG_RDX, instruction->operand0); + | or qword [rbx + 16], rdx + → (32-bit value fits) + | or qword [rbx + 16], (int)(instruction->operand0) + + For 64-bit operands, emits a runtime range check (if/else) to use + the immediate form when possible, falling back to the register form. + For 32-bit and 16-bit operands, the immediate always fits. + + The source register must be dead after the ALU instruction (since + we're eliminating the load). + """ + _ALU_OPS = {"cmp", "test", "and", "or", "xor", "add", "sub"} + # Only test is safe for commutative swap because it doesn't write to + # a register. For and/or/xor/add, swapping would change which register + # receives the result, corrupting program state. + _COMMUTATIVE_OPS = {"test"} + lines, i, src_idx, indent, src_name, expr = ( + ctx.lines, ctx.i, ctx.src_idx, ctx.indent, ctx.src_name, ctx.expr) + cur = ctx.cur + if not isinstance(cur, Asm) or cur.mnemonic not in _ALU_OPS: + return None + alu_op = cur.mnemonic + alu_reg = None + dst_op = None + # Standard order: ALU dst, src — where src is our emit_mov_imm register. + if isinstance(cur.src, Reg) and Reg(cur.src.name).idx == src_idx: + alu_reg = cur.src.name + dst_op = cur.dst + # Commutative swap: ALU dst, src — where dst is our register. + elif ( + alu_op in _COMMUTATIVE_OPS + and isinstance(cur.dst, Reg) + and Reg(cur.dst.name).idx == src_idx + ): + alu_reg = cur.dst.name + dst_op = cur.src + else: + return None + # Format the first operand text for output + alu_first = fmt_op(dst_op) + # Don't fold if first operand is also the same register + match dst_op: + case Reg(name=first_reg) if Reg(first_reg).idx == src_idx: + return None + case Mem() as mem if _mem_uses_reg(mem, src_idx): + return None + if not _reg_dead_after(ctx.program, i + 1, src_idx): + return None + _stat("P8_alu_imm_fold") + bits = Reg(alu_reg).bits + if bits == 32: + return _Match( + 1, + [ + f"{indent}| {alu_op} {alu_first}, (int)({expr})", + ], + ) + if bits == 64: + # For register first operand, use emit_{op}_reg_imm helpers. + # These emit the shortest encoding: imm32 when it fits, otherwise + # scratch register + reg-reg form. + if isinstance(dst_op, Reg): + first_idx_name = _REG_IDX_NAME.get(dst_op.name.lower()) + if first_idx_name: + return _Match( + 1, + [ + f"{indent}emit_{alu_op}_reg_imm(Dst, {first_idx_name}, {src_name}, (uintptr_t)({expr}));", + ], + ) + # Simple qword memory compare: route through a dedicated helper so we + # do not have to emit a multiline if/else template at each call site. + if ( + alu_op == "cmp" + and isinstance(dst_op, Mem) + and dst_op.size == "qword" + and dst_op.base is not None + and dst_op.index is None + ): + base_name = Reg(dst_op.base).jreg + if base_name is not None: + return _Match( + 1, + [ + f"{indent}emit_cmp_mem64_imm(Dst, {base_name}, {dst_op.offset}, {src_name}, (uintptr_t)({expr}));", + ], + ) + # Memory first operand: fallback inline runtime range check for other + # cases that still cannot use a dedicated helper. + c64 = f"(int64_t)({expr})" + c32 = f"(int32_t)({expr})" + return _Match( + 1, + [ + f"{indent}if ({c64} == {c32}) {{", + f"{indent}| {alu_op} {alu_first}, (int)({expr})", + f"{indent}}} else {{", + f"{indent}emit_mov_imm(Dst, {src_name}, {expr});", + f"{indent}| {alu_op} {alu_first}, {alu_reg}", + f"{indent}}}", + ], + ) + if bits == 16: + return _Match( + 1, + [ + f"{indent}| {alu_op} {alu_first}, (short)({expr})", + ], + ) + return None + + +def _try_store_imm(ctx: _FoldCtx) -> _Match | None: + """Pattern 12: Fold register store into immediate store to memory. + + When emit_mov_imm loads a value into a register and the next + instruction stores that register to memory, replace with a direct + immediate-to-memory store (eliminating the register load entirely). + + Example — store byte: + emit_mov_imm(Dst, JREG_RAX, instruction->oparg); + | mov byte [rbx + 42], al + → + | mov byte [rbx + 42], (char)(instruction->oparg) + + Example — store qword (needs range check): + emit_mov_imm(Dst, JREG_RCX, instruction->operand0); + | mov qword [r14 + 8], rcx + → + if ((int64_t)(instruction->operand0) == + (int32_t)(instruction->operand0)) { + | mov qword [r14 + 8], (int)(instruction->operand0) + } else { + emit_mov_imm(Dst, JREG_RCX, instruction->operand0); + | mov qword [r14 + 8], rcx + } + + For byte/word/dword stores the immediate always fits. For qword + stores, x86_64 only supports sign-extended imm32, so we emit a + runtime range check with a fallback. + + The source register must be dead after the store. + """ + lines, i, src_idx, indent, src_name, expr = ( + ctx.lines, ctx.i, ctx.src_idx, ctx.indent, ctx.src_name, ctx.expr) + match ctx.cur: + case Asm( + mnemonic="mov", + dst=Mem(size=size, expr=mem_expr), + src=Reg(name=reg), + ) if ( + size in ("qword", "dword", "word", "byte") + and Reg(reg).idx == src_idx + ): + pass # fall through to shared logic + case _: + return None + mem = typing.cast(Mem, ctx.cur.dst) + if _mem_uses_reg(mem, src_idx): + return None + if not _reg_dead_after(ctx.program, i + 1, src_idx): + return None + _stat("P12_store_imm") + _SIZE_CAST = {"byte": "char", "word": "short", "dword": "int"} + if size in _SIZE_CAST: + cast = _SIZE_CAST[size] + return _Match( + 1, + [ + f"{indent}| mov {size} {mem_expr}, ({cast})({expr})", + ], + ) + # qword: use emit_store_mem64_imm for simple [base + offset] forms, + # fall back to inline if/else for complex addressing modes. + if mem.base and not mem.index: + from _asm_to_dasc import Reg as _Reg + + base_reg = _Reg(mem.base) + base_jreg = base_reg.jreg + if base_jreg is not None: + return _Match( + 1, + [ + f"{indent}emit_store_mem64_imm(Dst, {base_jreg}," + f" {mem.offset}, {src_name}, {expr});", + ], + ) + # Complex addressing: inline if/else fallback + c64 = f"(int64_t)({expr})" + c32 = f"(int32_t)({expr})" + return _Match( + 1, + [ + f"{indent}if ({c64} == {c32}) {{", + f"{indent}| mov qword {mem_expr}, (int)({expr})", + f"{indent}}} else {{", + f"{indent}emit_mov_imm(Dst, {src_name}, {expr});", + f"{indent}| mov qword {mem_expr}, {reg}", + f"{indent}}}", + ], + ) + + +def _try_shift_fold(ctx: _FoldCtx) -> _Match | None: + """Fold shift of an emit_mov_imm register into the immediate expression. + + When emit_mov_imm loads a value into a register and the next instruction + shifts that same register by an immediate amount, absorb the shift into + the emit_mov_imm expression. + + Example (LOAD_GLOBAL_MODULE — dict key lookup): + emit_mov_imm(Dst, JREG_RAX, (uint16_t)(instruction->operand1)); + | shl rax, 4 + → + emit_mov_imm(Dst, JREG_RAX, (uintptr_t)((uint16_t)(instruction->operand1)) << 4); + """ + match ctx.cur: + case Asm( + mnemonic="shl", dst=Reg(name=reg), src=Imm(text=shift_str) + ) if Reg(reg).idx == ctx.src_idx: + pass + case _: + return None + _stat("P15_shift_fold") + return _Match( + 1, + [ + f"{ctx.indent}emit_mov_imm(Dst, {ctx.src_name}," + f" (uintptr_t)({ctx.expr}) << {shift_str});", + ], + ) + + +def _try_lea_fold(ctx: _FoldCtx) -> _Match | None: + """Fold emit_mov_imm + lea [base + reg*scale] into lea [base + disp]. + + When emit_mov_imm loads a JIT-time value into a register that is then + only used as a scaled index in a lea, the scaled product can be computed + at JIT emit time and used as a 32-bit displacement instead. + + Example (stack pointer adjustment): + emit_mov_imm(Dst, JREG_RBP, (0 - (uint16_t)(instruction->oparg))); + | lea rdi, [r14 + rbp*8] + → + | lea rdi, [r14 + (int)((intptr_t)(0 - (uint16_t)(instruction->oparg)) * 8)] + + Also handles the no-base form [reg*scale+0]: + emit_mov_imm(Dst, JREG_R15, expr); + | lea r12, [r15*8+0] + → + emit_mov_imm(Dst, JREG_R12, (uintptr_t)(expr) * 8); + + Conditions: + - The emit_mov_imm register is used as a scaled index in the lea + - The register is dead after the lea (or overwritten by it) + - The expression * scale fits in int32_t (guaranteed for oparg-based + expressions: max |oparg|=65535, max scale=8, product ≤ 524,280) + """ + import re as _re + + match ctx.cur: + case Asm(mnemonic="lea", dst=Reg(name=dst_reg), src=Mem() as mem_op): + pass + case _: + return None + + # Check that the emit_mov_imm register is used as a scaled index + src_reg = _IDX_TO_ALL_NAMES.get(ctx.src_idx, ()) + if not src_reg: + return None + + mem_text = fmt_op(mem_op) + reg_alt = "|".join(_re.escape(r) for r in src_reg) + + # Pattern 1: [base + reg*scale] + m_based = _re.search( + r"\[(\w+)\s*\+\s*(" + reg_alt + r")\*(\d+)\]", mem_text + ) + # Pattern 2: [reg*scale+0] (no base register) + m_nobase = _re.search(r"\[(" + reg_alt + r")\*(\d+)\+0\]", mem_text) + + if m_based: + base_reg = m_based.group(1) + scale = int(m_based.group(3)) + + # Check if the src register's original value is needed after the lea. + # The lea itself reads src_reg (scaled index), but we're replacing that + # with a displacement, so the lea's read doesn't count. + # If the lea's destination overwrites src_reg, the value is dead anyway. + lea_dst_idx = Reg(dst_reg).idx + if lea_dst_idx != ctx.src_idx: + if not _reg_dead_after(ctx.program, ctx.i + 1, ctx.src_idx): + return None + + _stat("P17_lea_fold") + return _Match( + 1, + [ + f"{ctx.indent}| lea {dst_reg}," + f" [{base_reg} + (int)((intptr_t)({ctx.expr}) * {scale})]", + ], + ) + elif m_nobase: + scale = int(m_nobase.group(2)) + + # For [reg*scale+0], the result is just expr * scale. + # We emit a new emit_mov_imm into the lea's destination register. + dst_idx = Reg(dst_reg).idx + dst_name = ( + _IDX_TO_JREG.get(dst_idx, str(dst_idx)) + if dst_idx is not None + else dst_reg + ) + + # Check src register is dead after the lea + lea_dst_idx = Reg(dst_reg).idx + if lea_dst_idx != ctx.src_idx: + if not _reg_dead_after(ctx.program, ctx.i + 1, ctx.src_idx): + return None + + _stat("P17_lea_fold") + return _Match( + 1, + [ + f"{ctx.indent}emit_mov_imm(Dst, {dst_name}," + f" (uintptr_t)({ctx.expr}) * {scale});", + ], + ) + + return None + + +def _fold_mov_imm( + program: _PeepholeProgram, + i: int, + result: list[str], +) -> int | None: + """Try to fold an emit_mov_imm with subsequent instructions. + + Returns the number of lines consumed (advancing past ``i``), or + None if no fold was possible and the caller should try other patterns. + """ + lines = program.lines + parsed = program.parsed + m_mov = _RE_EMIT_MOV_IMM.match(lines[i]) + if not m_mov or i + 1 >= len(lines): + return None + # Guard against re-folding inside "} else {" fallback blocks + if result and result[-1].rstrip().endswith("} else {"): + return None + + indent = m_mov.group(1) + src_idx, src_name = _parse_jreg(m_mov.group(2)) + expr = m_mov.group(3) + consumed = 1 + + # Build shared context for all _try_* functions + ctx = _FoldCtx( + program=program, + lines=lines, + parsed=parsed, + i=i + 1, + src_idx=src_idx, + src_name=src_name, + indent=indent, + expr=expr, + ) + + # Try each fold pattern in priority order (all take ctx: _FoldCtx) + ctx.i = i + consumed + if ctx.i < len(lines): + for try_fn in ( + _try_two_mov_add, + _try_indexed_mem, + _try_alu_imm, + _try_store_imm, + _try_shift_fold, + _try_lea_fold, + ): + match = try_fn(ctx) + if match: + result.extend(match.output) + return consumed + match.consumed + + return None + + +# ── Standalone patterns ──────────────────────────────────────────────── +# +# Each function: (lines, i, result, state) → int | None +# Returns lines consumed on match, or None. May append to ``result`` +# and mutate ``state``. + + +def _pattern_preserve_flags_mov_imm( + program: _PeepholeProgram, + i: int, + result: list[str], + state: _PeepholeState, +) -> int | None: + """Preserve flags across immediate loads inserted before setcc/cmov/jcc.""" + del state # unused + lines = program.lines + if "emit_mov_imm(" not in lines[i] or i == 0: + return None + prev = program.parsed[i - 1] + if isinstance(prev, CCall) and prev.kind == CCallKind.MOV_IMM: + return None + if not _is_flag_writer(prev): + return None + j = i + output: list[str] = [] + while j < len(lines) and "emit_mov_imm(" in lines[j]: + cur = program.parsed[j] + if not (isinstance(cur, CCall) and cur.kind == CCallKind.MOV_IMM): + break + output.append(_preserve_flags_mov_imm(lines[j])) + j += 1 + if not output or j >= len(lines): + return None + if not _is_flag_consumer(program.parsed[j]): + return None + _stat("SP0_preserve_flags_mov_imm") + result.extend(output) + return len(output) + + +def _pattern_store_reload_elim( + program: _PeepholeProgram, + i: int, + result: list[str], + state: _PeepholeState, +) -> int | None: + """Eliminate redundant stackpointer reloads on the hot path. + + Matches: + | mov qword [r13 + 64], r14 (store) + | test/cmp REG, IMM + | jcc =>COLD_LABEL + |=>MERGE_LABEL: + | mov r14, qword [r13 + 64] (reload — eliminated) + + The hot path never modifies r14 or [r13+64]. The cold dealloc path + may modify [r13+64], so we insert a reload there before jumping back. + + Uses structural pattern matching on typed ``Line`` objects for dispatch. + """ + lines = program.lines + if i + 4 >= len(lines): + return None + window = [program.parsed[i + k] for k in range(5)] + match window: + case [ + # mov qword [r13 + FRAME_STACKPOINTER_OFFSET], r14 + Asm( + mnemonic="mov", + dst=Mem(base=_C.FRAME, offset=_C.FRAME_SP_OFS), + src=Reg(name=_C.SP), + ), + # test/cmp REG, IMM + Asm(mnemonic=op), + # jcc =>LABEL + Asm(target=branch_tgt), + # =>MERGE_LABEL: + Label(name=merge_name), + # mov r14, qword [r13 + FRAME_STACKPOINTER_OFFSET] + Asm( + mnemonic="mov", + dst=Reg(name=_C.SP), + src=Mem(base=_C.FRAME, offset=_C.FRAME_SP_OFS), + ), + ] if op in ("test", "cmp") and branch_tgt: + merge_lbl = f"=>{merge_name}" + for k in range(4): + result.append(lines[i + k]) + state.need_reload_before_jmp.add(merge_lbl) + _stat("SP1_store_reload_elim") + return 5 + return None + + +def _pattern_cold_reload_insert( + program: _PeepholeProgram, + i: int, + result: list[str], + state: _PeepholeState, +) -> int | None: + """Insert stackpointer reload before cold-path jump back to merge label. + + After store-reload elimination, the dealloc cold path needs to reload + r14 from [r13+64] before jumping back (since _Py_Dealloc may have + modified [r13+64] via __del__). + """ + if not state.need_reload_before_jmp: + return None + lines = program.lines + cur = program.parsed[i] + # Must be a jump or conditional branch + match cur: + case Asm(mnemonic=m, target=target_lbl) if target_lbl and ( + m == "jmp" or m in _BRANCH_MNEMONICS + ): + pass # fall through + case _: + return None + if target_lbl not in state.need_reload_before_jmp: + return None + # Only insert reload if we're after a call (dealloc path) + for prev in reversed(result): + prev_line = parse_line(prev) + match prev_line: + case Blank(): + continue + case CCall() | Asm(mnemonic="call"): + _stat("SP2_cold_reload_insert") + result.append(_SP_RELOAD_LINE) + case Asm(mnemonic="add", dst=Reg(name="rsp")): + _stat("SP2_cold_reload_insert") + result.append(_SP_RELOAD_LINE) + break + # Don't consume extra lines — just insert before the current line + return None + + +def _pattern_inverted_store_reload( + program: _PeepholeProgram, + i: int, + result: list[str], + state: _PeepholeState, +) -> int | None: + """Defer stackpointer store to cold path when the hot branch skips it. + + Some stencils (e.g. _POP_TOP_r10) have an "inverted" pattern where the + conditional branch jumps to the merge point (hot path) and the + fallthrough goes to the cold path. The store/reload pair is redundant + on the hot path since r14 is preserved by callee-saved convention + across any C calls on the cold path. + + Matches the pattern: + | mov qword [r13 + 64], r14 ← store (line i) + | test/cmp REG, IMM ← branch condition (line i+1) + | jcc =>L(MERGE) ← hot: jump to merge (line i+2) + | jmp =>L(COLD) ← cold path redirect (line i+3) + ... intermediate comeback code (labels + instructions) ... + |=>L(MERGE): ← merge label + | mov r14, qword [r13 + 64] ← reload (eliminated on hot path) + + Transforms to: + | test/cmp REG, IMM ← condition (moved up past store) + | jcc =>L(MERGE) ← hot: jump past store+reload + | mov qword [r13 + 64], r14 ← store (deferred, cold-only) + | jmp =>L(COLD) ← cold path redirect + ... intermediate comeback code ... + | mov r14, qword [r13 + 64] ← reload (moved before merge label) + |=>L(MERGE): ← merge point (hot enters here) + + Hot path saves 14 bytes (7-byte store + 7-byte reload) and 2 memory + accesses. The cold comeback path still gets the reload. + + Uses structural pattern matching for robust store_sp, reload_sp, branch, + and jump detection. + """ + lines = program.lines + if i + 5 >= len(lines): + return None + + # First 4 lines: store, test/cmp, jcc, jmp + w = [program.parsed[i + k] for k in range(4)] + match w: + case [ + Asm( + mnemonic="mov", + dst=Mem(base=_C.FRAME, offset=_C.FRAME_SP_OFS), + src=Reg(name=_C.SP), + ), + Asm(mnemonic=cmp_op), + Asm(target=merge_target), + Asm(mnemonic="jmp"), + ] if cmp_op in ("test", "cmp") and merge_target: + pass # fall through + case _: + return None + + # Scan forward to find the merge label and reload + merge_label_str = ( + f"|=>{merge_target[2:]}:" if merge_target.startswith("=>") else None + ) + if not merge_label_str: + return None + merge_idx = None + for j in range(i + 4, min(i + 20, len(lines))): + stripped = lines[j].strip() + if stripped.replace(" ", "").startswith( + merge_label_str.replace(" ", "") + ): + merge_idx = j + break + if merge_idx is None or merge_idx + 1 >= len(lines): + return None + match program.parsed[merge_idx + 1]: + case Asm( + mnemonic="mov", + dst=Reg(name=_C.SP), + src=Mem(base=_C.FRAME, offset=_C.FRAME_SP_OFS), + ): + pass # confirmed reload + case _: + return None + + # Pattern matched! Build the transformed output. + result.append(lines[i + 1]) # test/cmp + result.append(lines[i + 2]) # jcc =>L(MERGE) + result.append(lines[i]) # store (now cold-only) + result.append(lines[i + 3]) # jmp =>L(COLD) + for k in range(i + 4, merge_idx): + result.append(lines[k]) + result.append(lines[merge_idx + 1]) # reload (before merge label) + result.append(lines[merge_idx]) # |=>L(MERGE): + _stat("SP3_inverted_store_reload") + return (merge_idx + 1) - i + 1 + + +def _pattern_test_memory_fold( + program: _PeepholeProgram, + i: int, + result: list[str], + state: _PeepholeState, +) -> int | None: + """Pattern 14: Fold mov+test into test [mem] when loaded register is dead. + + When a mov loads a value from memory solely to test a bit/byte + and the register is dead after the conditional branch, we can + test the memory location directly, eliminating the mov. + + Matches (3 consecutive lines): + | mov REG, qword/dword [MEM] + | test REG_LOW, REG_LOW (or: test REG_LOW, IMM) + | jcc =>TARGET + + When REG is dead after the jcc, transforms to: + | cmp byte [MEM], 0 (for test REG, REG form) + | test byte [MEM], IMM (for test REG, IMM form) + | jcc =>TARGET + + Example (TIER2_RESUME_CHECK — very hot path): + Before: + | mov rax, qword [r15 + 24] + | test al, al + | jne =>instruction->jump_target + After: + | cmp byte [r15 + 24], 0 + | jne =>instruction->jump_target + """ + del state # unused + lines = program.lines + if i + 2 >= len(lines): + return None + window = [program.parsed[i + k] for k in range(3)] + match window: + # mov REG, qword/dword [MEM]; test REG_LOW, REG_LOW; jcc + case [ + Asm(mnemonic="mov", dst=Reg() as mov_dst, src=Mem() as mov_src), + Asm(mnemonic="test", dst=Reg() as test_dst, src=Reg() as test_src), + Asm(target=branch_target), + ] if ( + branch_target + and mov_dst.idx is not None + and test_dst.idx == mov_dst.idx + and test_src.idx == mov_dst.idx + and mov_src.size in ("qword", "dword") + ): + # Start deadness check from the jcc (i+2), not i+3, so + # both the fall-through AND the branch target are checked. + if not _reg_dead_after(program, i + 2, mov_dst.idx): + return None + mem_expr = mov_src.expr + _stat("P14_test_memory_fold") + result.append(f" | cmp byte {mem_expr}, 0\n") + result.append(lines[i + 2]) + return 3 + + # mov REG, qword/dword [MEM]; test REG_LOW, IMM; jcc + case [ + Asm(mnemonic="mov", dst=Reg() as mov_dst, src=Mem() as mov_src), + Asm(mnemonic="test", dst=Reg() as test_dst, src=Imm() as test_imm), + Asm(target=branch_target), + ] if ( + branch_target + and mov_dst.idx is not None + and test_dst.idx == mov_dst.idx + and mov_src.size in ("qword", "dword") + ): + if not _reg_dead_after(program, i + 2, mov_dst.idx): + return None + mem_expr = mov_src.expr + _stat("P14_test_memory_fold") + result.append(f" | test byte {mem_expr}, {test_imm.text}\n") + result.append(lines[i + 2]) + return 3 + return None + + +def _pattern_dead_null_check( + program: _PeepholeProgram, + i: int, + result: list[str], + state: _PeepholeState, +) -> int | None: + """Pattern 13: Remove dead NULL check after PyStackRef tag creation. + + When creating a tagged _PyStackRef from a raw pointer, Clang emits a + NULL check that is actually dead code: the preceding ``movzx`` already + dereferences the pointer (at offset 6), so a NULL pointer would + segfault before the check could ever fire. + + Matches (5 consecutive lines): + | movzx edi, word [rax + 6] ← dereferences rax (proves non-NULL) + | and edi, 1 ← extract ob_flags deferred bit + | or rdi, rax ← create tagged ref: ptr | flag + | cmp rdi, 1 ← dead: rax!=NULL so rdi!=1 + | je =>L(N) ← dead branch (error path) + + Emits only the first 3 lines (tag creation), removing the dead check. + + Example (from _BINARY_OP_MULTIPLY_FLOAT after freelist allocation): + Before: + | movzx edi, word [rax + 6] + | and edi, 1 + | or rdi, rax + | cmp rdi, 1 + | je =>L(3) ← removed (dead) + After: + | movzx edi, word [rax + 6] + | and edi, 1 + | or rdi, rax + + Uses structural pattern matching for robust operand checking. + """ + lines = program.lines + if i + 4 >= len(lines): + return None + window = [program.parsed[i + k] for k in range(5)] + match window: + case [ + # movzx edi, word [REG + 6] — dereferences REG (proves non-NULL) + Asm(mnemonic="movzx", src=Mem(size="word", base=deref_reg)), + # and edi, 1 — extract ob_flags deferred bit + Asm(mnemonic="and", src=Imm(value=1)), + # or rdi, REG — create tagged ref: ptr | flag + Asm(mnemonic="or", src=Reg(name=tagged_reg)), + # cmp rdi, 1 — dead NULL check + Asm(mnemonic="cmp", src=Imm(value=1)), + # je =>L(N) — dead branch (error path) + Asm(mnemonic="je", target=branch_target), + ] if deref_reg == tagged_reg.lower() and branch_target: + # Emit only the tag creation (first 3 lines), skip cmp + je + for k in range(3): + result.append(lines[i + k]) + _stat("P13_dead_null_check") + return 5 + return None + + +def _pattern_dead_frame_anchor( + program: _PeepholeProgram, + i: int, + result: list[str], + state: _PeepholeState, +) -> int | None: + """Remove dead lea anchors introduced to force canonical stack frames. + + The JIT template intentionally forces Clang to materialize a fixed stack + frame. That can leave behind dead instructions like ``lea rax, [rbp-144]`` + whose only purpose was to keep the frame allocated. Those writes must not + leak into the stitched trace, since they clobber live cross-stencil + registers such as ``rax``. + """ + del state + match program.parsed[i]: + case Asm( + mnemonic="lea", + dst=Reg() as dst, + src=Mem(base="rbp", index=None, scale=None, offset=offset), + ): + if offset >= 0 or dst.idx is None: + return None + if i + 1 < len(program.parsed): + next_effect = _line_effect(program.parsed[i + 1]) + if dst.idx in next_effect.reads: + return None + if not _is_dead_before_any_call(program, i, dst.idx): + return None + _stat("P18_dead_frame_anchor") + return 1 + return None + + +def _pattern_inverse_mov_restore( + program: _PeepholeProgram, + i: int, + result: list[str], + state: _PeepholeState, +) -> int | None: + """Drop the redundant second move in ``mov A, B`` / ``mov B, A`` pairs. + + The first move already preserves the original value of ``B`` in ``A`` while + leaving ``B`` unchanged, so the immediate inverse move is a no-op. + """ + del state + if i + 1 >= len(program.parsed): + return None + match program.parsed[i], program.parsed[i + 1]: + case ( + Asm(mnemonic="mov", dst=Reg() as dst1, src=Reg() as src1), + Asm(mnemonic="mov", dst=Reg() as dst2, src=Reg() as src2), + ): + if ( + dst1.idx is None + or src1.idx is None + or dst2.idx is None + or src2.idx is None + ): + return None + if ( + dst1.idx != src2.idx + or src1.idx != dst2.idx + or dst1.bits != src1.bits + or dst2.bits != src2.bits + or dst1.bits != dst2.bits + ): + return None + result.append(program.lines[i]) + _stat("P19_inverse_mov_restore") + return 2 + return None + + +def _pass_fold_mov_imm( + program: _PeepholeProgram, + i: int, + result: list[str], + state: _PeepholeState, +) -> int | None: + """Driver wrapper for the emit_mov_imm fold family.""" + del state + if ( + isinstance(program.parsed[i], CCall) + and program.parsed[i].kind == CCallKind.MOV_IMM + ): + return _fold_mov_imm(program, i, result) + return None + + +# Pass registry — order matters for priority +_PEEPHOLE_PASSES = ( + _PeepholePass("mov_imm_folds", _pass_fold_mov_imm), + _PeepholePass( + "SP0_preserve_flags_mov_imm", _pattern_preserve_flags_mov_imm + ), + _PeepholePass("P13_dead_null_check", _pattern_dead_null_check), + _PeepholePass("P18_dead_frame_anchor", _pattern_dead_frame_anchor), + _PeepholePass("P19_inverse_mov_restore", _pattern_inverse_mov_restore), + _PeepholePass("P14_test_memory_fold", _pattern_test_memory_fold), + _PeepholePass("SP1_store_reload_elim", _pattern_store_reload_elim), + _PeepholePass("SP2_cold_reload_insert", _pattern_cold_reload_insert), + _PeepholePass("SP3_inverted_store_reload", _pattern_inverted_store_reload), +) diff --git a/Tools/jit/_dasc_writer.py b/Tools/jit/_dasc_writer.py new file mode 100644 index 00000000000000..93e613f7ae3344 --- /dev/null +++ b/Tools/jit/_dasc_writer.py @@ -0,0 +1,448 @@ +"""Generate jit_stencils.h via DynASM from converted stencil assembly. + +This module replaces _writer.py for the DynASM-based JIT pipeline. +It generates a .dasc file from converted stencils, runs the DynASM Lua +preprocessor on it, and produces a complete jit_stencils.h header. +""" + +from __future__ import annotations + +import pathlib +import re +import subprocess +import typing + +import _asm_to_dasc + + +# Path to the DynASM Lua preprocessor +_DYNASM_DIR = pathlib.Path(__file__).resolve().parent / "LuaJIT" / "dynasm" +_DYNASM_LUA = _DYNASM_DIR / "dynasm.lua" + + +def _generate_dasc_content( + stencils: dict[str, _asm_to_dasc.ConvertedStencil], + shim: _asm_to_dasc.ConvertedStencil | None = None, +) -> typing.Iterator[str]: + """Generate the contents of the .dasc file. + + This produces a C file with embedded DynASM directives that, when + processed by dynasm.lua, yields a header with action lists and + dasm_put() calls for each stencil. + """ + max_frame_size = max( + (stencil.frame_size for stencil in stencils.values()), default=0 + ) + + # Deduplicate static data blobs: group identical content under a + # single declaration, mapping per-stencil names → shared names. + _data_by_content: dict[bytes, str] = {} # content → shared name + _data_name_map: dict[str, str] = {} # old per-stencil name → shared name + all_stencils = list(stencils.values()) + if shim: + all_stencils.append(shim) + for stencil in all_stencils: + for item in stencil.data_items: + safe = item.label.replace(".", "_") + old_name = f"jit_data_{stencil.opname}_{safe}" + content = bytes(item.data) + if content not in _data_by_content: + shared_name = f"jit_data_{len(_data_by_content)}" + _data_by_content[content] = shared_name + _data_name_map[old_name] = _data_by_content[content] + + # Simple identifier-based regex for data name substitution. + # Matches "jit_data_" followed by an identifier and looks it up + # in the map — much faster than a regex with 11K alternatives. + _DATA_RE = re.compile(r"\bjit_data_\w+") if _data_name_map else None + + def _resolve_data_names(lines: list[str]) -> list[str]: + """Replace per-stencil jit_data_OPNAME_LABEL with shared jit_data_N.""" + if _DATA_RE is None: + return lines + resolved = [] + for line in lines: + if "jit_data_" in line: + line = _DATA_RE.sub( + lambda m: _data_name_map.get(m.group(0), m.group(0)), + line, + ) + resolved.append(line) + return resolved + + yield "// Auto-generated by Tools/jit/_dasc_writer.py — DO NOT EDIT" + yield "// This file is processed by DynASM (dynasm.lua -D X64)" + yield "" + + # DynASM architecture and section definitions + yield "|.arch x64" + yield "|.section code, cold, data" + yield "|.actionlist jit_actionlist" + yield "" + + # Teach DynASM the standard x86-64 byte register names for registers + # 4-7 (spl, bpl, sil, dil). DynASM auto-generates r4b-r7b for these + # but the standard names are more readable. r8b-r15b are already + # known to DynASM natively. + yield "|.define spl, Rb(4)" + yield "|.define bpl, Rb(5)" + yield "|.define sil, Rb(6)" + yield "|.define dil, Rb(7)" + yield "" + + # Shorthand macro for label references — keeps emitted code compact + # and human-readable. L(n) references internal stencil labels. + yield "#define L(n) (label_base + (n))" + yield "" + + # Named register indices for emit_mov_imm() — human-readable + # alternative to raw integer indices. Prefixed with JREG_ to avoid + # collisions with system headers (e.g. ucontext.h REG_R8). + yield "#define JREG_RAX 0" + yield "#define JREG_RCX 1" + yield "#define JREG_RDX 2" + yield "#define JREG_RBX 3" + yield "#define JREG_RSP 4" + yield "#define JREG_RBP 5" + yield "#define JREG_RSI 6" + yield "#define JREG_RDI 7" + yield "#define JREG_R8 8" + yield "#define JREG_R9 9" + yield "#define JREG_R10 10" + yield "#define JREG_R11 11" + yield "#define JREG_R12 12" + yield "#define JREG_R13 13" + yield "#define JREG_R14 14" + yield "#define JREG_R15 15" + yield "" + + # jit_code_base is set by _PyJIT_Compile to the real allocation address. + # Always valid — jit_alloc() places code near CPython text via mmap hints. + yield "static uintptr_t jit_code_base;" + yield "" + + # Cross-stencil untag reuse: GUARD_TOS/NOS_FLOAT stencils compute + # rax = src_reg & -2 (pointer untagging). If the immediately following + # BINARY_OP_*_FLOAT stencil needs the same untag, it can reuse rax + yield "" + + # Runtime-optimal immediate load: picks the shortest encoding based + # on the actual value at JIT compile time. Like Pyston's emit_mov_imm. + # + # Encoding priority (shortest first): + # val == 0: xor Rd, Rd (2 bytes) + # val <= UINT32_MAX: mov Rd, imm32 (5 bytes) + # val near JIT code: lea Rq, [rip+disp32] (7 bytes) + # otherwise: mov64 Rq, imm64 (10 bytes) + # + # jit_code_base is always the real allocation address, so the LEA + # path is used whenever the value is within ±2GB of any point in the JIT code. + yield "static void emit_mov_imm(dasm_State **Dst, int r, uintptr_t val) {" + yield " if (val == 0) {" + yield " | xor Rd(r), Rd(r)" + yield " } else if (val <= UINT32_MAX) {" + yield " | mov Rd(r), (unsigned int)val" + yield " } else {" + yield " intptr_t delta = (intptr_t)(val - jit_code_base);" + yield " intptr_t safe_radius = 0x7FFFFFFFLL - PY_MAX_JIT_CODE_SIZE - 15;" + yield " if (delta >= -safe_radius && delta <= safe_radius) {" + yield " | lea Rq(r), [&((void*)(uintptr_t)val)]" + yield " } else {" + yield " | mov64 Rq(r), (unsigned long)val" + yield " }" + yield " }" + yield "}" + yield "" + + # Flag-preserving immediate load. Same codegen as emit_mov_imm() except + # that the zero-immediate case must use mov instead of xor so condition + # codes remain intact for a following setcc/cmov/jcc. + yield "static void emit_mov_imm_preserve_flags(dasm_State **Dst, int r, uintptr_t val) {" + yield " if (val == 0) {" + yield " | mov Rd(r), 0" + yield " } else {" + yield " emit_mov_imm(Dst, r, val);" + yield " }" + yield "}" + yield "" + + # Direct relative call to an external function. Uses DynASM's &addr + # syntax which emits a 5-byte E8 rel32 instruction. + # + # Falls back to mov+call for targets beyond ±2GB (e.g. shared + # library functions when JIT code is mapped far from them). + yield "static void emit_call_ext(dasm_State **Dst, void *addr) {" + yield " intptr_t delta = (intptr_t)((uintptr_t)addr - (uintptr_t)jit_code_base);" + yield " intptr_t safe_radius = 0x7FFFFFFFLL - PY_MAX_JIT_CODE_SIZE - 15;" + yield " if (delta >= -safe_radius && delta <= safe_radius) {" + yield " | call qword &addr // 5-byte E8 rel32" + yield " } else {" + yield " emit_mov_imm(Dst, JREG_RAX, (unsigned long)(uintptr_t)addr);" + yield " | call rax" + yield " }" + yield "}" + yield "" + + # Generalized ALU register-vs-immediate helper. Handles all commutative + # and comparison ALU operations: cmp, test, and, or, xor. + # When the immediate fits in sign-extended imm32, uses the direct form; + # otherwise loads into scratch register first. + for alu_op in ("cmp", "test", "and", "or", "xor"): + func_name = f"emit_{alu_op}_reg_imm" + yield f"__attribute__((unused)) // may not be used prevents a compiler warning" + yield f"static void {func_name}(dasm_State **Dst, int r, int scratch, uintptr_t val) {{" + yield " if ((int64_t)val == (int32_t)val) {" + yield f" | {alu_op} Rq(r), (int)val" + yield " } else {" + yield " emit_mov_imm(Dst, scratch, val);" + yield f" | {alu_op} Rq(r), Rq(scratch)" + yield " }" + yield "}" + yield "" + + # 64-bit memory-vs-immediate compare helper for simple [base + offset] + # operands. This replaces the old multiline inline if/else fallback in + # _asm_to_dasc.py with one helper call at the use site. + yield "__attribute__((unused)) // may not be used prevents a compiler warning" + yield "static void emit_cmp_mem64_imm(" + yield " dasm_State **Dst, int r_mem, long offset, int scratch, uintptr_t val" + yield ") {" + yield " if ((int64_t)val == (int32_t)val) {" + yield " | cmp qword [Rq(r_mem)+ offset], (int)val" + yield " } else {" + yield " emit_mov_imm(Dst, scratch, val);" + yield " | cmp qword [Rq(r_mem)+ offset], Rq(scratch)" + yield " }" + yield "}" + yield "" + + # 64-bit memory store with immediate value. When the value fits in + # sign-extended imm32, uses a direct mov qword [mem], imm32. + # Otherwise loads into scratch register first. + yield "__attribute__((unused)) // may not be used prevents a compiler warning" + yield "static void emit_store_mem64_imm(" + yield " dasm_State **Dst, int r_mem, long offset, int scratch, uintptr_t val" + yield ") {" + yield " if ((int64_t)val == (int32_t)val) {" + yield " | mov qword [Rq(r_mem)+ offset], (int)val" + yield " } else {" + yield " emit_mov_imm(Dst, scratch, val);" + yield " | mov qword [Rq(r_mem)+ offset], Rq(scratch)" + yield " }" + yield "}" + yield "" + + # _SET_IP delta helper: replace movabs+store (14 bytes) with + # add qword [frame+IP_OFFSET], delta (8 bytes) for subsequent _SET_IP ops. + yield "static void emit_set_ip_delta(dasm_State **Dst, int uop_label, intptr_t delta) {" + yield " |.code" + yield " |=>uop_label:" + yield ( + " | add qword [r13 + " + + str(_asm_to_dasc.FRAME_IP_OFFSET) + + "], (int)(delta)" + ) + yield "}" + yield "" + + yield "static int jit_max_frame_size(void) {" + yield f" return {max_frame_size};" + yield "}" + yield "" + + yield "static void emit_trace_entry_frame(dasm_State **Dst) {" + yield " |.code" + yield " | push rbp" + yield " | mov rbp, rsp" + yield " | sub rsp, jit_max_frame_size()" + yield "}" + yield "" + + # Deduplicated static data blobs (assert strings, etc.) + for content, shared_name in sorted( + _data_by_content.items(), key=lambda x: x[1] + ): + vals = ", ".join(str(b) for b in content) + yield f"static const char {shared_name}[] = {{{vals}}};" + yield "" + + # Emit function for each stencil + for opname, stencil in sorted(stencils.items()): + yield f"static void emit_{opname}(" + yield " dasm_State **Dst," + yield " const _PyUOpInstruction *instruction," + yield " int uop_label," + yield " int continue_label," + yield " int label_base" + yield ") {" + stencil_lines = _resolve_data_names(stencil.lines) + for line in stencil_lines: + yield line + yield "}" + yield "" + + # Shim emit function (not rewritten — has its own prologue/epilogue) + if shim: + yield "static void emit_shim(dasm_State **Dst, int uop_label, int label_base) {" + for line in _resolve_data_names(shim.lines): + yield line + yield "}" + yield "" + + # Shim internal-label-count helper + if shim: + yield f"static int jit_internal_label_count_shim(void) {{ return {shim.num_internal_labels}; }}" + else: + yield "static int jit_internal_label_count_shim(void) { return 0; }" + yield "" + + # Emit function type + yield "typedef void (*jit_emit_fn)(" + yield " dasm_State **Dst," + yield " const _PyUOpInstruction *instruction," + yield " int uop_label," + yield " int continue_label," + yield " int label_base" + yield ");" + yield "" + + # Stencil descriptor table: function pointer + label count + whether the + # stencil invalidates the tracked frame->ip value on the hot path. + yield "static const struct {" + yield " jit_emit_fn emit;" + yield " int label_count;" + yield " int invalidates_ip;" # stencil writes r13 or frame->ip on hot path + yield f"}} jit_stencil_table[MAX_UOP_REGS_ID + 1] = {{" + for opname, stencil in sorted(stencils.items()): + # Detect if the stencil invalidates our tracked frame->ip value. + # Two cases (both checked on the hot path only): + # 1. Writes to r13 directly (frame pointer changes) + # 2. Writes to [r13 + ] (frame->ip modified by the stencil) + invalidates_ip = 0 + _ip_mem = f"r13 + {_asm_to_dasc.FRAME_IP_OFFSET}" + for line in stencil.lines: + if ".cold" in line: + break + stripped = line.strip() + # Case 1: "mov r13, " (not "mov qword [r13+...]") + if "mov r13," in stripped: + before_r13 = stripped.split("r13,")[0] + if "[" not in before_r13: + invalidates_ip = 1 + break + # Case 2: write to [r13 + ] (frame->ip) + if _ip_mem in stripped: + for op in ( + f"mov qword [{_ip_mem}]", + f"add qword [{_ip_mem}]", + f"sub qword [{_ip_mem}]", + ): + if op in stripped: + invalidates_ip = 1 + break + if invalidates_ip: + break + yield ( + f" [{opname}] = {{ emit_{opname}, " + f"{stencil.num_internal_labels}, {invalidates_ip} }}," + ) + yield "};" + yield "" + + # Thin wrappers used by jit.c + yield "static int jit_internal_label_count(int opcode) {" + yield " return jit_stencil_table[opcode].label_count;" + yield "}" + yield "" + yield "static int jit_invalidates_ip(int opcode) {" + yield " return jit_stencil_table[opcode].invalidates_ip;" + yield "}" + yield "" + yield "static void jit_emit_one(" + yield " dasm_State **Dst," + yield " int opcode," + yield " const _PyUOpInstruction *instruction," + yield " int uop_label," + yield " int continue_label," + yield " int label_base" + yield ") {" + yield " jit_stencil_table[opcode].emit(Dst, instruction, uop_label, continue_label, label_base);" + yield "}" + + +def write_dasc( + dasc_path: pathlib.Path, + stencils: dict[str, _asm_to_dasc.ConvertedStencil], + shim: _asm_to_dasc.ConvertedStencil | None = None, +) -> None: + """Write the .dasc file to disk.""" + with dasc_path.open("w") as f: + for line in _generate_dasc_content(stencils, shim): + f.write(line) + f.write("\n") + + +def run_dynasm( + dasc_path: pathlib.Path, + output_path: pathlib.Path, + *, + luajit: str = "luajit", +) -> None: + """Run the DynASM Lua preprocessor to generate the C header. + + Args: + dasc_path: Path to the .dasc file + output_path: Path for the generated .h output + luajit: Path to luajit binary + """ + if not _DYNASM_LUA.exists(): + raise FileNotFoundError( + f"DynASM preprocessor not found at {_DYNASM_LUA}.\n" + f"Ensure the LuaJIT submodule is initialized:\n" + f" git submodule update --init Tools/jit/LuaJIT" + ) + cmd = [ + luajit, + str(_DYNASM_LUA), + "-D", + "X64", + "-o", + str(output_path), + str(dasc_path), + ] + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError( + f"DynASM preprocessing failed:\n" + f"stdout: {result.stdout}\n" + f"stderr: {result.stderr}\n" + f"Command: {' '.join(cmd)}" + ) + + +def dump_header( + stencils: dict[str, _asm_to_dasc.ConvertedStencil], + shim: _asm_to_dasc.ConvertedStencil | None = None, + *, + dasc_path: pathlib.Path, + luajit: str = "luajit", +) -> str: + """Generate jit_stencils.h content via DynASM. + + 1. Writes a .dasc file + 2. Runs dynasm.lua to produce a .h file + 3. Returns the .h file content + + Args: + stencils: Dict mapping opname to ConvertedStencil + shim: Optional shim stencil + dasc_path: Path for the intermediate .dasc file + luajit: Path to luajit binary + + Returns: + The generated header content as a string + """ + output_path = dasc_path.with_suffix(".h") + write_dasc(dasc_path, stencils, shim) + run_dynasm(dasc_path, output_path, luajit=luajit) + return output_path.read_text() diff --git a/Tools/jit/_optimizers.py b/Tools/jit/_optimizers.py index 83c878d8fe205b..0561ef824cce05 100644 --- a/Tools/jit/_optimizers.py +++ b/Tools/jit/_optimizers.py @@ -63,42 +63,40 @@ "hi": "ls", "ls": "hi", } -# MyPy doesn't understand that a invariant variable can be initialized by a covariant value -CUSTOM_AARCH64_BRANCH19: str | None = "CUSTOM_AARCH64_BRANCH19" - _AARCH64_SHORT_BRANCHES = { "tbz": "tbnz", "tbnz": "tbz", } -# Branches are either b.{cond}, bc.{cond}, cbz, cbnz, tbz or tbnz +# Branches are either b.{cond}, bc.{cond}, cbz, cbnz, tbz or tbnz. +# Second tuple element unused (was for relocation fixup, now handled by DynASM). _AARCH64_BRANCHES: dict[str, tuple[str | None, str | None]] = ( { - "b." + cond: (("b." + inverse if inverse else None), CUSTOM_AARCH64_BRANCH19) + "b." + cond: (("b." + inverse if inverse else None), None) for (cond, inverse) in _AARCH64_COND_CODES.items() } | { - "bc." + cond: (("bc." + inverse if inverse else None), CUSTOM_AARCH64_BRANCH19) + "bc." + cond: (("bc." + inverse if inverse else None), None) for (cond, inverse) in _AARCH64_COND_CODES.items() } | { - "cbz": ("cbnz", CUSTOM_AARCH64_BRANCH19), - "cbnz": ("cbz", CUSTOM_AARCH64_BRANCH19), + "cbz": ("cbnz", None), + "cbnz": ("cbz", None), + } + | { + cond: (inverse, None) + for (cond, inverse) in _AARCH64_SHORT_BRANCHES.items() } - | {cond: (inverse, None) for (cond, inverse) in _AARCH64_SHORT_BRANCHES.items()} ) @enum.unique class InstructionKind(enum.Enum): - JUMP = enum.auto() LONG_BRANCH = enum.auto() SHORT_BRANCH = enum.auto() CALL = enum.auto() RETURN = enum.auto() - SMALL_CONST_1 = enum.auto() - SMALL_CONST_2 = enum.auto() OTHER = enum.auto() @@ -110,12 +108,18 @@ class Instruction: target: str | None def is_branch(self) -> bool: - return self.kind in (InstructionKind.LONG_BRANCH, InstructionKind.SHORT_BRANCH) + return self.kind in ( + InstructionKind.LONG_BRANCH, + InstructionKind.SHORT_BRANCH, + ) def update_target(self, target: str) -> "Instruction": assert self.target is not None return Instruction( - self.kind, self.name, self.text.replace(self.target, target), target + self.kind, + self.name, + self.text.replace(self.target, target), + target, ) def update_name_and_target(self, name: str, target: str) -> "Instruction": @@ -164,7 +168,9 @@ class Optimizer: re_global: re.Pattern[str] # The first block in the linked list: _root: _Block = dataclasses.field(init=False, default_factory=_Block) - _labels: dict[str, _Block] = dataclasses.field(init=False, default_factory=dict) + _labels: dict[str, _Block] = dataclasses.field( + init=False, default_factory=dict + ) # No groups: _re_noninstructions: typing.ClassVar[re.Pattern[str]] = re.compile( r"\s*(?:\.|#|//|;|$)" @@ -174,8 +180,6 @@ class Optimizer: r'\s*(?P