NativeSignalHandler.cs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. using ARMeilleure.IntermediateRepresentation;
  2. using ARMeilleure.Memory;
  3. using ARMeilleure.Translation;
  4. using ARMeilleure.Translation.Cache;
  5. using System;
  6. using System.Runtime.CompilerServices;
  7. using System.Runtime.InteropServices;
  8. using static ARMeilleure.IntermediateRepresentation.Operand.Factory;
  9. namespace ARMeilleure.Signal
  10. {
  11. [StructLayout(LayoutKind.Sequential, Pack = 1)]
  12. struct SignalHandlerRange
  13. {
  14. public int IsActive;
  15. public nuint RangeAddress;
  16. public nuint RangeEndAddress;
  17. public IntPtr ActionPointer;
  18. }
  19. [StructLayout(LayoutKind.Sequential, Pack = 1)]
  20. struct SignalHandlerConfig
  21. {
  22. /// <summary>
  23. /// The byte offset of the faulting address in the SigInfo or ExceptionRecord struct.
  24. /// </summary>
  25. public int StructAddressOffset;
  26. /// <summary>
  27. /// The byte offset of the write flag in the SigInfo or ExceptionRecord struct.
  28. /// </summary>
  29. public int StructWriteOffset;
  30. /// <summary>
  31. /// The sigaction handler that was registered before this one. (unix only)
  32. /// </summary>
  33. public nuint UnixOldSigaction;
  34. /// <summary>
  35. /// The type of the previous sigaction. True for the 3 argument variant. (unix only)
  36. /// </summary>
  37. public int UnixOldSigaction3Arg;
  38. public SignalHandlerRange Range0;
  39. public SignalHandlerRange Range1;
  40. public SignalHandlerRange Range2;
  41. public SignalHandlerRange Range3;
  42. public SignalHandlerRange Range4;
  43. public SignalHandlerRange Range5;
  44. public SignalHandlerRange Range6;
  45. public SignalHandlerRange Range7;
  46. }
  47. public static class NativeSignalHandler
  48. {
  49. private delegate void UnixExceptionHandler(int sig, IntPtr info, IntPtr ucontext);
  50. [UnmanagedFunctionPointer(CallingConvention.Winapi)]
  51. private delegate int VectoredExceptionHandler(IntPtr exceptionInfo);
  52. private const int MaxTrackedRanges = 8;
  53. private const int StructAddressOffset = 0;
  54. private const int StructWriteOffset = 4;
  55. private const int UnixOldSigaction = 8;
  56. private const int UnixOldSigaction3Arg = 16;
  57. private const int RangeOffset = 20;
  58. private const int EXCEPTION_CONTINUE_SEARCH = 0;
  59. private const int EXCEPTION_CONTINUE_EXECUTION = -1;
  60. private const uint EXCEPTION_ACCESS_VIOLATION = 0xc0000005;
  61. private static ulong _pageSize;
  62. private static ulong _pageMask;
  63. private static IntPtr _handlerConfig;
  64. private static IntPtr _signalHandlerPtr;
  65. private static IntPtr _signalHandlerHandle;
  66. private static readonly object _lock = new object();
  67. private static bool _initialized;
  68. static NativeSignalHandler()
  69. {
  70. _handlerConfig = Marshal.AllocHGlobal(Unsafe.SizeOf<SignalHandlerConfig>());
  71. ref SignalHandlerConfig config = ref GetConfigRef();
  72. config = new SignalHandlerConfig();
  73. }
  74. public static void Initialize(IJitMemoryAllocator allocator)
  75. {
  76. JitCache.Initialize(allocator);
  77. }
  78. public static void InitializeSignalHandler(ulong pageSize, Func<IntPtr, IntPtr, IntPtr> customSignalHandlerFactory = null)
  79. {
  80. if (_initialized) return;
  81. lock (_lock)
  82. {
  83. if (_initialized) return;
  84. _pageSize = pageSize;
  85. _pageMask = pageSize - 1;
  86. ref SignalHandlerConfig config = ref GetConfigRef();
  87. if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS())
  88. {
  89. _signalHandlerPtr = Marshal.GetFunctionPointerForDelegate(GenerateUnixSignalHandler(_handlerConfig));
  90. if (customSignalHandlerFactory != null)
  91. {
  92. _signalHandlerPtr = customSignalHandlerFactory(UnixSignalHandlerRegistration.GetSegfaultExceptionHandler().sa_handler, _signalHandlerPtr);
  93. }
  94. var old = UnixSignalHandlerRegistration.RegisterExceptionHandler(_signalHandlerPtr);
  95. config.UnixOldSigaction = (nuint)(ulong)old.sa_handler;
  96. config.UnixOldSigaction3Arg = old.sa_flags & 4;
  97. }
  98. else
  99. {
  100. config.StructAddressOffset = 40; // ExceptionInformation1
  101. config.StructWriteOffset = 32; // ExceptionInformation0
  102. _signalHandlerPtr = Marshal.GetFunctionPointerForDelegate(GenerateWindowsSignalHandler(_handlerConfig));
  103. if (customSignalHandlerFactory != null)
  104. {
  105. _signalHandlerPtr = customSignalHandlerFactory(IntPtr.Zero, _signalHandlerPtr);
  106. }
  107. _signalHandlerHandle = WindowsSignalHandlerRegistration.RegisterExceptionHandler(_signalHandlerPtr);
  108. }
  109. _initialized = true;
  110. }
  111. }
  112. private static unsafe ref SignalHandlerConfig GetConfigRef()
  113. {
  114. return ref Unsafe.AsRef<SignalHandlerConfig>((void*)_handlerConfig);
  115. }
  116. public static unsafe bool AddTrackedRegion(nuint address, nuint endAddress, IntPtr action)
  117. {
  118. var ranges = &((SignalHandlerConfig*)_handlerConfig)->Range0;
  119. for (int i = 0; i < MaxTrackedRanges; i++)
  120. {
  121. if (ranges[i].IsActive == 0)
  122. {
  123. ranges[i].RangeAddress = address;
  124. ranges[i].RangeEndAddress = endAddress;
  125. ranges[i].ActionPointer = action;
  126. ranges[i].IsActive = 1;
  127. return true;
  128. }
  129. }
  130. return false;
  131. }
  132. public static unsafe bool RemoveTrackedRegion(nuint address)
  133. {
  134. var ranges = &((SignalHandlerConfig*)_handlerConfig)->Range0;
  135. for (int i = 0; i < MaxTrackedRanges; i++)
  136. {
  137. if (ranges[i].IsActive == 1 && ranges[i].RangeAddress == address)
  138. {
  139. ranges[i].IsActive = 0;
  140. return true;
  141. }
  142. }
  143. return false;
  144. }
  145. private static Operand EmitGenericRegionCheck(EmitterContext context, IntPtr signalStructPtr, Operand faultAddress, Operand isWrite)
  146. {
  147. Operand inRegionLocal = context.AllocateLocal(OperandType.I32);
  148. context.Copy(inRegionLocal, Const(0));
  149. Operand endLabel = Label();
  150. for (int i = 0; i < MaxTrackedRanges; i++)
  151. {
  152. ulong rangeBaseOffset = (ulong)(RangeOffset + i * Unsafe.SizeOf<SignalHandlerRange>());
  153. Operand nextLabel = Label();
  154. Operand isActive = context.Load(OperandType.I32, Const((ulong)signalStructPtr + rangeBaseOffset));
  155. context.BranchIfFalse(nextLabel, isActive);
  156. Operand rangeAddress = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 4));
  157. Operand rangeEndAddress = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 12));
  158. // Is the fault address within this tracked region?
  159. Operand inRange = context.BitwiseAnd(
  160. context.ICompare(faultAddress, rangeAddress, Comparison.GreaterOrEqualUI),
  161. context.ICompare(faultAddress, rangeEndAddress, Comparison.LessUI)
  162. );
  163. // Only call tracking if in range.
  164. context.BranchIfFalse(nextLabel, inRange, BasicBlockFrequency.Cold);
  165. Operand offset = context.BitwiseAnd(context.Subtract(faultAddress, rangeAddress), Const(~_pageMask));
  166. // Call the tracking action, with the pointer's relative offset to the base address.
  167. Operand trackingActionPtr = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 20));
  168. context.Copy(inRegionLocal, Const(0));
  169. Operand skipActionLabel = Label();
  170. // Tracking action should be non-null to call it, otherwise assume false return.
  171. context.BranchIfFalse(skipActionLabel, trackingActionPtr);
  172. Operand result = context.Call(trackingActionPtr, OperandType.I32, offset, Const(_pageSize), isWrite);
  173. context.Copy(inRegionLocal, result);
  174. context.MarkLabel(skipActionLabel);
  175. // If the tracking action returns false or does not exist, it might be an invalid access due to a partial overlap on Windows.
  176. if (OperatingSystem.IsWindows())
  177. {
  178. context.BranchIfTrue(endLabel, inRegionLocal);
  179. context.Copy(inRegionLocal, WindowsPartialUnmapHandler.EmitRetryFromAccessViolation(context));
  180. }
  181. context.Branch(endLabel);
  182. context.MarkLabel(nextLabel);
  183. }
  184. context.MarkLabel(endLabel);
  185. return context.Copy(inRegionLocal);
  186. }
  187. private static Operand GenerateUnixFaultAddress(EmitterContext context, Operand sigInfoPtr)
  188. {
  189. ulong structAddressOffset = OperatingSystem.IsMacOS() ? 24ul : 16ul; // si_addr
  190. return context.Load(OperandType.I64, context.Add(sigInfoPtr, Const(structAddressOffset)));
  191. }
  192. private static Operand GenerateUnixWriteFlag(EmitterContext context, Operand ucontextPtr)
  193. {
  194. if (OperatingSystem.IsMacOS())
  195. {
  196. const ulong mcontextOffset = 48; // uc_mcontext
  197. Operand ctxPtr = context.Load(OperandType.I64, context.Add(ucontextPtr, Const(mcontextOffset)));
  198. if (RuntimeInformation.ProcessArchitecture == Architecture.Arm64)
  199. {
  200. const ulong esrOffset = 8; // __es.__esr
  201. Operand esr = context.Load(OperandType.I64, context.Add(ctxPtr, Const(esrOffset)));
  202. return context.BitwiseAnd(esr, Const(0x40ul));
  203. }
  204. if (RuntimeInformation.ProcessArchitecture == Architecture.X64)
  205. {
  206. const ulong errOffset = 4; // __es.__err
  207. Operand err = context.Load(OperandType.I64, context.Add(ctxPtr, Const(errOffset)));
  208. return context.BitwiseAnd(err, Const(2ul));
  209. }
  210. }
  211. else if (OperatingSystem.IsLinux())
  212. {
  213. if (RuntimeInformation.ProcessArchitecture == Architecture.Arm64)
  214. {
  215. Operand auxPtr = context.AllocateLocal(OperandType.I64);
  216. Operand loopLabel = Label();
  217. Operand successLabel = Label();
  218. const ulong auxOffset = 464; // uc_mcontext.__reserved
  219. const uint esrMagic = 0x45535201;
  220. context.Copy(auxPtr, context.Add(ucontextPtr, Const(auxOffset)));
  221. context.MarkLabel(loopLabel);
  222. // _aarch64_ctx::magic
  223. Operand magic = context.Load(OperandType.I32, auxPtr);
  224. // _aarch64_ctx::size
  225. Operand size = context.Load(OperandType.I32, context.Add(auxPtr, Const(4ul)));
  226. context.BranchIf(successLabel, magic, Const(esrMagic), Comparison.Equal);
  227. context.Copy(auxPtr, context.Add(auxPtr, context.ZeroExtend32(OperandType.I64, size)));
  228. context.Branch(loopLabel);
  229. context.MarkLabel(successLabel);
  230. // esr_context::esr
  231. Operand esr = context.Load(OperandType.I64, context.Add(auxPtr, Const(8ul)));
  232. return context.BitwiseAnd(esr, Const(0x40ul));
  233. }
  234. if (RuntimeInformation.ProcessArchitecture == Architecture.X64)
  235. {
  236. const int errOffset = 192; // uc_mcontext.gregs[REG_ERR]
  237. Operand err = context.Load(OperandType.I64, context.Add(ucontextPtr, Const(errOffset)));
  238. return context.BitwiseAnd(err, Const(2ul));
  239. }
  240. }
  241. throw new PlatformNotSupportedException();
  242. }
  243. private static UnixExceptionHandler GenerateUnixSignalHandler(IntPtr signalStructPtr)
  244. {
  245. EmitterContext context = new EmitterContext();
  246. // (int sig, SigInfo* sigInfo, void* ucontext)
  247. Operand sigInfoPtr = context.LoadArgument(OperandType.I64, 1);
  248. Operand ucontextPtr = context.LoadArgument(OperandType.I64, 2);
  249. Operand faultAddress = GenerateUnixFaultAddress(context, sigInfoPtr);
  250. Operand writeFlag = GenerateUnixWriteFlag(context, ucontextPtr);
  251. Operand isWrite = context.ICompareNotEqual(writeFlag, Const(0L)); // Normalize to 0/1.
  252. Operand isInRegion = EmitGenericRegionCheck(context, signalStructPtr, faultAddress, isWrite);
  253. Operand endLabel = Label();
  254. context.BranchIfTrue(endLabel, isInRegion);
  255. Operand unixOldSigaction = context.Load(OperandType.I64, Const((ulong)signalStructPtr + UnixOldSigaction));
  256. Operand unixOldSigaction3Arg = context.Load(OperandType.I64, Const((ulong)signalStructPtr + UnixOldSigaction3Arg));
  257. Operand threeArgLabel = Label();
  258. context.BranchIfTrue(threeArgLabel, unixOldSigaction3Arg);
  259. context.Call(unixOldSigaction, OperandType.None, context.LoadArgument(OperandType.I32, 0));
  260. context.Branch(endLabel);
  261. context.MarkLabel(threeArgLabel);
  262. context.Call(unixOldSigaction,
  263. OperandType.None,
  264. context.LoadArgument(OperandType.I32, 0),
  265. sigInfoPtr,
  266. context.LoadArgument(OperandType.I64, 2)
  267. );
  268. context.MarkLabel(endLabel);
  269. context.Return();
  270. ControlFlowGraph cfg = context.GetControlFlowGraph();
  271. OperandType[] argTypes = new OperandType[] { OperandType.I32, OperandType.I64, OperandType.I64 };
  272. return Compiler.Compile(cfg, argTypes, OperandType.None, CompilerOptions.HighCq, RuntimeInformation.ProcessArchitecture).Map<UnixExceptionHandler>();
  273. }
  274. private static VectoredExceptionHandler GenerateWindowsSignalHandler(IntPtr signalStructPtr)
  275. {
  276. EmitterContext context = new EmitterContext();
  277. // (ExceptionPointers* exceptionInfo)
  278. Operand exceptionInfoPtr = context.LoadArgument(OperandType.I64, 0);
  279. Operand exceptionRecordPtr = context.Load(OperandType.I64, exceptionInfoPtr);
  280. // First thing's first - this catches a number of exceptions, but we only want access violations.
  281. Operand validExceptionLabel = Label();
  282. Operand exceptionCode = context.Load(OperandType.I32, exceptionRecordPtr);
  283. context.BranchIf(validExceptionLabel, exceptionCode, Const(EXCEPTION_ACCESS_VIOLATION), Comparison.Equal);
  284. context.Return(Const(EXCEPTION_CONTINUE_SEARCH)); // Don't handle this one.
  285. context.MarkLabel(validExceptionLabel);
  286. // Next, read the address of the invalid access, and whether it is a write or not.
  287. Operand structAddressOffset = context.Load(OperandType.I32, Const((ulong)signalStructPtr + StructAddressOffset));
  288. Operand structWriteOffset = context.Load(OperandType.I32, Const((ulong)signalStructPtr + StructWriteOffset));
  289. Operand faultAddress = context.Load(OperandType.I64, context.Add(exceptionRecordPtr, context.ZeroExtend32(OperandType.I64, structAddressOffset)));
  290. Operand writeFlag = context.Load(OperandType.I64, context.Add(exceptionRecordPtr, context.ZeroExtend32(OperandType.I64, structWriteOffset)));
  291. Operand isWrite = context.ICompareNotEqual(writeFlag, Const(0L)); // Normalize to 0/1.
  292. Operand isInRegion = EmitGenericRegionCheck(context, signalStructPtr, faultAddress, isWrite);
  293. Operand endLabel = Label();
  294. // If the region check result is false, then run the next vectored exception handler.
  295. context.BranchIfTrue(endLabel, isInRegion);
  296. context.Return(Const(EXCEPTION_CONTINUE_SEARCH));
  297. context.MarkLabel(endLabel);
  298. // Otherwise, return to execution.
  299. context.Return(Const(EXCEPTION_CONTINUE_EXECUTION));
  300. // Compile and return the function.
  301. ControlFlowGraph cfg = context.GetControlFlowGraph();
  302. OperandType[] argTypes = new OperandType[] { OperandType.I64 };
  303. return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq, RuntimeInformation.ProcessArchitecture).Map<VectoredExceptionHandler>();
  304. }
  305. }
  306. }