NativeSignalHandler.cs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. using ARMeilleure.IntermediateRepresentation;
  2. using ARMeilleure.Translation;
  3. using System;
  4. using System.Runtime.CompilerServices;
  5. using System.Runtime.InteropServices;
  6. using static ARMeilleure.IntermediateRepresentation.Operand.Factory;
  7. namespace ARMeilleure.Signal
  8. {
  9. [StructLayout(LayoutKind.Sequential, Pack = 1)]
  10. struct SignalHandlerRange
  11. {
  12. public int IsActive;
  13. public nuint RangeAddress;
  14. public nuint RangeEndAddress;
  15. public IntPtr ActionPointer;
  16. }
  17. [StructLayout(LayoutKind.Sequential, Pack = 1)]
  18. struct SignalHandlerConfig
  19. {
  20. /// <summary>
  21. /// The byte offset of the faulting address in the SigInfo or ExceptionRecord struct.
  22. /// </summary>
  23. public int StructAddressOffset;
  24. /// <summary>
  25. /// The byte offset of the write flag in the SigInfo or ExceptionRecord struct.
  26. /// </summary>
  27. public int StructWriteOffset;
  28. /// <summary>
  29. /// The sigaction handler that was registered before this one. (unix only)
  30. /// </summary>
  31. public nuint UnixOldSigaction;
  32. /// <summary>
  33. /// The type of the previous sigaction. True for the 3 argument variant. (unix only)
  34. /// </summary>
  35. public int UnixOldSigaction3Arg;
  36. public SignalHandlerRange Range0;
  37. public SignalHandlerRange Range1;
  38. public SignalHandlerRange Range2;
  39. public SignalHandlerRange Range3;
  40. public SignalHandlerRange Range4;
  41. public SignalHandlerRange Range5;
  42. public SignalHandlerRange Range6;
  43. public SignalHandlerRange Range7;
  44. }
  45. public static class NativeSignalHandler
  46. {
  47. private delegate void UnixExceptionHandler(int sig, IntPtr info, IntPtr ucontext);
  48. [UnmanagedFunctionPointer(CallingConvention.Winapi)]
  49. private delegate int VectoredExceptionHandler(IntPtr exceptionInfo);
  50. private const int MaxTrackedRanges = 8;
  51. private const int StructAddressOffset = 0;
  52. private const int StructWriteOffset = 4;
  53. private const int UnixOldSigaction = 8;
  54. private const int UnixOldSigaction3Arg = 16;
  55. private const int RangeOffset = 20;
  56. private const int EXCEPTION_CONTINUE_SEARCH = 0;
  57. private const int EXCEPTION_CONTINUE_EXECUTION = -1;
  58. private const uint EXCEPTION_ACCESS_VIOLATION = 0xc0000005;
  59. private const ulong PageSize = 0x1000;
  60. private const ulong PageMask = PageSize - 1;
  61. private static IntPtr _handlerConfig;
  62. private static IntPtr _signalHandlerPtr;
  63. private static IntPtr _signalHandlerHandle;
  64. private static readonly object _lock = new object();
  65. private static bool _initialized;
  66. static NativeSignalHandler()
  67. {
  68. _handlerConfig = Marshal.AllocHGlobal(Unsafe.SizeOf<SignalHandlerConfig>());
  69. ref SignalHandlerConfig config = ref GetConfigRef();
  70. config = new SignalHandlerConfig();
  71. }
  72. public static void InitializeSignalHandler()
  73. {
  74. if (_initialized) return;
  75. lock (_lock)
  76. {
  77. if (_initialized) return;
  78. bool unix = OperatingSystem.IsLinux() || OperatingSystem.IsMacOS();
  79. ref SignalHandlerConfig config = ref GetConfigRef();
  80. if (unix)
  81. {
  82. // Unix siginfo struct locations.
  83. // NOTE: These are incredibly likely to be different between kernel version and architectures.
  84. config.StructAddressOffset = OperatingSystem.IsMacOS() ? 24 : 16; // si_addr
  85. config.StructWriteOffset = 8; // si_code
  86. _signalHandlerPtr = Marshal.GetFunctionPointerForDelegate(GenerateUnixSignalHandler(_handlerConfig));
  87. SigAction old = UnixSignalHandlerRegistration.RegisterExceptionHandler(_signalHandlerPtr);
  88. config.UnixOldSigaction = (nuint)(ulong)old.sa_handler;
  89. config.UnixOldSigaction3Arg = old.sa_flags & 4;
  90. }
  91. else
  92. {
  93. config.StructAddressOffset = 40; // ExceptionInformation1
  94. config.StructWriteOffset = 32; // ExceptionInformation0
  95. _signalHandlerPtr = Marshal.GetFunctionPointerForDelegate(GenerateWindowsSignalHandler(_handlerConfig));
  96. _signalHandlerHandle = WindowsSignalHandlerRegistration.RegisterExceptionHandler(_signalHandlerPtr);
  97. }
  98. _initialized = true;
  99. }
  100. }
  101. private static unsafe ref SignalHandlerConfig GetConfigRef()
  102. {
  103. return ref Unsafe.AsRef<SignalHandlerConfig>((void*)_handlerConfig);
  104. }
  105. public static unsafe bool AddTrackedRegion(nuint address, nuint endAddress, IntPtr action)
  106. {
  107. var ranges = &((SignalHandlerConfig*)_handlerConfig)->Range0;
  108. for (int i = 0; i < MaxTrackedRanges; i++)
  109. {
  110. if (ranges[i].IsActive == 0)
  111. {
  112. ranges[i].RangeAddress = address;
  113. ranges[i].RangeEndAddress = endAddress;
  114. ranges[i].ActionPointer = action;
  115. ranges[i].IsActive = 1;
  116. return true;
  117. }
  118. }
  119. return false;
  120. }
  121. public static unsafe bool RemoveTrackedRegion(nuint address)
  122. {
  123. var ranges = &((SignalHandlerConfig*)_handlerConfig)->Range0;
  124. for (int i = 0; i < MaxTrackedRanges; i++)
  125. {
  126. if (ranges[i].IsActive == 1 && ranges[i].RangeAddress == address)
  127. {
  128. ranges[i].IsActive = 0;
  129. return true;
  130. }
  131. }
  132. return false;
  133. }
  134. private static Operand EmitGenericRegionCheck(EmitterContext context, IntPtr signalStructPtr, Operand faultAddress, Operand isWrite)
  135. {
  136. Operand inRegionLocal = context.AllocateLocal(OperandType.I32);
  137. context.Copy(inRegionLocal, Const(0));
  138. Operand endLabel = Label();
  139. for (int i = 0; i < MaxTrackedRanges; i++)
  140. {
  141. ulong rangeBaseOffset = (ulong)(RangeOffset + i * Unsafe.SizeOf<SignalHandlerRange>());
  142. Operand nextLabel = Label();
  143. Operand isActive = context.Load(OperandType.I32, Const((ulong)signalStructPtr + rangeBaseOffset));
  144. context.BranchIfFalse(nextLabel, isActive);
  145. Operand rangeAddress = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 4));
  146. Operand rangeEndAddress = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 12));
  147. // Is the fault address within this tracked region?
  148. Operand inRange = context.BitwiseAnd(
  149. context.ICompare(faultAddress, rangeAddress, Comparison.GreaterOrEqualUI),
  150. context.ICompare(faultAddress, rangeEndAddress, Comparison.LessUI)
  151. );
  152. // Only call tracking if in range.
  153. context.BranchIfFalse(nextLabel, inRange, BasicBlockFrequency.Cold);
  154. Operand offset = context.BitwiseAnd(context.Subtract(faultAddress, rangeAddress), Const(~PageMask));
  155. // Call the tracking action, with the pointer's relative offset to the base address.
  156. Operand trackingActionPtr = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 20));
  157. context.Copy(inRegionLocal, Const(0));
  158. Operand skipActionLabel = Label();
  159. // Tracking action should be non-null to call it, otherwise assume false return.
  160. context.BranchIfFalse(skipActionLabel, trackingActionPtr);
  161. Operand result = context.Call(trackingActionPtr, OperandType.I32, offset, Const(PageSize), isWrite, Const(0));
  162. context.Copy(inRegionLocal, result);
  163. context.MarkLabel(skipActionLabel);
  164. // If the tracking action returns false or does not exist, it might be an invalid access due to a partial overlap on Windows.
  165. if (OperatingSystem.IsWindows())
  166. {
  167. context.BranchIfTrue(endLabel, inRegionLocal);
  168. context.Copy(inRegionLocal, WindowsPartialUnmapHandler.EmitRetryFromAccessViolation(context));
  169. }
  170. context.Branch(endLabel);
  171. context.MarkLabel(nextLabel);
  172. }
  173. context.MarkLabel(endLabel);
  174. return context.Copy(inRegionLocal);
  175. }
  176. private static UnixExceptionHandler GenerateUnixSignalHandler(IntPtr signalStructPtr)
  177. {
  178. EmitterContext context = new EmitterContext();
  179. // (int sig, SigInfo* sigInfo, void* ucontext)
  180. Operand sigInfoPtr = context.LoadArgument(OperandType.I64, 1);
  181. Operand structAddressOffset = context.Load(OperandType.I64, Const((ulong)signalStructPtr + StructAddressOffset));
  182. Operand structWriteOffset = context.Load(OperandType.I64, Const((ulong)signalStructPtr + StructWriteOffset));
  183. Operand faultAddress = context.Load(OperandType.I64, context.Add(sigInfoPtr, context.ZeroExtend32(OperandType.I64, structAddressOffset)));
  184. Operand writeFlag = context.Load(OperandType.I64, context.Add(sigInfoPtr, context.ZeroExtend32(OperandType.I64, structWriteOffset)));
  185. Operand isWrite = context.ICompareNotEqual(writeFlag, Const(0L)); // Normalize to 0/1.
  186. Operand isInRegion = EmitGenericRegionCheck(context, signalStructPtr, faultAddress, isWrite);
  187. Operand endLabel = Label();
  188. context.BranchIfTrue(endLabel, isInRegion);
  189. Operand unixOldSigaction = context.Load(OperandType.I64, Const((ulong)signalStructPtr + UnixOldSigaction));
  190. Operand unixOldSigaction3Arg = context.Load(OperandType.I64, Const((ulong)signalStructPtr + UnixOldSigaction3Arg));
  191. Operand threeArgLabel = Label();
  192. context.BranchIfTrue(threeArgLabel, unixOldSigaction3Arg);
  193. context.Call(unixOldSigaction, OperandType.None, context.LoadArgument(OperandType.I32, 0));
  194. context.Branch(endLabel);
  195. context.MarkLabel(threeArgLabel);
  196. context.Call(unixOldSigaction,
  197. OperandType.None,
  198. context.LoadArgument(OperandType.I32, 0),
  199. sigInfoPtr,
  200. context.LoadArgument(OperandType.I64, 2)
  201. );
  202. context.MarkLabel(endLabel);
  203. context.Return();
  204. ControlFlowGraph cfg = context.GetControlFlowGraph();
  205. OperandType[] argTypes = new OperandType[] { OperandType.I32, OperandType.I64, OperandType.I64 };
  206. return Compiler.Compile(cfg, argTypes, OperandType.None, CompilerOptions.HighCq).Map<UnixExceptionHandler>();
  207. }
  208. private static VectoredExceptionHandler GenerateWindowsSignalHandler(IntPtr signalStructPtr)
  209. {
  210. EmitterContext context = new EmitterContext();
  211. // (ExceptionPointers* exceptionInfo)
  212. Operand exceptionInfoPtr = context.LoadArgument(OperandType.I64, 0);
  213. Operand exceptionRecordPtr = context.Load(OperandType.I64, exceptionInfoPtr);
  214. // First thing's first - this catches a number of exceptions, but we only want access violations.
  215. Operand validExceptionLabel = Label();
  216. Operand exceptionCode = context.Load(OperandType.I32, exceptionRecordPtr);
  217. context.BranchIf(validExceptionLabel, exceptionCode, Const(EXCEPTION_ACCESS_VIOLATION), Comparison.Equal);
  218. context.Return(Const(EXCEPTION_CONTINUE_SEARCH)); // Don't handle this one.
  219. context.MarkLabel(validExceptionLabel);
  220. // Next, read the address of the invalid access, and whether it is a write or not.
  221. Operand structAddressOffset = context.Load(OperandType.I32, Const((ulong)signalStructPtr + StructAddressOffset));
  222. Operand structWriteOffset = context.Load(OperandType.I32, Const((ulong)signalStructPtr + StructWriteOffset));
  223. Operand faultAddress = context.Load(OperandType.I64, context.Add(exceptionRecordPtr, context.ZeroExtend32(OperandType.I64, structAddressOffset)));
  224. Operand writeFlag = context.Load(OperandType.I64, context.Add(exceptionRecordPtr, context.ZeroExtend32(OperandType.I64, structWriteOffset)));
  225. Operand isWrite = context.ICompareNotEqual(writeFlag, Const(0L)); // Normalize to 0/1.
  226. Operand isInRegion = EmitGenericRegionCheck(context, signalStructPtr, faultAddress, isWrite);
  227. Operand endLabel = Label();
  228. // If the region check result is false, then run the next vectored exception handler.
  229. context.BranchIfTrue(endLabel, isInRegion);
  230. context.Return(Const(EXCEPTION_CONTINUE_SEARCH));
  231. context.MarkLabel(endLabel);
  232. // Otherwise, return to execution.
  233. context.Return(Const(EXCEPTION_CONTINUE_EXECUTION));
  234. // Compile and return the function.
  235. ControlFlowGraph cfg = context.GetControlFlowGraph();
  236. OperandType[] argTypes = new OperandType[] { OperandType.I64 };
  237. return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq).Map<VectoredExceptionHandler>();
  238. }
  239. }
  240. }