NativeSignalHandler.cs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. using ARMeilleure.IntermediateRepresentation;
  2. using ARMeilleure.Memory;
  3. using ARMeilleure.Translation;
  4. using ARMeilleure.Translation.Cache;
  5. using System;
  6. using System.Runtime.CompilerServices;
  7. using System.Runtime.InteropServices;
  8. using static ARMeilleure.IntermediateRepresentation.Operand.Factory;
  9. namespace ARMeilleure.Signal
  10. {
  11. [StructLayout(LayoutKind.Sequential, Pack = 1)]
  12. struct SignalHandlerRange
  13. {
  14. public int IsActive;
  15. public nuint RangeAddress;
  16. public nuint RangeEndAddress;
  17. public IntPtr ActionPointer;
  18. }
  19. [StructLayout(LayoutKind.Sequential, Pack = 1)]
  20. struct SignalHandlerConfig
  21. {
  22. /// <summary>
  23. /// The byte offset of the faulting address in the SigInfo or ExceptionRecord struct.
  24. /// </summary>
  25. public int StructAddressOffset;
  26. /// <summary>
  27. /// The byte offset of the write flag in the SigInfo or ExceptionRecord struct.
  28. /// </summary>
  29. public int StructWriteOffset;
  30. /// <summary>
  31. /// The sigaction handler that was registered before this one. (unix only)
  32. /// </summary>
  33. public nuint UnixOldSigaction;
  34. /// <summary>
  35. /// The type of the previous sigaction. True for the 3 argument variant. (unix only)
  36. /// </summary>
  37. public int UnixOldSigaction3Arg;
  38. public SignalHandlerRange Range0;
  39. public SignalHandlerRange Range1;
  40. public SignalHandlerRange Range2;
  41. public SignalHandlerRange Range3;
  42. public SignalHandlerRange Range4;
  43. public SignalHandlerRange Range5;
  44. public SignalHandlerRange Range6;
  45. public SignalHandlerRange Range7;
  46. }
  47. public static class NativeSignalHandler
  48. {
  49. private delegate void UnixExceptionHandler(int sig, IntPtr info, IntPtr ucontext);
  50. [UnmanagedFunctionPointer(CallingConvention.Winapi)]
  51. private delegate int VectoredExceptionHandler(IntPtr exceptionInfo);
  52. private const int MaxTrackedRanges = 8;
  53. private const int StructAddressOffset = 0;
  54. private const int StructWriteOffset = 4;
  55. private const int UnixOldSigaction = 8;
  56. private const int UnixOldSigaction3Arg = 16;
  57. private const int RangeOffset = 20;
  58. private const int EXCEPTION_CONTINUE_SEARCH = 0;
  59. private const int EXCEPTION_CONTINUE_EXECUTION = -1;
  60. private const uint EXCEPTION_ACCESS_VIOLATION = 0xc0000005;
  61. private static ulong _pageSize;
  62. private static ulong _pageMask;
  63. private static IntPtr _handlerConfig;
  64. private static IntPtr _signalHandlerPtr;
  65. private static IntPtr _signalHandlerHandle;
  66. private static readonly object _lock = new object();
  67. private static bool _initialized;
  68. static NativeSignalHandler()
  69. {
  70. _handlerConfig = Marshal.AllocHGlobal(Unsafe.SizeOf<SignalHandlerConfig>());
  71. ref SignalHandlerConfig config = ref GetConfigRef();
  72. config = new SignalHandlerConfig();
  73. }
  74. public static void Initialize(IJitMemoryAllocator allocator)
  75. {
  76. JitCache.Initialize(allocator);
  77. }
  78. public static void InitializeSignalHandler(ulong pageSize, Func<IntPtr, IntPtr, IntPtr> customSignalHandlerFactory = null)
  79. {
  80. if (_initialized) return;
  81. lock (_lock)
  82. {
  83. if (_initialized) return;
  84. _pageSize = pageSize;
  85. _pageMask = pageSize - 1;
  86. ref SignalHandlerConfig config = ref GetConfigRef();
  87. if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
  88. {
  89. // Unix siginfo struct locations.
  90. // NOTE: These are incredibly likely to be different between kernel version and architectures.
  91. config.StructAddressOffset = OperatingSystem.IsMacOS() ? 24 : 16; // si_addr
  92. config.StructWriteOffset = 8; // si_code
  93. _signalHandlerPtr = Marshal.GetFunctionPointerForDelegate(GenerateUnixSignalHandler(_handlerConfig));
  94. if (customSignalHandlerFactory != null)
  95. {
  96. _signalHandlerPtr = customSignalHandlerFactory(UnixSignalHandlerRegistration.GetSegfaultExceptionHandler().sa_handler, _signalHandlerPtr);
  97. }
  98. var old = UnixSignalHandlerRegistration.RegisterExceptionHandler(_signalHandlerPtr);
  99. config.UnixOldSigaction = (nuint)(ulong)old.sa_handler;
  100. config.UnixOldSigaction3Arg = old.sa_flags & 4;
  101. }
  102. else
  103. {
  104. config.StructAddressOffset = 40; // ExceptionInformation1
  105. config.StructWriteOffset = 32; // ExceptionInformation0
  106. _signalHandlerPtr = Marshal.GetFunctionPointerForDelegate(GenerateWindowsSignalHandler(_handlerConfig));
  107. if (customSignalHandlerFactory != null)
  108. {
  109. _signalHandlerPtr = customSignalHandlerFactory(IntPtr.Zero, _signalHandlerPtr);
  110. }
  111. _signalHandlerHandle = WindowsSignalHandlerRegistration.RegisterExceptionHandler(_signalHandlerPtr);
  112. }
  113. _initialized = true;
  114. }
  115. }
  116. private static unsafe ref SignalHandlerConfig GetConfigRef()
  117. {
  118. return ref Unsafe.AsRef<SignalHandlerConfig>((void*)_handlerConfig);
  119. }
  120. public static unsafe bool AddTrackedRegion(nuint address, nuint endAddress, IntPtr action)
  121. {
  122. var ranges = &((SignalHandlerConfig*)_handlerConfig)->Range0;
  123. for (int i = 0; i < MaxTrackedRanges; i++)
  124. {
  125. if (ranges[i].IsActive == 0)
  126. {
  127. ranges[i].RangeAddress = address;
  128. ranges[i].RangeEndAddress = endAddress;
  129. ranges[i].ActionPointer = action;
  130. ranges[i].IsActive = 1;
  131. return true;
  132. }
  133. }
  134. return false;
  135. }
  136. public static unsafe bool RemoveTrackedRegion(nuint address)
  137. {
  138. var ranges = &((SignalHandlerConfig*)_handlerConfig)->Range0;
  139. for (int i = 0; i < MaxTrackedRanges; i++)
  140. {
  141. if (ranges[i].IsActive == 1 && ranges[i].RangeAddress == address)
  142. {
  143. ranges[i].IsActive = 0;
  144. return true;
  145. }
  146. }
  147. return false;
  148. }
  149. private static Operand EmitGenericRegionCheck(EmitterContext context, IntPtr signalStructPtr, Operand faultAddress, Operand isWrite)
  150. {
  151. Operand inRegionLocal = context.AllocateLocal(OperandType.I32);
  152. context.Copy(inRegionLocal, Const(0));
  153. Operand endLabel = Label();
  154. for (int i = 0; i < MaxTrackedRanges; i++)
  155. {
  156. ulong rangeBaseOffset = (ulong)(RangeOffset + i * Unsafe.SizeOf<SignalHandlerRange>());
  157. Operand nextLabel = Label();
  158. Operand isActive = context.Load(OperandType.I32, Const((ulong)signalStructPtr + rangeBaseOffset));
  159. context.BranchIfFalse(nextLabel, isActive);
  160. Operand rangeAddress = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 4));
  161. Operand rangeEndAddress = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 12));
  162. // Is the fault address within this tracked region?
  163. Operand inRange = context.BitwiseAnd(
  164. context.ICompare(faultAddress, rangeAddress, Comparison.GreaterOrEqualUI),
  165. context.ICompare(faultAddress, rangeEndAddress, Comparison.LessUI)
  166. );
  167. // Only call tracking if in range.
  168. context.BranchIfFalse(nextLabel, inRange, BasicBlockFrequency.Cold);
  169. Operand offset = context.BitwiseAnd(context.Subtract(faultAddress, rangeAddress), Const(~_pageMask));
  170. // Call the tracking action, with the pointer's relative offset to the base address.
  171. Operand trackingActionPtr = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 20));
  172. context.Copy(inRegionLocal, Const(0));
  173. Operand skipActionLabel = Label();
  174. // Tracking action should be non-null to call it, otherwise assume false return.
  175. context.BranchIfFalse(skipActionLabel, trackingActionPtr);
  176. Operand result = context.Call(trackingActionPtr, OperandType.I32, offset, Const(_pageSize), isWrite, Const(0));
  177. context.Copy(inRegionLocal, result);
  178. context.MarkLabel(skipActionLabel);
  179. // If the tracking action returns false or does not exist, it might be an invalid access due to a partial overlap on Windows.
  180. if (OperatingSystem.IsWindows())
  181. {
  182. context.BranchIfTrue(endLabel, inRegionLocal);
  183. context.Copy(inRegionLocal, WindowsPartialUnmapHandler.EmitRetryFromAccessViolation(context));
  184. }
  185. context.Branch(endLabel);
  186. context.MarkLabel(nextLabel);
  187. }
  188. context.MarkLabel(endLabel);
  189. return context.Copy(inRegionLocal);
  190. }
  191. private static UnixExceptionHandler GenerateUnixSignalHandler(IntPtr signalStructPtr)
  192. {
  193. EmitterContext context = new EmitterContext();
  194. // (int sig, SigInfo* sigInfo, void* ucontext)
  195. Operand sigInfoPtr = context.LoadArgument(OperandType.I64, 1);
  196. Operand structAddressOffset = context.Load(OperandType.I64, Const((ulong)signalStructPtr + StructAddressOffset));
  197. Operand structWriteOffset = context.Load(OperandType.I64, Const((ulong)signalStructPtr + StructWriteOffset));
  198. Operand faultAddress = context.Load(OperandType.I64, context.Add(sigInfoPtr, context.ZeroExtend32(OperandType.I64, structAddressOffset)));
  199. Operand writeFlag = context.Load(OperandType.I64, context.Add(sigInfoPtr, context.ZeroExtend32(OperandType.I64, structWriteOffset)));
  200. Operand isWrite = context.ICompareNotEqual(writeFlag, Const(0L)); // Normalize to 0/1.
  201. Operand isInRegion = EmitGenericRegionCheck(context, signalStructPtr, faultAddress, isWrite);
  202. Operand endLabel = Label();
  203. context.BranchIfTrue(endLabel, isInRegion);
  204. Operand unixOldSigaction = context.Load(OperandType.I64, Const((ulong)signalStructPtr + UnixOldSigaction));
  205. Operand unixOldSigaction3Arg = context.Load(OperandType.I64, Const((ulong)signalStructPtr + UnixOldSigaction3Arg));
  206. Operand threeArgLabel = Label();
  207. context.BranchIfTrue(threeArgLabel, unixOldSigaction3Arg);
  208. context.Call(unixOldSigaction, OperandType.None, context.LoadArgument(OperandType.I32, 0));
  209. context.Branch(endLabel);
  210. context.MarkLabel(threeArgLabel);
  211. context.Call(unixOldSigaction,
  212. OperandType.None,
  213. context.LoadArgument(OperandType.I32, 0),
  214. sigInfoPtr,
  215. context.LoadArgument(OperandType.I64, 2)
  216. );
  217. context.MarkLabel(endLabel);
  218. context.Return();
  219. ControlFlowGraph cfg = context.GetControlFlowGraph();
  220. OperandType[] argTypes = new OperandType[] { OperandType.I32, OperandType.I64, OperandType.I64 };
  221. return Compiler.Compile(cfg, argTypes, OperandType.None, CompilerOptions.HighCq, RuntimeInformation.ProcessArchitecture).Map<UnixExceptionHandler>();
  222. }
  223. private static VectoredExceptionHandler GenerateWindowsSignalHandler(IntPtr signalStructPtr)
  224. {
  225. EmitterContext context = new EmitterContext();
  226. // (ExceptionPointers* exceptionInfo)
  227. Operand exceptionInfoPtr = context.LoadArgument(OperandType.I64, 0);
  228. Operand exceptionRecordPtr = context.Load(OperandType.I64, exceptionInfoPtr);
  229. // First thing's first - this catches a number of exceptions, but we only want access violations.
  230. Operand validExceptionLabel = Label();
  231. Operand exceptionCode = context.Load(OperandType.I32, exceptionRecordPtr);
  232. context.BranchIf(validExceptionLabel, exceptionCode, Const(EXCEPTION_ACCESS_VIOLATION), Comparison.Equal);
  233. context.Return(Const(EXCEPTION_CONTINUE_SEARCH)); // Don't handle this one.
  234. context.MarkLabel(validExceptionLabel);
  235. // Next, read the address of the invalid access, and whether it is a write or not.
  236. Operand structAddressOffset = context.Load(OperandType.I32, Const((ulong)signalStructPtr + StructAddressOffset));
  237. Operand structWriteOffset = context.Load(OperandType.I32, Const((ulong)signalStructPtr + StructWriteOffset));
  238. Operand faultAddress = context.Load(OperandType.I64, context.Add(exceptionRecordPtr, context.ZeroExtend32(OperandType.I64, structAddressOffset)));
  239. Operand writeFlag = context.Load(OperandType.I64, context.Add(exceptionRecordPtr, context.ZeroExtend32(OperandType.I64, structWriteOffset)));
  240. Operand isWrite = context.ICompareNotEqual(writeFlag, Const(0L)); // Normalize to 0/1.
  241. Operand isInRegion = EmitGenericRegionCheck(context, signalStructPtr, faultAddress, isWrite);
  242. Operand endLabel = Label();
  243. // If the region check result is false, then run the next vectored exception handler.
  244. context.BranchIfTrue(endLabel, isInRegion);
  245. context.Return(Const(EXCEPTION_CONTINUE_SEARCH));
  246. context.MarkLabel(endLabel);
  247. // Otherwise, return to execution.
  248. context.Return(Const(EXCEPTION_CONTINUE_EXECUTION));
  249. // Compile and return the function.
  250. ControlFlowGraph cfg = context.GetControlFlowGraph();
  251. OperandType[] argTypes = new OperandType[] { OperandType.I64 };
  252. return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq, RuntimeInformation.ProcessArchitecture).Map<VectoredExceptionHandler>();
  253. }
  254. }
  255. }