NativeSignalHandler.cs 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. using ARMeilleure.IntermediateRepresentation;
  2. using ARMeilleure.Translation;
  3. using System;
  4. using System.Runtime.CompilerServices;
  5. using System.Runtime.InteropServices;
  6. using static ARMeilleure.IntermediateRepresentation.Operand.Factory;
  7. namespace ARMeilleure.Signal
  8. {
  9. [StructLayout(LayoutKind.Sequential, Pack = 1)]
  10. struct SignalHandlerRange
  11. {
  12. public int IsActive;
  13. public nuint RangeAddress;
  14. public nuint RangeEndAddress;
  15. public IntPtr ActionPointer;
  16. }
  17. [StructLayout(LayoutKind.Sequential, Pack = 1)]
  18. struct SignalHandlerConfig
  19. {
  20. /// <summary>
  21. /// The byte offset of the faulting address in the SigInfo or ExceptionRecord struct.
  22. /// </summary>
  23. public int StructAddressOffset;
  24. /// <summary>
  25. /// The byte offset of the write flag in the SigInfo or ExceptionRecord struct.
  26. /// </summary>
  27. public int StructWriteOffset;
  28. /// <summary>
  29. /// The sigaction handler that was registered before this one. (unix only)
  30. /// </summary>
  31. public nuint UnixOldSigaction;
  32. /// <summary>
  33. /// The type of the previous sigaction. True for the 3 argument variant. (unix only)
  34. /// </summary>
  35. public int UnixOldSigaction3Arg;
  36. public SignalHandlerRange Range0;
  37. public SignalHandlerRange Range1;
  38. public SignalHandlerRange Range2;
  39. public SignalHandlerRange Range3;
  40. public SignalHandlerRange Range4;
  41. public SignalHandlerRange Range5;
  42. public SignalHandlerRange Range6;
  43. public SignalHandlerRange Range7;
  44. }
  45. public static class NativeSignalHandler
  46. {
  47. private delegate void UnixExceptionHandler(int sig, IntPtr info, IntPtr ucontext);
  48. [UnmanagedFunctionPointer(CallingConvention.Winapi)]
  49. private delegate int VectoredExceptionHandler(IntPtr exceptionInfo);
  50. private const int MaxTrackedRanges = 8;
  51. private const int StructAddressOffset = 0;
  52. private const int StructWriteOffset = 4;
  53. private const int UnixOldSigaction = 8;
  54. private const int UnixOldSigaction3Arg = 16;
  55. private const int RangeOffset = 20;
  56. private const int EXCEPTION_CONTINUE_SEARCH = 0;
  57. private const int EXCEPTION_CONTINUE_EXECUTION = -1;
  58. private const uint EXCEPTION_ACCESS_VIOLATION = 0xc0000005;
  59. private const ulong PageSize = 0x1000;
  60. private const ulong PageMask = PageSize - 1;
  61. private static IntPtr _handlerConfig;
  62. private static IntPtr _signalHandlerPtr;
  63. private static IntPtr _signalHandlerHandle;
  64. private static readonly object _lock = new object();
  65. private static bool _initialized;
  66. static NativeSignalHandler()
  67. {
  68. _handlerConfig = Marshal.AllocHGlobal(Unsafe.SizeOf<SignalHandlerConfig>());
  69. ref SignalHandlerConfig config = ref GetConfigRef();
  70. config = new SignalHandlerConfig();
  71. }
  72. public static void InitializeSignalHandler()
  73. {
  74. if (_initialized) return;
  75. lock (_lock)
  76. {
  77. if (_initialized) return;
  78. bool unix = OperatingSystem.IsLinux() || OperatingSystem.IsMacOS();
  79. ref SignalHandlerConfig config = ref GetConfigRef();
  80. if (unix)
  81. {
  82. // Unix siginfo struct locations.
  83. // NOTE: These are incredibly likely to be different between kernel version and architectures.
  84. config.StructAddressOffset = OperatingSystem.IsMacOS() ? 24 : 16; // si_addr
  85. config.StructWriteOffset = 8; // si_code
  86. _signalHandlerPtr = Marshal.GetFunctionPointerForDelegate(GenerateUnixSignalHandler(_handlerConfig));
  87. SigAction old = UnixSignalHandlerRegistration.RegisterExceptionHandler(_signalHandlerPtr);
  88. config.UnixOldSigaction = (nuint)(ulong)old.sa_handler;
  89. config.UnixOldSigaction3Arg = old.sa_flags & 4;
  90. }
  91. else
  92. {
  93. config.StructAddressOffset = 40; // ExceptionInformation1
  94. config.StructWriteOffset = 32; // ExceptionInformation0
  95. _signalHandlerPtr = Marshal.GetFunctionPointerForDelegate(GenerateWindowsSignalHandler(_handlerConfig));
  96. _signalHandlerHandle = WindowsSignalHandlerRegistration.RegisterExceptionHandler(_signalHandlerPtr);
  97. }
  98. _initialized = true;
  99. }
  100. }
  101. private static unsafe ref SignalHandlerConfig GetConfigRef()
  102. {
  103. return ref Unsafe.AsRef<SignalHandlerConfig>((void*)_handlerConfig);
  104. }
  105. public static unsafe bool AddTrackedRegion(nuint address, nuint endAddress, IntPtr action)
  106. {
  107. var ranges = &((SignalHandlerConfig*)_handlerConfig)->Range0;
  108. for (int i = 0; i < MaxTrackedRanges; i++)
  109. {
  110. if (ranges[i].IsActive == 0)
  111. {
  112. ranges[i].RangeAddress = address;
  113. ranges[i].RangeEndAddress = endAddress;
  114. ranges[i].ActionPointer = action;
  115. ranges[i].IsActive = 1;
  116. return true;
  117. }
  118. }
  119. return false;
  120. }
  121. public static unsafe bool RemoveTrackedRegion(nuint address)
  122. {
  123. var ranges = &((SignalHandlerConfig*)_handlerConfig)->Range0;
  124. for (int i = 0; i < MaxTrackedRanges; i++)
  125. {
  126. if (ranges[i].IsActive == 1 && ranges[i].RangeAddress == address)
  127. {
  128. ranges[i].IsActive = 0;
  129. return true;
  130. }
  131. }
  132. return false;
  133. }
  134. private static Operand EmitGenericRegionCheck(EmitterContext context, IntPtr signalStructPtr, Operand faultAddress, Operand isWrite)
  135. {
  136. Operand inRegionLocal = context.AllocateLocal(OperandType.I32);
  137. context.Copy(inRegionLocal, Const(0));
  138. Operand endLabel = Label();
  139. for (int i = 0; i < MaxTrackedRanges; i++)
  140. {
  141. ulong rangeBaseOffset = (ulong)(RangeOffset + i * Unsafe.SizeOf<SignalHandlerRange>());
  142. Operand nextLabel = Label();
  143. Operand isActive = context.Load(OperandType.I32, Const((ulong)signalStructPtr + rangeBaseOffset));
  144. context.BranchIfFalse(nextLabel, isActive);
  145. Operand rangeAddress = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 4));
  146. Operand rangeEndAddress = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 12));
  147. // Is the fault address within this tracked region?
  148. Operand inRange = context.BitwiseAnd(
  149. context.ICompare(faultAddress, rangeAddress, Comparison.GreaterOrEqualUI),
  150. context.ICompare(faultAddress, rangeEndAddress, Comparison.LessUI)
  151. );
  152. // Only call tracking if in range.
  153. context.BranchIfFalse(nextLabel, inRange, BasicBlockFrequency.Cold);
  154. context.Copy(inRegionLocal, Const(1));
  155. Operand offset = context.BitwiseAnd(context.Subtract(faultAddress, rangeAddress), Const(~PageMask));
  156. // Call the tracking action, with the pointer's relative offset to the base address.
  157. Operand trackingActionPtr = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 20));
  158. context.Call(trackingActionPtr, OperandType.I32, offset, Const(PageSize), isWrite, Const(0));
  159. context.Branch(endLabel);
  160. context.MarkLabel(nextLabel);
  161. }
  162. context.MarkLabel(endLabel);
  163. return context.Copy(inRegionLocal);
  164. }
  165. private static UnixExceptionHandler GenerateUnixSignalHandler(IntPtr signalStructPtr)
  166. {
  167. EmitterContext context = new EmitterContext();
  168. // (int sig, SigInfo* sigInfo, void* ucontext)
  169. Operand sigInfoPtr = context.LoadArgument(OperandType.I64, 1);
  170. Operand structAddressOffset = context.Load(OperandType.I64, Const((ulong)signalStructPtr + StructAddressOffset));
  171. Operand structWriteOffset = context.Load(OperandType.I64, Const((ulong)signalStructPtr + StructWriteOffset));
  172. Operand faultAddress = context.Load(OperandType.I64, context.Add(sigInfoPtr, context.ZeroExtend32(OperandType.I64, structAddressOffset)));
  173. Operand writeFlag = context.Load(OperandType.I64, context.Add(sigInfoPtr, context.ZeroExtend32(OperandType.I64, structWriteOffset)));
  174. Operand isWrite = context.ICompareNotEqual(writeFlag, Const(0L)); // Normalize to 0/1.
  175. Operand isInRegion = EmitGenericRegionCheck(context, signalStructPtr, faultAddress, isWrite);
  176. Operand endLabel = Label();
  177. context.BranchIfTrue(endLabel, isInRegion);
  178. Operand unixOldSigaction = context.Load(OperandType.I64, Const((ulong)signalStructPtr + UnixOldSigaction));
  179. Operand unixOldSigaction3Arg = context.Load(OperandType.I64, Const((ulong)signalStructPtr + UnixOldSigaction3Arg));
  180. Operand threeArgLabel = Label();
  181. context.BranchIfTrue(threeArgLabel, unixOldSigaction3Arg);
  182. context.Call(unixOldSigaction, OperandType.None, context.LoadArgument(OperandType.I32, 0));
  183. context.Branch(endLabel);
  184. context.MarkLabel(threeArgLabel);
  185. context.Call(unixOldSigaction,
  186. OperandType.None,
  187. context.LoadArgument(OperandType.I32, 0),
  188. sigInfoPtr,
  189. context.LoadArgument(OperandType.I64, 2)
  190. );
  191. context.MarkLabel(endLabel);
  192. context.Return();
  193. ControlFlowGraph cfg = context.GetControlFlowGraph();
  194. OperandType[] argTypes = new OperandType[] { OperandType.I32, OperandType.I64, OperandType.I64 };
  195. return Compiler.Compile(cfg, argTypes, OperandType.None, CompilerOptions.HighCq).Map<UnixExceptionHandler>();
  196. }
  197. private static VectoredExceptionHandler GenerateWindowsSignalHandler(IntPtr signalStructPtr)
  198. {
  199. EmitterContext context = new EmitterContext();
  200. // (ExceptionPointers* exceptionInfo)
  201. Operand exceptionInfoPtr = context.LoadArgument(OperandType.I64, 0);
  202. Operand exceptionRecordPtr = context.Load(OperandType.I64, exceptionInfoPtr);
  203. // First thing's first - this catches a number of exceptions, but we only want access violations.
  204. Operand validExceptionLabel = Label();
  205. Operand exceptionCode = context.Load(OperandType.I32, exceptionRecordPtr);
  206. context.BranchIf(validExceptionLabel, exceptionCode, Const(EXCEPTION_ACCESS_VIOLATION), Comparison.Equal);
  207. context.Return(Const(EXCEPTION_CONTINUE_SEARCH)); // Don't handle this one.
  208. context.MarkLabel(validExceptionLabel);
  209. // Next, read the address of the invalid access, and whether it is a write or not.
  210. Operand structAddressOffset = context.Load(OperandType.I32, Const((ulong)signalStructPtr + StructAddressOffset));
  211. Operand structWriteOffset = context.Load(OperandType.I32, Const((ulong)signalStructPtr + StructWriteOffset));
  212. Operand faultAddress = context.Load(OperandType.I64, context.Add(exceptionRecordPtr, context.ZeroExtend32(OperandType.I64, structAddressOffset)));
  213. Operand writeFlag = context.Load(OperandType.I64, context.Add(exceptionRecordPtr, context.ZeroExtend32(OperandType.I64, structWriteOffset)));
  214. Operand isWrite = context.ICompareNotEqual(writeFlag, Const(0L)); // Normalize to 0/1.
  215. Operand isInRegion = EmitGenericRegionCheck(context, signalStructPtr, faultAddress, isWrite);
  216. Operand endLabel = Label();
  217. // If the region check result is false, then run the next vectored exception handler.
  218. context.BranchIfTrue(endLabel, isInRegion);
  219. context.Return(Const(EXCEPTION_CONTINUE_SEARCH));
  220. context.MarkLabel(endLabel);
  221. // Otherwise, return to execution.
  222. context.Return(Const(EXCEPTION_CONTINUE_EXECUTION));
  223. // Compile and return the function.
  224. ControlFlowGraph cfg = context.GetControlFlowGraph();
  225. OperandType[] argTypes = new OperandType[] { OperandType.I64 };
  226. return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq).Map<VectoredExceptionHandler>();
  227. }
  228. }
  229. }