Преглед изворни кода

LightningJit: Reduce stack usage for Arm32 code (#6245)

* Write/read guest state to context for sync points, stop reserving stack for them

* Fix UsedGprsMask not being updated when allocating with preferencing

* POP should be also considered a return
gdkchan пре 2 година
родитељ
комит
ea07328aea

+ 5 - 0
src/Ryujinx.Cpu/LightningJit/Arm32/Block.cs

@@ -10,6 +10,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
         public readonly List<InstInfo> Instructions;
         public readonly List<InstInfo> Instructions;
         public readonly bool EndsWithBranch;
         public readonly bool EndsWithBranch;
         public readonly bool HasHostCall;
         public readonly bool HasHostCall;
+        public readonly bool HasHostCallSkipContext;
         public readonly bool IsTruncated;
         public readonly bool IsTruncated;
         public readonly bool IsLoopEnd;
         public readonly bool IsLoopEnd;
         public readonly bool IsThumb;
         public readonly bool IsThumb;
@@ -20,6 +21,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
             List<InstInfo> instructions,
             List<InstInfo> instructions,
             bool endsWithBranch,
             bool endsWithBranch,
             bool hasHostCall,
             bool hasHostCall,
+            bool hasHostCallSkipContext,
             bool isTruncated,
             bool isTruncated,
             bool isLoopEnd,
             bool isLoopEnd,
             bool isThumb)
             bool isThumb)
@@ -31,6 +33,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
             Instructions = instructions;
             Instructions = instructions;
             EndsWithBranch = endsWithBranch;
             EndsWithBranch = endsWithBranch;
             HasHostCall = hasHostCall;
             HasHostCall = hasHostCall;
+            HasHostCallSkipContext = hasHostCallSkipContext;
             IsTruncated = isTruncated;
             IsTruncated = isTruncated;
             IsLoopEnd = isLoopEnd;
             IsLoopEnd = isLoopEnd;
             IsThumb = isThumb;
             IsThumb = isThumb;
@@ -57,6 +60,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
                 Instructions.GetRange(0, splitIndex),
                 Instructions.GetRange(0, splitIndex),
                 false,
                 false,
                 HasHostCall,
                 HasHostCall,
+                HasHostCallSkipContext,
                 false,
                 false,
                 false,
                 false,
                 IsThumb);
                 IsThumb);
@@ -67,6 +71,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
                 Instructions.GetRange(splitIndex, splitCount),
                 Instructions.GetRange(splitIndex, splitCount),
                 EndsWithBranch,
                 EndsWithBranch,
                 HasHostCall,
                 HasHostCall,
+                HasHostCallSkipContext,
                 IsTruncated,
                 IsTruncated,
                 IsLoopEnd,
                 IsLoopEnd,
                 IsThumb);
                 IsThumb);

+ 13 - 3
src/Ryujinx.Cpu/LightningJit/Arm32/Decoder.cs

@@ -208,6 +208,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
             InstMeta meta;
             InstMeta meta;
             InstFlags extraFlags = InstFlags.None;
             InstFlags extraFlags = InstFlags.None;
             bool hasHostCall = false;
             bool hasHostCall = false;
+            bool hasHostCallSkipContext = false;
             bool isTruncated = false;
             bool isTruncated = false;
 
 
             do
             do
@@ -246,9 +247,17 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
                     meta = InstTableA32<T>.GetMeta(encoding, cpuPreset.Version, cpuPreset.Features);
                     meta = InstTableA32<T>.GetMeta(encoding, cpuPreset.Version, cpuPreset.Features);
                 }
                 }
 
 
-                if (meta.Name.IsSystemOrCall() && !hasHostCall)
+                if (meta.Name.IsSystemOrCall())
                 {
                 {
-                    hasHostCall = meta.Name.IsCall() || InstEmitSystem.NeedsCall(meta.Name);
+                    if (!hasHostCall)
+                    {
+                        hasHostCall = InstEmitSystem.NeedsCall(meta.Name);
+                    }
+
+                    if (!hasHostCallSkipContext)
+                    {
+                        hasHostCallSkipContext = meta.Name.IsCall() || InstEmitSystem.NeedsCallSkipContext(meta.Name);
+                    }
                 }
                 }
 
 
                 insts.Add(new(encoding, meta.Name, meta.EmitFunc, meta.Flags | extraFlags));
                 insts.Add(new(encoding, meta.Name, meta.EmitFunc, meta.Flags | extraFlags));
@@ -259,8 +268,8 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
 
 
             if (!isTruncated && IsBackwardsBranch(meta.Name, encoding))
             if (!isTruncated && IsBackwardsBranch(meta.Name, encoding))
             {
             {
-                hasHostCall = true;
                 isLoopEnd = true;
                 isLoopEnd = true;
+                hasHostCallSkipContext = true;
             }
             }
 
 
             return new(
             return new(
@@ -269,6 +278,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
                 insts,
                 insts,
                 !isTruncated,
                 !isTruncated,
                 hasHostCall,
                 hasHostCall,
+                hasHostCallSkipContext,
                 isTruncated,
                 isTruncated,
                 isLoopEnd,
                 isLoopEnd,
                 isThumb);
                 isThumb);

+ 3 - 0
src/Ryujinx.Cpu/LightningJit/Arm32/MultiBlock.cs

@@ -6,6 +6,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
     {
     {
         public readonly List<Block> Blocks;
         public readonly List<Block> Blocks;
         public readonly bool HasHostCall;
         public readonly bool HasHostCall;
+        public readonly bool HasHostCallSkipContext;
         public readonly bool IsTruncated;
         public readonly bool IsTruncated;
 
 
         public MultiBlock(List<Block> blocks)
         public MultiBlock(List<Block> blocks)
@@ -15,12 +16,14 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
             Block block = blocks[0];
             Block block = blocks[0];
 
 
             HasHostCall = block.HasHostCall;
             HasHostCall = block.HasHostCall;
+            HasHostCallSkipContext = block.HasHostCallSkipContext;
 
 
             for (int index = 1; index < blocks.Count; index++)
             for (int index = 1; index < blocks.Count; index++)
             {
             {
                 block = blocks[index];
                 block = blocks[index];
 
 
                 HasHostCall |= block.HasHostCall;
                 HasHostCall |= block.HasHostCall;
+                HasHostCallSkipContext |= block.HasHostCallSkipContext;
             }
             }
 
 
             block = blocks[^1];
             block = blocks[^1];

+ 1 - 0
src/Ryujinx.Cpu/LightningJit/Arm32/RegisterAllocator.cs

@@ -106,6 +106,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32
                 if ((regMask & AbiConstants.ReservedRegsMask) == 0)
                 if ((regMask & AbiConstants.ReservedRegsMask) == 0)
                 {
                 {
                     _gprMask |= regMask;
                     _gprMask |= regMask;
+                    UsedGprsMask |= regMask;
 
 
                     return firstCalleeSaved;
                     return firstCalleeSaved;
                 }
                 }

+ 23 - 4
src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/Compiler.cs

@@ -305,12 +305,23 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
                 ForceConditionalEnd(cgContext, ref lastCondition, lastConditionIp);
                 ForceConditionalEnd(cgContext, ref lastCondition, lastConditionIp);
             }
             }
 
 
+            int reservedStackSize = 0;
+
+            if (multiBlock.HasHostCall)
+            {
+                reservedStackSize = CalculateStackSizeForCallSpill(regAlloc.UsedGprsMask, regAlloc.UsedFpSimdMask, UsablePStateMask);
+            }
+            else if (multiBlock.HasHostCallSkipContext)
+            {
+                reservedStackSize = 2 * sizeof(ulong); // Context and page table pointers.
+            }
+
             RegisterSaveRestore rsr = new(
             RegisterSaveRestore rsr = new(
                 regAlloc.UsedGprsMask & AbiConstants.GprCalleeSavedRegsMask,
                 regAlloc.UsedGprsMask & AbiConstants.GprCalleeSavedRegsMask,
                 regAlloc.UsedFpSimdMask & AbiConstants.FpSimdCalleeSavedRegsMask,
                 regAlloc.UsedFpSimdMask & AbiConstants.FpSimdCalleeSavedRegsMask,
                 OperandType.FP64,
                 OperandType.FP64,
-                multiBlock.HasHostCall,
-                multiBlock.HasHostCall ? CalculateStackSizeForCallSpill(regAlloc.UsedGprsMask, regAlloc.UsedFpSimdMask, UsablePStateMask) : 0);
+                multiBlock.HasHostCall || multiBlock.HasHostCallSkipContext,
+                reservedStackSize);
 
 
             TailMerger tailMerger = new();
             TailMerger tailMerger = new();
 
 
@@ -596,7 +607,8 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
                 name == InstName.Ldm ||
                 name == InstName.Ldm ||
                 name == InstName.Ldmda ||
                 name == InstName.Ldmda ||
                 name == InstName.Ldmdb ||
                 name == InstName.Ldmdb ||
-                name == InstName.Ldmib)
+                name == InstName.Ldmib ||
+                name == InstName.Pop)
             {
             {
                 // Arm32 does not have a return instruction, instead returns are implemented
                 // Arm32 does not have a return instruction, instead returns are implemented
                 // either using BX LR (for leaf functions), or POP { ... PC }.
                 // either using BX LR (for leaf functions), or POP { ... PC }.
@@ -711,7 +723,14 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
             switch (type)
             switch (type)
             {
             {
                 case BranchType.SyncPoint:
                 case BranchType.SyncPoint:
-                    InstEmitSystem.WriteSyncPoint(context.Writer, context.RegisterAllocator, context.TailMerger, context.GetReservedStackOffset());
+                    InstEmitSystem.WriteSyncPoint(
+                        context.Writer,
+                        ref asm,
+                        context.RegisterAllocator,
+                        context.TailMerger,
+                        context.GetReservedStackOffset(),
+                        context.StoreToContext,
+                        context.LoadFromContext);
                     break;
                     break;
                 case BranchType.SoftwareInterrupt:
                 case BranchType.SoftwareInterrupt:
                     context.StoreToContext();
                     context.StoreToContext();

+ 2 - 2
src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/InstEmitFlow.cs

@@ -199,12 +199,12 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
             }
             }
         }
         }
 
 
-        private static void WriteSpillSkipContext(ref Assembler asm, RegisterAllocator regAlloc, int spillOffset)
+        public static void WriteSpillSkipContext(ref Assembler asm, RegisterAllocator regAlloc, int spillOffset)
         {
         {
             WriteSpillOrFillSkipContext(ref asm, regAlloc, spillOffset, spill: true);
             WriteSpillOrFillSkipContext(ref asm, regAlloc, spillOffset, spill: true);
         }
         }
 
 
-        private static void WriteFillSkipContext(ref Assembler asm, RegisterAllocator regAlloc, int spillOffset)
+        public static void WriteFillSkipContext(ref Assembler asm, RegisterAllocator regAlloc, int spillOffset)
         {
         {
             WriteSpillOrFillSkipContext(ref asm, regAlloc, spillOffset, spill: false);
             WriteSpillOrFillSkipContext(ref asm, regAlloc, spillOffset, spill: false);
         }
         }

+ 39 - 27
src/Ryujinx.Cpu/LightningJit/Arm32/Target/Arm64/InstEmitSystem.cs

@@ -354,11 +354,18 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
             // All instructions that might do a host call should be included here.
             // All instructions that might do a host call should be included here.
             // That is required to reserve space on the stack for caller saved registers.
             // That is required to reserve space on the stack for caller saved registers.
 
 
+            return name == InstName.Mrrc;
+        }
+
+        public static bool NeedsCallSkipContext(InstName name)
+        {
+            // All instructions that might do a host call should be included here.
+            // That is required to reserve space on the stack for caller saved registers.
+
             switch (name)
             switch (name)
             {
             {
                 case InstName.Mcr:
                 case InstName.Mcr:
                 case InstName.Mrc:
                 case InstName.Mrc:
-                case InstName.Mrrc:
                 case InstName.Svc:
                 case InstName.Svc:
                 case InstName.Udf:
                 case InstName.Udf:
                     return true;
                     return true;
@@ -372,7 +379,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
             Assembler asm = new(writer);
             Assembler asm = new(writer);
 
 
             WriteCall(ref asm, regAlloc, GetBkptHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, imm);
             WriteCall(ref asm, regAlloc, GetBkptHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, imm);
-            WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, skipContext: true, spillBaseOffset);
+            WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, spillBaseOffset);
         }
         }
 
 
         public static void WriteSvc(CodeWriter writer, RegisterAllocator regAlloc, TailMerger tailMerger, int spillBaseOffset, uint pc, uint svcId)
         public static void WriteSvc(CodeWriter writer, RegisterAllocator regAlloc, TailMerger tailMerger, int spillBaseOffset, uint pc, uint svcId)
@@ -380,7 +387,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
             Assembler asm = new(writer);
             Assembler asm = new(writer);
 
 
             WriteCall(ref asm, regAlloc, GetSvcHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, svcId);
             WriteCall(ref asm, regAlloc, GetSvcHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, svcId);
-            WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, skipContext: true, spillBaseOffset);
+            WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, spillBaseOffset);
         }
         }
 
 
         public static void WriteUdf(CodeWriter writer, RegisterAllocator regAlloc, TailMerger tailMerger, int spillBaseOffset, uint pc, uint imm)
         public static void WriteUdf(CodeWriter writer, RegisterAllocator regAlloc, TailMerger tailMerger, int spillBaseOffset, uint pc, uint imm)
@@ -388,7 +395,7 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
             Assembler asm = new(writer);
             Assembler asm = new(writer);
 
 
             WriteCall(ref asm, regAlloc, GetUdfHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, imm);
             WriteCall(ref asm, regAlloc, GetUdfHandlerPtr(), skipContext: true, spillBaseOffset, null, pc, imm);
-            WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, skipContext: true, spillBaseOffset);
+            WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, spillBaseOffset);
         }
         }
 
 
         public static void WriteReadCntpct(CodeWriter writer, RegisterAllocator regAlloc, int spillBaseOffset, int rt, int rt2)
         public static void WriteReadCntpct(CodeWriter writer, RegisterAllocator regAlloc, int spillBaseOffset, int rt, int rt2)
@@ -422,14 +429,14 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
             WriteFill(ref asm, regAlloc, resultMask, skipContext: false, spillBaseOffset, tempRegister);
             WriteFill(ref asm, regAlloc, resultMask, skipContext: false, spillBaseOffset, tempRegister);
         }
         }
 
 
-        public static void WriteSyncPoint(CodeWriter writer, RegisterAllocator regAlloc, TailMerger tailMerger, int spillBaseOffset)
-        {
-            Assembler asm = new(writer);
-
-            WriteSyncPoint(writer, ref asm, regAlloc, tailMerger, skipContext: false, spillBaseOffset);
-        }
-
-        private static void WriteSyncPoint(CodeWriter writer, ref Assembler asm, RegisterAllocator regAlloc, TailMerger tailMerger, bool skipContext, int spillBaseOffset)
+        public static void WriteSyncPoint(
+            CodeWriter writer,
+            ref Assembler asm,
+            RegisterAllocator regAlloc,
+            TailMerger tailMerger,
+            int spillBaseOffset,
+            Action storeToContext = null,
+            Action loadFromContext = null)
         {
         {
             int tempRegister = regAlloc.AllocateTempGprRegister();
             int tempRegister = regAlloc.AllocateTempGprRegister();
 
 
@@ -440,7 +447,8 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
             int branchIndex = writer.InstructionPointer;
             int branchIndex = writer.InstructionPointer;
             asm.Cbnz(rt, 0);
             asm.Cbnz(rt, 0);
 
 
-            WriteSpill(ref asm, regAlloc, 1u << tempRegister, skipContext, spillBaseOffset, tempRegister);
+            storeToContext?.Invoke();
+            WriteSpill(ref asm, regAlloc, 1u << tempRegister, skipContext: true, spillBaseOffset, tempRegister);
 
 
             Operand rn = Register(tempRegister == 0 ? 1 : 0);
             Operand rn = Register(tempRegister == 0 ? 1 : 0);
 
 
@@ -449,7 +457,8 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
 
 
             tailMerger.AddConditionalZeroReturn(writer, asm, Register(0, OperandType.I32));
             tailMerger.AddConditionalZeroReturn(writer, asm, Register(0, OperandType.I32));
 
 
-            WriteFill(ref asm, regAlloc, 1u << tempRegister, skipContext, spillBaseOffset, tempRegister);
+            WriteFill(ref asm, regAlloc, 1u << tempRegister, skipContext: true, spillBaseOffset, tempRegister);
+            loadFromContext?.Invoke();
 
 
             asm.LdrRiUn(rt, Register(regAlloc.FixedContextRegister), NativeContextOffsets.CounterOffset);
             asm.LdrRiUn(rt, Register(regAlloc.FixedContextRegister), NativeContextOffsets.CounterOffset);
 
 
@@ -514,18 +523,31 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
 
 
         private static void WriteSpill(ref Assembler asm, RegisterAllocator regAlloc, uint exceptMask, bool skipContext, int spillOffset, int tempRegister)
         private static void WriteSpill(ref Assembler asm, RegisterAllocator regAlloc, uint exceptMask, bool skipContext, int spillOffset, int tempRegister)
         {
         {
-            WriteSpillOrFill(ref asm, regAlloc, skipContext, exceptMask, spillOffset, tempRegister, spill: true);
+            if (skipContext)
+            {
+                InstEmitFlow.WriteSpillSkipContext(ref asm, regAlloc, spillOffset);
+            }
+            else
+            {
+                WriteSpillOrFill(ref asm, regAlloc, exceptMask, spillOffset, tempRegister, spill: true);
+            }
         }
         }
 
 
         private static void WriteFill(ref Assembler asm, RegisterAllocator regAlloc, uint exceptMask, bool skipContext, int spillOffset, int tempRegister)
         private static void WriteFill(ref Assembler asm, RegisterAllocator regAlloc, uint exceptMask, bool skipContext, int spillOffset, int tempRegister)
         {
         {
-            WriteSpillOrFill(ref asm, regAlloc, skipContext, exceptMask, spillOffset, tempRegister, spill: false);
+            if (skipContext)
+            {
+                InstEmitFlow.WriteFillSkipContext(ref asm, regAlloc, spillOffset);
+            }
+            else
+            {
+                WriteSpillOrFill(ref asm, regAlloc, exceptMask, spillOffset, tempRegister, spill: false);
+            }
         }
         }
 
 
         private static void WriteSpillOrFill(
         private static void WriteSpillOrFill(
             ref Assembler asm,
             ref Assembler asm,
             RegisterAllocator regAlloc,
             RegisterAllocator regAlloc,
-            bool skipContext,
             uint exceptMask,
             uint exceptMask,
             int spillOffset,
             int spillOffset,
             int tempRegister,
             int tempRegister,
@@ -533,11 +555,6 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
         {
         {
             uint gprMask = regAlloc.UsedGprsMask & ~(AbiConstants.GprCalleeSavedRegsMask | exceptMask);
             uint gprMask = regAlloc.UsedGprsMask & ~(AbiConstants.GprCalleeSavedRegsMask | exceptMask);
 
 
-            if (skipContext)
-            {
-                gprMask &= ~Compiler.UsableGprsMask;
-            }
-
             if (!spill)
             if (!spill)
             {
             {
                 // We must reload the status register before reloading the GPRs,
                 // We must reload the status register before reloading the GPRs,
@@ -600,11 +617,6 @@ namespace Ryujinx.Cpu.LightningJit.Arm32.Target.Arm64
 
 
             uint fpSimdMask = regAlloc.UsedFpSimdMask;
             uint fpSimdMask = regAlloc.UsedFpSimdMask;
 
 
-            if (skipContext)
-            {
-                fpSimdMask &= ~Compiler.UsableFpSimdMask;
-            }
-
             while (fpSimdMask != 0)
             while (fpSimdMask != 0)
             {
             {
                 int reg = BitOperations.TrailingZeroCount(fpSimdMask);
                 int reg = BitOperations.TrailingZeroCount(fpSimdMask);