WindowsPartialUnmapHandler.cs 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. using ARMeilleure.IntermediateRepresentation;
  2. using ARMeilleure.Translation;
  3. using Ryujinx.Common.Memory.PartialUnmaps;
  4. using System.Runtime.InteropServices;
  5. using static ARMeilleure.IntermediateRepresentation.Operand.Factory;
  6. namespace ARMeilleure.Signal
  7. {
  8. /// <summary>
  9. /// Methods to handle signals caused by partial unmaps. See the structs for C# implementations of the methods.
  10. /// </summary>
  11. internal static partial class WindowsPartialUnmapHandler
  12. {
  13. [LibraryImport("kernel32.dll", SetLastError = true, EntryPoint = "LoadLibraryA")]
  14. private static partial nint LoadLibrary([MarshalAs(UnmanagedType.LPStr)] string lpFileName);
  15. [LibraryImport("kernel32.dll", SetLastError = true)]
  16. private static partial nint GetProcAddress(nint hModule, [MarshalAs(UnmanagedType.LPStr)] string procName);
  17. private static nint _getCurrentThreadIdPtr;
  18. public static nint GetCurrentThreadIdFunc()
  19. {
  20. if (_getCurrentThreadIdPtr == nint.Zero)
  21. {
  22. nint handle = LoadLibrary("kernel32.dll");
  23. _getCurrentThreadIdPtr = GetProcAddress(handle, "GetCurrentThreadId");
  24. }
  25. return _getCurrentThreadIdPtr;
  26. }
  27. public static Operand EmitRetryFromAccessViolation(EmitterContext context)
  28. {
  29. nint partialRemapStatePtr = PartialUnmapState.GlobalState;
  30. nint localCountsPtr = nint.Add(partialRemapStatePtr, PartialUnmapState.LocalCountsOffset);
  31. // Get the lock first.
  32. EmitNativeReaderLockAcquire(context, nint.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapLockOffset));
  33. nint getCurrentThreadId = GetCurrentThreadIdFunc();
  34. Operand threadId = context.Call(Const((ulong)getCurrentThreadId), OperandType.I32);
  35. Operand threadIndex = EmitThreadLocalMapIntGetOrReserve(context, localCountsPtr, threadId, Const(0));
  36. Operand endLabel = Label();
  37. Operand retry = context.AllocateLocal(OperandType.I32);
  38. Operand threadIndexValidLabel = Label();
  39. context.BranchIfFalse(threadIndexValidLabel, context.ICompareEqual(threadIndex, Const(-1)));
  40. context.Copy(retry, Const(1)); // Always retry when thread local cannot be allocated.
  41. context.Branch(endLabel);
  42. context.MarkLabel(threadIndexValidLabel);
  43. Operand threadLocalPartialUnmapsPtr = EmitThreadLocalMapIntGetValuePtr(context, localCountsPtr, threadIndex);
  44. Operand threadLocalPartialUnmaps = context.Load(OperandType.I32, threadLocalPartialUnmapsPtr);
  45. Operand partialUnmapsCount = context.Load(OperandType.I32, Const((ulong)nint.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapsCountOffset)));
  46. context.Copy(retry, context.ICompareNotEqual(threadLocalPartialUnmaps, partialUnmapsCount));
  47. Operand noRetryLabel = Label();
  48. context.BranchIfFalse(noRetryLabel, retry);
  49. // if (retry) {
  50. context.Store(threadLocalPartialUnmapsPtr, partialUnmapsCount);
  51. context.Branch(endLabel);
  52. context.MarkLabel(noRetryLabel);
  53. // }
  54. context.MarkLabel(endLabel);
  55. // Finally, release the lock and return the retry value.
  56. EmitNativeReaderLockRelease(context, nint.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapLockOffset));
  57. return retry;
  58. }
  59. public static Operand EmitThreadLocalMapIntGetOrReserve(EmitterContext context, nint threadLocalMapPtr, Operand threadId, Operand initialState)
  60. {
  61. Operand idsPtr = Const((ulong)nint.Add(threadLocalMapPtr, ThreadLocalMap<int>.ThreadIdsOffset));
  62. Operand i = context.AllocateLocal(OperandType.I32);
  63. context.Copy(i, Const(0));
  64. // (Loop 1) Check all slots for a matching Thread ID (while also trying to allocate)
  65. Operand endLabel = Label();
  66. Operand loopLabel = Label();
  67. context.MarkLabel(loopLabel);
  68. Operand offset = context.Multiply(i, Const(sizeof(int)));
  69. Operand idPtr = context.Add(idsPtr, context.SignExtend32(OperandType.I64, offset));
  70. // Check that this slot has the thread ID.
  71. Operand existingId = context.CompareAndSwap(idPtr, threadId, threadId);
  72. // If it was already the thread ID, then we just need to return i.
  73. context.BranchIfTrue(endLabel, context.ICompareEqual(existingId, threadId));
  74. context.Copy(i, context.Add(i, Const(1)));
  75. context.BranchIfTrue(loopLabel, context.ICompareLess(i, Const(ThreadLocalMap<int>.MapSize)));
  76. // (Loop 2) Try take a slot that is 0 with our Thread ID.
  77. context.Copy(i, Const(0)); // Reset i.
  78. Operand loop2Label = Label();
  79. context.MarkLabel(loop2Label);
  80. Operand offset2 = context.Multiply(i, Const(sizeof(int)));
  81. Operand idPtr2 = context.Add(idsPtr, context.SignExtend32(OperandType.I64, offset2));
  82. // Try and swap in the thread id on top of 0.
  83. Operand existingId2 = context.CompareAndSwap(idPtr2, Const(0), threadId);
  84. Operand idNot0Label = Label();
  85. // If it was 0, then we need to initialize the struct entry and return i.
  86. context.BranchIfFalse(idNot0Label, context.ICompareEqual(existingId2, Const(0)));
  87. Operand structsPtr = Const((ulong)nint.Add(threadLocalMapPtr, ThreadLocalMap<int>.StructsOffset));
  88. Operand structPtr = context.Add(structsPtr, context.SignExtend32(OperandType.I64, offset2));
  89. context.Store(structPtr, initialState);
  90. context.Branch(endLabel);
  91. context.MarkLabel(idNot0Label);
  92. context.Copy(i, context.Add(i, Const(1)));
  93. context.BranchIfTrue(loop2Label, context.ICompareLess(i, Const(ThreadLocalMap<int>.MapSize)));
  94. context.Copy(i, Const(-1)); // Could not place the thread in the list.
  95. context.MarkLabel(endLabel);
  96. return context.Copy(i);
  97. }
  98. private static Operand EmitThreadLocalMapIntGetValuePtr(EmitterContext context, nint threadLocalMapPtr, Operand index)
  99. {
  100. Operand offset = context.Multiply(index, Const(sizeof(int)));
  101. Operand structsPtr = Const((ulong)nint.Add(threadLocalMapPtr, ThreadLocalMap<int>.StructsOffset));
  102. return context.Add(structsPtr, context.SignExtend32(OperandType.I64, offset));
  103. }
  104. private static void EmitAtomicAddI32(EmitterContext context, Operand ptr, Operand additive)
  105. {
  106. Operand loop = Label();
  107. context.MarkLabel(loop);
  108. Operand initial = context.Load(OperandType.I32, ptr);
  109. Operand newValue = context.Add(initial, additive);
  110. Operand replaced = context.CompareAndSwap(ptr, initial, newValue);
  111. context.BranchIfFalse(loop, context.ICompareEqual(initial, replaced));
  112. }
  113. private static void EmitNativeReaderLockAcquire(EmitterContext context, nint nativeReaderLockPtr)
  114. {
  115. Operand writeLockPtr = Const((ulong)nint.Add(nativeReaderLockPtr, NativeReaderWriterLock.WriteLockOffset));
  116. // Spin until we can acquire the write lock.
  117. Operand spinLabel = Label();
  118. context.MarkLabel(spinLabel);
  119. // Old value must be 0 to continue (we gained the write lock)
  120. context.BranchIfTrue(spinLabel, context.CompareAndSwap(writeLockPtr, Const(0), Const(1)));
  121. // Increment reader count.
  122. EmitAtomicAddI32(context, Const((ulong)nint.Add(nativeReaderLockPtr, NativeReaderWriterLock.ReaderCountOffset)), Const(1));
  123. // Release write lock.
  124. context.CompareAndSwap(writeLockPtr, Const(1), Const(0));
  125. }
  126. private static void EmitNativeReaderLockRelease(EmitterContext context, nint nativeReaderLockPtr)
  127. {
  128. // Decrement reader count.
  129. EmitAtomicAddI32(context, Const((ulong)nint.Add(nativeReaderLockPtr, NativeReaderWriterLock.ReaderCountOffset)), Const(-1));
  130. }
  131. }
  132. }