GlobalToStorage.cs 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175
  1. using Ryujinx.Graphics.Shader.IntermediateRepresentation;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper;
  7. namespace Ryujinx.Graphics.Shader.Translation.Optimizations
  8. {
  9. static class GlobalToStorage
  10. {
  11. private const int DriverReservedCb = 0;
  12. enum LsMemoryType
  13. {
  14. Local,
  15. Shared,
  16. }
  17. private class GtsContext
  18. {
  19. private readonly struct Entry
  20. {
  21. public readonly int FunctionId;
  22. public readonly Instruction Inst;
  23. public readonly StorageKind StorageKind;
  24. public readonly bool IsMultiTarget;
  25. public readonly IReadOnlyList<uint> TargetCbs;
  26. public Entry(
  27. int functionId,
  28. Instruction inst,
  29. StorageKind storageKind,
  30. bool isMultiTarget,
  31. IReadOnlyList<uint> targetCbs)
  32. {
  33. FunctionId = functionId;
  34. Inst = inst;
  35. StorageKind = storageKind;
  36. IsMultiTarget = isMultiTarget;
  37. TargetCbs = targetCbs;
  38. }
  39. }
  40. private readonly struct LsKey : IEquatable<LsKey>
  41. {
  42. public readonly Operand BaseOffset;
  43. public readonly int ConstOffset;
  44. public readonly LsMemoryType Type;
  45. public LsKey(Operand baseOffset, int constOffset, LsMemoryType type)
  46. {
  47. BaseOffset = baseOffset;
  48. ConstOffset = constOffset;
  49. Type = type;
  50. }
  51. public override int GetHashCode()
  52. {
  53. return HashCode.Combine(BaseOffset, ConstOffset, Type);
  54. }
  55. public override bool Equals(object obj)
  56. {
  57. return obj is LsKey other && Equals(other);
  58. }
  59. public bool Equals(LsKey other)
  60. {
  61. return other.BaseOffset == BaseOffset && other.ConstOffset == ConstOffset && other.Type == Type;
  62. }
  63. }
  64. private readonly List<Entry> _entries;
  65. private readonly Dictionary<LsKey, Dictionary<uint, SearchResult>> _sharedEntries;
  66. private readonly HelperFunctionManager _hfm;
  67. public GtsContext(HelperFunctionManager hfm)
  68. {
  69. _entries = [];
  70. _sharedEntries = new Dictionary<LsKey, Dictionary<uint, SearchResult>>();
  71. _hfm = hfm;
  72. }
  73. public int AddFunction(Operation baseOp, bool isMultiTarget, IReadOnlyList<uint> targetCbs, Function function)
  74. {
  75. int functionId = _hfm.AddFunction(function);
  76. _entries.Add(new Entry(functionId, baseOp.Inst, baseOp.StorageKind, isMultiTarget, targetCbs));
  77. return functionId;
  78. }
  79. public bool TryGetFunctionId(Operation baseOp, bool isMultiTarget, IReadOnlyList<uint> targetCbs, out int functionId)
  80. {
  81. foreach (Entry entry in _entries)
  82. {
  83. if (entry.Inst != baseOp.Inst ||
  84. entry.StorageKind != baseOp.StorageKind ||
  85. entry.IsMultiTarget != isMultiTarget ||
  86. entry.TargetCbs.Count != targetCbs.Count)
  87. {
  88. continue;
  89. }
  90. bool allEqual = true;
  91. for (int index = 0; index < targetCbs.Count; index++)
  92. {
  93. if (targetCbs[index] != entry.TargetCbs[index])
  94. {
  95. allEqual = false;
  96. break;
  97. }
  98. }
  99. if (allEqual)
  100. {
  101. functionId = entry.FunctionId;
  102. return true;
  103. }
  104. }
  105. functionId = -1;
  106. return false;
  107. }
  108. public void AddMemoryTargetCb(LsMemoryType type, Operand baseOffset, int constOffset, uint targetCb, SearchResult result)
  109. {
  110. LsKey key = new(baseOffset, constOffset, type);
  111. if (!_sharedEntries.TryGetValue(key, out Dictionary<uint, SearchResult> targetCbs))
  112. {
  113. // No entry with this base offset, create a new one.
  114. targetCbs = new Dictionary<uint, SearchResult>() { { targetCb, result } };
  115. _sharedEntries.Add(key, targetCbs);
  116. }
  117. else if (targetCbs.TryGetValue(targetCb, out SearchResult existingResult))
  118. {
  119. // If our entry already exists, but does not match the new result,
  120. // we set the offset to null to indicate there are multiple possible offsets.
  121. // This will be used on the multi-target access that does not need to know the offset.
  122. if (existingResult.Offset != null &&
  123. (existingResult.Offset != result.Offset ||
  124. existingResult.ConstOffset != result.ConstOffset))
  125. {
  126. targetCbs[targetCb] = new SearchResult(result.SbCbSlot, result.SbCbOffset);
  127. }
  128. }
  129. else
  130. {
  131. // An entry for this base offset already exists, but not for the specified
  132. // constant buffer region where the storage buffer base address and size
  133. // comes from.
  134. targetCbs.Add(targetCb, result);
  135. }
  136. }
  137. public bool TryGetMemoryTargetCb(LsMemoryType type, Operand baseOffset, int constOffset, out SearchResult result)
  138. {
  139. LsKey key = new(baseOffset, constOffset, type);
  140. if (_sharedEntries.TryGetValue(key, out Dictionary<uint, SearchResult> targetCbs) && targetCbs.Count == 1)
  141. {
  142. SearchResult candidateResult = targetCbs.Values.First();
  143. if (candidateResult.Found)
  144. {
  145. result = candidateResult;
  146. return true;
  147. }
  148. }
  149. result = default;
  150. return false;
  151. }
  152. }
  153. private readonly struct SearchResult
  154. {
  155. public static SearchResult NotFound => new(-1, 0);
  156. public bool Found => SbCbSlot != -1;
  157. public int SbCbSlot { get; }
  158. public int SbCbOffset { get; }
  159. public Operand Offset { get; }
  160. public int ConstOffset { get; }
  161. public SearchResult(int sbCbSlot, int sbCbOffset)
  162. {
  163. SbCbSlot = sbCbSlot;
  164. SbCbOffset = sbCbOffset;
  165. }
  166. public SearchResult(int sbCbSlot, int sbCbOffset, Operand offset, int constOffset = 0)
  167. {
  168. SbCbSlot = sbCbSlot;
  169. SbCbOffset = sbCbOffset;
  170. Offset = offset;
  171. ConstOffset = constOffset;
  172. }
  173. }
  174. public static void RunPass(
  175. HelperFunctionManager hfm,
  176. BasicBlock[] blocks,
  177. ResourceManager resourceManager,
  178. IGpuAccessor gpuAccessor,
  179. TargetLanguage targetLanguage)
  180. {
  181. GtsContext gtsContext = new(hfm);
  182. foreach (BasicBlock block in blocks)
  183. {
  184. for (LinkedListNode<INode> node = block.Operations.First; node != null; node = node.Next)
  185. {
  186. if (node.Value is not Operation operation)
  187. {
  188. continue;
  189. }
  190. if (IsGlobalMemory(operation.StorageKind))
  191. {
  192. LinkedListNode<INode> nextNode = ReplaceGlobalMemoryWithStorage(
  193. gtsContext,
  194. resourceManager,
  195. gpuAccessor,
  196. targetLanguage,
  197. block,
  198. node);
  199. if (nextNode == null)
  200. {
  201. // The returned value being null means that the global memory replacement failed,
  202. // so we just make loads read 0 and stores do nothing.
  203. gpuAccessor.Log($"Failed to reserve storage buffer for global memory operation \"{operation.Inst}\".");
  204. if (operation.Dest != null)
  205. {
  206. operation.TurnIntoCopy(Const(0));
  207. }
  208. else
  209. {
  210. Utils.DeleteNode(node, operation);
  211. }
  212. }
  213. else
  214. {
  215. node = nextNode;
  216. }
  217. }
  218. else if (operation.Inst == Instruction.Store &&
  219. (operation.StorageKind == StorageKind.SharedMemory ||
  220. operation.StorageKind == StorageKind.LocalMemory))
  221. {
  222. // The NVIDIA compiler can sometimes use shared or local memory as temporary
  223. // storage to place the base address and size on, so we need
  224. // to be able to find such information stored in memory too.
  225. if (TryGetMemoryOffsets(operation, out LsMemoryType type, out Operand baseOffset, out int constOffset))
  226. {
  227. Operand value = operation.GetSource(operation.SourcesCount - 1);
  228. SearchResult result = FindUniqueBaseAddressCb(gtsContext, block, value, needsOffset: false);
  229. if (result.Found)
  230. {
  231. uint targetCb = PackCbSlotAndOffset(result.SbCbSlot, result.SbCbOffset);
  232. gtsContext.AddMemoryTargetCb(type, baseOffset, constOffset, targetCb, result);
  233. }
  234. }
  235. }
  236. }
  237. }
  238. }
  239. private static bool IsGlobalMemory(StorageKind storageKind)
  240. {
  241. return storageKind == StorageKind.GlobalMemory ||
  242. storageKind == StorageKind.GlobalMemoryS8 ||
  243. storageKind == StorageKind.GlobalMemoryS16 ||
  244. storageKind == StorageKind.GlobalMemoryU8 ||
  245. storageKind == StorageKind.GlobalMemoryU16;
  246. }
  247. private static bool IsSmallInt(StorageKind storageKind)
  248. {
  249. return storageKind == StorageKind.GlobalMemoryS8 ||
  250. storageKind == StorageKind.GlobalMemoryS16 ||
  251. storageKind == StorageKind.GlobalMemoryU8 ||
  252. storageKind == StorageKind.GlobalMemoryU16;
  253. }
  254. private static LinkedListNode<INode> ReplaceGlobalMemoryWithStorage(
  255. GtsContext gtsContext,
  256. ResourceManager resourceManager,
  257. IGpuAccessor gpuAccessor,
  258. TargetLanguage targetLanguage,
  259. BasicBlock block,
  260. LinkedListNode<INode> node)
  261. {
  262. Operation operation = node.Value as Operation;
  263. Operand globalAddress = operation.GetSource(0);
  264. SearchResult result = FindUniqueBaseAddressCb(gtsContext, block, globalAddress, needsOffset: true);
  265. if (result.Found)
  266. {
  267. // We found the storage buffer that is being accessed.
  268. // There are two possible paths here, if the operation is simple enough,
  269. // we just generate the storage access code inline.
  270. // Otherwise, we generate a function call (and the function if necessary).
  271. Operand offset = result.Offset;
  272. bool storageUnaligned = gpuAccessor.QueryHasUnalignedStorageBuffer();
  273. if (storageUnaligned)
  274. {
  275. Operand baseAddress = Cbuf(result.SbCbSlot, result.SbCbOffset);
  276. Operand baseAddressMasked = Local();
  277. Operand hostOffset = Local();
  278. int alignment = gpuAccessor.QueryHostStorageBufferOffsetAlignment();
  279. Operation maskOp = new(Instruction.BitwiseAnd, baseAddressMasked, baseAddress, Const(-alignment));
  280. Operation subOp = new(Instruction.Subtract, hostOffset, globalAddress, baseAddressMasked);
  281. node.List.AddBefore(node, maskOp);
  282. node.List.AddBefore(node, subOp);
  283. offset = hostOffset;
  284. }
  285. else if (result.ConstOffset != 0)
  286. {
  287. Operand newOffset = Local();
  288. Operation addOp = new(Instruction.Add, newOffset, offset, Const(result.ConstOffset));
  289. node.List.AddBefore(node, addOp);
  290. offset = newOffset;
  291. }
  292. if (CanUseInlineStorageOp(operation, targetLanguage))
  293. {
  294. return GenerateInlineStorageOp(resourceManager, node, operation, offset, result);
  295. }
  296. else
  297. {
  298. if (!TryGenerateSingleTargetStorageOp(
  299. gtsContext,
  300. resourceManager,
  301. targetLanguage,
  302. operation,
  303. result,
  304. out int functionId))
  305. {
  306. return null;
  307. }
  308. return GenerateCallStorageOp(node, operation, offset, functionId);
  309. }
  310. }
  311. else
  312. {
  313. // Failed to find the storage buffer directly.
  314. // Try to walk through Phi chains and find all possible constant buffers where
  315. // the base address might be stored.
  316. // Generate a helper function that will check all possible storage buffers and use the right one.
  317. if (!TryGenerateMultiTargetStorageOp(
  318. gtsContext,
  319. resourceManager,
  320. gpuAccessor,
  321. targetLanguage,
  322. block,
  323. operation,
  324. out int functionId))
  325. {
  326. return null;
  327. }
  328. return GenerateCallStorageOp(node, operation, null, functionId);
  329. }
  330. }
  331. private static bool CanUseInlineStorageOp(Operation operation, TargetLanguage targetLanguage)
  332. {
  333. if (operation.StorageKind != StorageKind.GlobalMemory)
  334. {
  335. return false;
  336. }
  337. return (operation.Inst != Instruction.AtomicMaxS32 &&
  338. operation.Inst != Instruction.AtomicMinS32) || targetLanguage == TargetLanguage.Spirv;
  339. }
  340. private static LinkedListNode<INode> GenerateInlineStorageOp(
  341. ResourceManager resourceManager,
  342. LinkedListNode<INode> node,
  343. Operation operation,
  344. Operand offset,
  345. SearchResult result)
  346. {
  347. bool isStore = operation.Inst == Instruction.Store || operation.Inst.IsAtomic();
  348. if (!resourceManager.TryGetStorageBufferBinding(result.SbCbSlot, result.SbCbOffset, isStore, out int binding))
  349. {
  350. return null;
  351. }
  352. Operand wordOffset = Local();
  353. Operand[] sources;
  354. if (operation.Inst == Instruction.AtomicCompareAndSwap)
  355. {
  356. sources =
  357. [
  358. Const(binding),
  359. Const(0),
  360. wordOffset,
  361. operation.GetSource(operation.SourcesCount - 2),
  362. operation.GetSource(operation.SourcesCount - 1)
  363. ];
  364. }
  365. else if (isStore)
  366. {
  367. sources = [Const(binding), Const(0), wordOffset, operation.GetSource(operation.SourcesCount - 1)];
  368. }
  369. else
  370. {
  371. sources = [Const(binding), Const(0), wordOffset];
  372. }
  373. Operation shiftOp = new(Instruction.ShiftRightU32, wordOffset, offset, Const(2));
  374. Operation storageOp = new(operation.Inst, StorageKind.StorageBuffer, operation.Dest, sources);
  375. node.List.AddBefore(node, shiftOp);
  376. LinkedListNode<INode> newNode = node.List.AddBefore(node, storageOp);
  377. Utils.DeleteNode(node, operation);
  378. return newNode;
  379. }
  380. private static LinkedListNode<INode> GenerateCallStorageOp(LinkedListNode<INode> node, Operation operation, Operand offset, int functionId)
  381. {
  382. // Generate call to a helper function that will perform the storage buffer operation.
  383. Operand[] sources = new Operand[operation.SourcesCount - 1 + (offset == null ? 2 : 1)];
  384. sources[0] = Const(functionId);
  385. if (offset != null)
  386. {
  387. // If the offset was supplised, we use that and skip the global address.
  388. sources[1] = offset;
  389. for (int srcIndex = 2; srcIndex < operation.SourcesCount; srcIndex++)
  390. {
  391. sources[srcIndex] = operation.GetSource(srcIndex);
  392. }
  393. }
  394. else
  395. {
  396. // Use the 64-bit global address which is split in 2 32-bit arguments.
  397. for (int srcIndex = 0; srcIndex < operation.SourcesCount; srcIndex++)
  398. {
  399. sources[srcIndex + 1] = operation.GetSource(srcIndex);
  400. }
  401. }
  402. bool returnsValue = operation.Dest != null;
  403. Operand returnValue = returnsValue ? Local() : null;
  404. Operation callOp = new(Instruction.Call, returnValue, sources);
  405. LinkedListNode<INode> newNode = node.List.AddBefore(node, callOp);
  406. if (returnsValue)
  407. {
  408. operation.TurnIntoCopy(returnValue);
  409. return node;
  410. }
  411. else
  412. {
  413. Utils.DeleteNode(node, operation);
  414. return newNode;
  415. }
  416. }
  417. private static bool TryGenerateSingleTargetStorageOp(
  418. GtsContext gtsContext,
  419. ResourceManager resourceManager,
  420. TargetLanguage targetLanguage,
  421. Operation operation,
  422. SearchResult result,
  423. out int functionId)
  424. {
  425. List<uint> targetCbs = [PackCbSlotAndOffset(result.SbCbSlot, result.SbCbOffset)];
  426. if (gtsContext.TryGetFunctionId(operation, isMultiTarget: false, targetCbs, out functionId))
  427. {
  428. return true;
  429. }
  430. int inArgumentsCount = 1;
  431. if (operation.Inst == Instruction.AtomicCompareAndSwap)
  432. {
  433. inArgumentsCount = 3;
  434. }
  435. else if (operation.Inst == Instruction.Store || operation.Inst.IsAtomic())
  436. {
  437. inArgumentsCount = 2;
  438. }
  439. EmitterContext context = new();
  440. Operand offset = Argument(0);
  441. Operand compare = null;
  442. Operand value = null;
  443. if (inArgumentsCount == 3)
  444. {
  445. compare = Argument(1);
  446. value = Argument(2);
  447. }
  448. else if (inArgumentsCount == 2)
  449. {
  450. value = Argument(1);
  451. }
  452. if (!TryGenerateStorageOp(
  453. resourceManager,
  454. targetLanguage,
  455. context,
  456. operation.Inst,
  457. operation.StorageKind,
  458. offset,
  459. compare,
  460. value,
  461. result,
  462. out Operand resultValue))
  463. {
  464. functionId = 0;
  465. return false;
  466. }
  467. bool returnsValue = resultValue != null;
  468. if (returnsValue)
  469. {
  470. context.Return(resultValue);
  471. }
  472. else
  473. {
  474. context.Return();
  475. }
  476. string functionName = GetFunctionName(operation, isMultiTarget: false, targetCbs);
  477. Function function = new(
  478. ControlFlowGraph.Create(context.GetOperations()).Blocks,
  479. functionName,
  480. returnsValue,
  481. inArgumentsCount,
  482. 0);
  483. functionId = gtsContext.AddFunction(operation, isMultiTarget: false, targetCbs, function);
  484. return true;
  485. }
  486. private static bool TryGenerateMultiTargetStorageOp(
  487. GtsContext gtsContext,
  488. ResourceManager resourceManager,
  489. IGpuAccessor gpuAccessor,
  490. TargetLanguage targetLanguage,
  491. BasicBlock block,
  492. Operation operation,
  493. out int functionId)
  494. {
  495. Queue<PhiNode> phis = new();
  496. HashSet<PhiNode> visited = [];
  497. List<uint> targetCbs = [];
  498. Operand globalAddress = operation.GetSource(0);
  499. if (globalAddress.AsgOp is Operation addOp && addOp.Inst == Instruction.Add)
  500. {
  501. Operand src1 = addOp.GetSource(0);
  502. Operand src2 = addOp.GetSource(1);
  503. if (src1.Type == OperandType.Constant && src2.Type == OperandType.LocalVariable)
  504. {
  505. globalAddress = src2;
  506. }
  507. else if (src1.Type == OperandType.LocalVariable && src2.Type == OperandType.Constant)
  508. {
  509. globalAddress = src1;
  510. }
  511. }
  512. if (globalAddress.AsgOp is PhiNode phi && visited.Add(phi))
  513. {
  514. phis.Enqueue(phi);
  515. }
  516. else
  517. {
  518. SearchResult result = FindUniqueBaseAddressCb(gtsContext, block, operation.GetSource(0), needsOffset: false);
  519. if (result.Found)
  520. {
  521. targetCbs.Add(PackCbSlotAndOffset(result.SbCbSlot, result.SbCbOffset));
  522. }
  523. }
  524. while (phis.TryDequeue(out phi))
  525. {
  526. for (int srcIndex = 0; srcIndex < phi.SourcesCount; srcIndex++)
  527. {
  528. BasicBlock phiBlock = phi.GetBlock(srcIndex);
  529. Operand phiSource = phi.GetSource(srcIndex);
  530. SearchResult result = FindUniqueBaseAddressCb(gtsContext, phiBlock, phiSource, needsOffset: false);
  531. if (result.Found)
  532. {
  533. uint targetCb = PackCbSlotAndOffset(result.SbCbSlot, result.SbCbOffset);
  534. if (!targetCbs.Contains(targetCb))
  535. {
  536. targetCbs.Add(targetCb);
  537. }
  538. }
  539. else if (phiSource.AsgOp is PhiNode phi2 && visited.Add(phi2))
  540. {
  541. phis.Enqueue(phi2);
  542. }
  543. }
  544. }
  545. targetCbs.Sort();
  546. if (targetCbs.Count == 0)
  547. {
  548. gpuAccessor.Log($"Failed to find storage buffer for global memory operation \"{operation.Inst}\".");
  549. }
  550. if (gtsContext.TryGetFunctionId(operation, isMultiTarget: true, targetCbs, out functionId))
  551. {
  552. return true;
  553. }
  554. int inArgumentsCount = 2;
  555. if (operation.Inst == Instruction.AtomicCompareAndSwap)
  556. {
  557. inArgumentsCount = 4;
  558. }
  559. else if (operation.Inst == Instruction.Store || operation.Inst.IsAtomic())
  560. {
  561. inArgumentsCount = 3;
  562. }
  563. EmitterContext context = new();
  564. Operand globalAddressLow = Argument(0);
  565. Operand globalAddressHigh = Argument(1);
  566. foreach (uint targetCb in targetCbs)
  567. {
  568. (int sbCbSlot, int sbCbOffset) = UnpackCbSlotAndOffset(targetCb);
  569. Operand baseAddrLow = Cbuf(sbCbSlot, sbCbOffset);
  570. Operand baseAddrHigh = Cbuf(sbCbSlot, sbCbOffset + 1);
  571. Operand size = Cbuf(sbCbSlot, sbCbOffset + 2);
  572. Operand offset = context.ISubtract(globalAddressLow, baseAddrLow);
  573. Operand borrow = context.ICompareLessUnsigned(globalAddressLow, baseAddrLow);
  574. Operand inRangeLow = context.ICompareLessUnsigned(offset, size);
  575. Operand addrHighBorrowed = context.IAdd(globalAddressHigh, borrow);
  576. Operand inRangeHigh = context.ICompareEqual(addrHighBorrowed, baseAddrHigh);
  577. Operand inRange = context.BitwiseAnd(inRangeLow, inRangeHigh);
  578. Operand lblSkip = Label();
  579. context.BranchIfFalse(lblSkip, inRange);
  580. Operand compare = null;
  581. Operand value = null;
  582. if (inArgumentsCount == 4)
  583. {
  584. compare = Argument(2);
  585. value = Argument(3);
  586. }
  587. else if (inArgumentsCount == 3)
  588. {
  589. value = Argument(2);
  590. }
  591. SearchResult result = new(sbCbSlot, sbCbOffset);
  592. int alignment = gpuAccessor.QueryHostStorageBufferOffsetAlignment();
  593. Operand baseAddressMasked = context.BitwiseAnd(baseAddrLow, Const(-alignment));
  594. Operand hostOffset = context.ISubtract(globalAddressLow, baseAddressMasked);
  595. if (!TryGenerateStorageOp(
  596. resourceManager,
  597. targetLanguage,
  598. context,
  599. operation.Inst,
  600. operation.StorageKind,
  601. hostOffset,
  602. compare,
  603. value,
  604. result,
  605. out Operand resultValue))
  606. {
  607. functionId = 0;
  608. return false;
  609. }
  610. if (resultValue != null)
  611. {
  612. context.Return(resultValue);
  613. }
  614. else
  615. {
  616. context.Return();
  617. }
  618. context.MarkLabel(lblSkip);
  619. }
  620. bool returnsValue = operation.Dest != null;
  621. if (returnsValue)
  622. {
  623. context.Return(Const(0));
  624. }
  625. else
  626. {
  627. context.Return();
  628. }
  629. string functionName = GetFunctionName(operation, isMultiTarget: true, targetCbs);
  630. Function function = new(
  631. ControlFlowGraph.Create(context.GetOperations()).Blocks,
  632. functionName,
  633. returnsValue,
  634. inArgumentsCount,
  635. 0);
  636. functionId = gtsContext.AddFunction(operation, isMultiTarget: true, targetCbs, function);
  637. return true;
  638. }
  639. private static uint PackCbSlotAndOffset(int cbSlot, int cbOffset)
  640. {
  641. return (uint)((ushort)cbSlot | ((ushort)cbOffset << 16));
  642. }
  643. private static (int, int) UnpackCbSlotAndOffset(uint packed)
  644. {
  645. return ((ushort)packed, (ushort)(packed >> 16));
  646. }
  647. private static string GetFunctionName(Operation baseOp, bool isMultiTarget, IReadOnlyList<uint> targetCbs)
  648. {
  649. StringBuilder nameBuilder = new();
  650. nameBuilder.Append(baseOp.Inst.ToString());
  651. nameBuilder.Append(baseOp.StorageKind switch
  652. {
  653. StorageKind.GlobalMemoryS8 => "S8",
  654. StorageKind.GlobalMemoryS16 => "S16",
  655. StorageKind.GlobalMemoryU8 => "U8",
  656. StorageKind.GlobalMemoryU16 => "U16",
  657. _ => string.Empty,
  658. });
  659. if (isMultiTarget)
  660. {
  661. nameBuilder.Append("Multi");
  662. }
  663. foreach (uint targetCb in targetCbs)
  664. {
  665. (int sbCbSlot, int sbCbOffset) = UnpackCbSlotAndOffset(targetCb);
  666. nameBuilder.Append($"_c{sbCbSlot}o{sbCbOffset}");
  667. }
  668. return nameBuilder.ToString();
  669. }
  670. private static bool TryGenerateStorageOp(
  671. ResourceManager resourceManager,
  672. TargetLanguage targetLanguage,
  673. EmitterContext context,
  674. Instruction inst,
  675. StorageKind storageKind,
  676. Operand offset,
  677. Operand compare,
  678. Operand value,
  679. SearchResult result,
  680. out Operand resultValue)
  681. {
  682. resultValue = null;
  683. bool isStore = inst.IsAtomic() || inst == Instruction.Store;
  684. if (!resourceManager.TryGetStorageBufferBinding(result.SbCbSlot, result.SbCbOffset, isStore, out int binding))
  685. {
  686. return false;
  687. }
  688. Operand wordOffset = context.ShiftRightU32(offset, Const(2));
  689. if (inst.IsAtomic())
  690. {
  691. if (IsSmallInt(storageKind))
  692. {
  693. throw new NotImplementedException();
  694. }
  695. switch (inst)
  696. {
  697. case Instruction.AtomicAdd:
  698. resultValue = context.AtomicAdd(StorageKind.StorageBuffer, binding, Const(0), wordOffset, value);
  699. break;
  700. case Instruction.AtomicAnd:
  701. resultValue = context.AtomicAnd(StorageKind.StorageBuffer, binding, Const(0), wordOffset, value);
  702. break;
  703. case Instruction.AtomicCompareAndSwap:
  704. resultValue = context.AtomicCompareAndSwap(StorageKind.StorageBuffer, binding, Const(0), wordOffset, compare, value);
  705. break;
  706. case Instruction.AtomicMaxS32:
  707. if (targetLanguage == TargetLanguage.Spirv)
  708. {
  709. resultValue = context.AtomicMaxS32(StorageKind.StorageBuffer, binding, Const(0), wordOffset, value);
  710. }
  711. else
  712. {
  713. resultValue = GenerateAtomicCasLoop(context, wordOffset, binding, (memValue) =>
  714. {
  715. return context.IMaximumS32(memValue, value);
  716. });
  717. }
  718. break;
  719. case Instruction.AtomicMaxU32:
  720. resultValue = context.AtomicMaxU32(StorageKind.StorageBuffer, binding, Const(0), wordOffset, value);
  721. break;
  722. case Instruction.AtomicMinS32:
  723. if (targetLanguage == TargetLanguage.Spirv)
  724. {
  725. resultValue = context.AtomicMinS32(StorageKind.StorageBuffer, binding, Const(0), wordOffset, value);
  726. }
  727. else
  728. {
  729. resultValue = GenerateAtomicCasLoop(context, wordOffset, binding, (memValue) =>
  730. {
  731. return context.IMinimumS32(memValue, value);
  732. });
  733. }
  734. break;
  735. case Instruction.AtomicMinU32:
  736. resultValue = context.AtomicMinU32(StorageKind.StorageBuffer, binding, Const(0), wordOffset, value);
  737. break;
  738. case Instruction.AtomicOr:
  739. resultValue = context.AtomicOr(StorageKind.StorageBuffer, binding, Const(0), wordOffset, value);
  740. break;
  741. case Instruction.AtomicSwap:
  742. resultValue = context.AtomicSwap(StorageKind.StorageBuffer, binding, Const(0), wordOffset, value);
  743. break;
  744. case Instruction.AtomicXor:
  745. resultValue = context.AtomicXor(StorageKind.StorageBuffer, binding, Const(0), wordOffset, value);
  746. break;
  747. }
  748. }
  749. else if (inst == Instruction.Store)
  750. {
  751. int bitSize = storageKind switch
  752. {
  753. StorageKind.GlobalMemoryS8 or
  754. StorageKind.GlobalMemoryU8 => 8,
  755. StorageKind.GlobalMemoryS16 or
  756. StorageKind.GlobalMemoryU16 => 16,
  757. _ => 32,
  758. };
  759. if (bitSize < 32)
  760. {
  761. Operand bitOffset = HelperFunctionManager.GetBitOffset(context, offset);
  762. GenerateAtomicCasLoop(context, wordOffset, binding, (memValue) =>
  763. {
  764. return context.BitfieldInsert(memValue, value, bitOffset, Const(bitSize));
  765. });
  766. }
  767. else
  768. {
  769. context.Store(StorageKind.StorageBuffer, binding, Const(0), wordOffset, value);
  770. }
  771. }
  772. else
  773. {
  774. value = context.Load(StorageKind.StorageBuffer, binding, Const(0), wordOffset);
  775. if (IsSmallInt(storageKind))
  776. {
  777. Operand bitOffset = HelperFunctionManager.GetBitOffset(context, offset);
  778. switch (storageKind)
  779. {
  780. case StorageKind.GlobalMemoryS8:
  781. value = context.ShiftRightS32(value, bitOffset);
  782. value = context.BitfieldExtractS32(value, Const(0), Const(8));
  783. break;
  784. case StorageKind.GlobalMemoryS16:
  785. value = context.ShiftRightS32(value, bitOffset);
  786. value = context.BitfieldExtractS32(value, Const(0), Const(16));
  787. break;
  788. case StorageKind.GlobalMemoryU8:
  789. value = context.ShiftRightU32(value, bitOffset);
  790. value = context.BitwiseAnd(value, Const(byte.MaxValue));
  791. break;
  792. case StorageKind.GlobalMemoryU16:
  793. value = context.ShiftRightU32(value, bitOffset);
  794. value = context.BitwiseAnd(value, Const(ushort.MaxValue));
  795. break;
  796. }
  797. }
  798. resultValue = value;
  799. }
  800. return true;
  801. }
  802. private static Operand GenerateAtomicCasLoop(EmitterContext context, Operand wordOffset, int binding, Func<Operand, Operand> opCallback)
  803. {
  804. Operand lblLoopHead = Label();
  805. context.MarkLabel(lblLoopHead);
  806. Operand oldValue = context.Load(StorageKind.StorageBuffer, binding, Const(0), wordOffset);
  807. Operand newValue = opCallback(oldValue);
  808. Operand casResult = context.AtomicCompareAndSwap(
  809. StorageKind.StorageBuffer,
  810. binding,
  811. Const(0),
  812. wordOffset,
  813. oldValue,
  814. newValue);
  815. Operand casFail = context.ICompareNotEqual(casResult, oldValue);
  816. context.BranchIfTrue(lblLoopHead, casFail);
  817. return oldValue;
  818. }
  819. private static SearchResult FindUniqueBaseAddressCb(GtsContext gtsContext, BasicBlock block, Operand globalAddress, bool needsOffset)
  820. {
  821. globalAddress = Utils.FindLastOperation(globalAddress, block);
  822. if (globalAddress.Type == OperandType.ConstantBuffer)
  823. {
  824. return GetBaseAddressCbWithOffset(globalAddress, Const(0), 0);
  825. }
  826. Operation operation = globalAddress.AsgOp as Operation;
  827. if (operation == null || operation.Inst != Instruction.Add)
  828. {
  829. return FindBaseAddressCbFromMemory(gtsContext, operation, 0, needsOffset);
  830. }
  831. Operand src1 = operation.GetSource(0);
  832. Operand src2 = operation.GetSource(1);
  833. int constOffset = 0;
  834. if ((src1.Type == OperandType.LocalVariable && src2.Type == OperandType.Constant) ||
  835. (src2.Type == OperandType.LocalVariable && src1.Type == OperandType.Constant))
  836. {
  837. Operand baseAddr;
  838. Operand offset;
  839. if (src1.Type == OperandType.LocalVariable)
  840. {
  841. baseAddr = Utils.FindLastOperation(src1, block);
  842. offset = src2;
  843. }
  844. else
  845. {
  846. baseAddr = Utils.FindLastOperation(src2, block);
  847. offset = src1;
  848. }
  849. SearchResult result = GetBaseAddressCbWithOffset(baseAddr, offset, 0);
  850. if (result.Found)
  851. {
  852. return result;
  853. }
  854. constOffset = offset.Value;
  855. operation = baseAddr.AsgOp as Operation;
  856. if (operation == null || operation.Inst != Instruction.Add)
  857. {
  858. return FindBaseAddressCbFromMemory(gtsContext, operation, constOffset, needsOffset);
  859. }
  860. }
  861. src1 = operation.GetSource(0);
  862. src2 = operation.GetSource(1);
  863. // If we have two possible results, we give preference to the ones from
  864. // the driver reserved constant buffer, as those are the ones that
  865. // contains the base address.
  866. // If both are constant buffer, give preference to the second operand,
  867. // because constant buffer are always encoded as the second operand,
  868. // so the second operand will always be the one from the last instruction.
  869. if (src1.Type != OperandType.ConstantBuffer ||
  870. (src1.Type == OperandType.ConstantBuffer && src2.Type == OperandType.ConstantBuffer) ||
  871. (src2.Type == OperandType.ConstantBuffer && src2.GetCbufSlot() == DriverReservedCb))
  872. {
  873. return GetBaseAddressCbWithOffset(src2, src1, constOffset);
  874. }
  875. return GetBaseAddressCbWithOffset(src1, src2, constOffset);
  876. }
  877. private static SearchResult FindBaseAddressCbFromMemory(GtsContext gtsContext, Operation operation, int constOffset, bool needsOffset)
  878. {
  879. if (operation != null)
  880. {
  881. if (TryGetMemoryOffsets(operation, out LsMemoryType type, out Operand bo, out int co) &&
  882. gtsContext.TryGetMemoryTargetCb(type, bo, co, out SearchResult result) &&
  883. (result.Offset != null || !needsOffset))
  884. {
  885. if (constOffset != 0)
  886. {
  887. return new SearchResult(
  888. result.SbCbSlot,
  889. result.SbCbOffset,
  890. result.Offset,
  891. result.ConstOffset + constOffset);
  892. }
  893. return result;
  894. }
  895. }
  896. return SearchResult.NotFound;
  897. }
  898. private static SearchResult GetBaseAddressCbWithOffset(Operand baseAddress, Operand offset, int constOffset)
  899. {
  900. if (baseAddress.Type == OperandType.ConstantBuffer)
  901. {
  902. int sbCbSlot = baseAddress.GetCbufSlot();
  903. int sbCbOffset = baseAddress.GetCbufOffset();
  904. // We require the offset to be aligned to 1 word (64 bits),
  905. // since the address size is 64-bit and the GPU only supports aligned memory access.
  906. if ((sbCbOffset & 1) == 0)
  907. {
  908. return new SearchResult(sbCbSlot, sbCbOffset, offset, constOffset);
  909. }
  910. }
  911. return SearchResult.NotFound;
  912. }
  913. private static bool TryGetMemoryOffsets(Operation operation, out LsMemoryType type, out Operand baseOffset, out int constOffset)
  914. {
  915. baseOffset = null;
  916. if (operation.Inst == Instruction.Load || operation.Inst == Instruction.Store)
  917. {
  918. if (operation.StorageKind == StorageKind.SharedMemory)
  919. {
  920. type = LsMemoryType.Shared;
  921. return TryGetSharedMemoryOffsets(operation, out baseOffset, out constOffset);
  922. }
  923. else if (operation.StorageKind == StorageKind.LocalMemory)
  924. {
  925. type = LsMemoryType.Local;
  926. return TryGetLocalMemoryOffset(operation, out constOffset);
  927. }
  928. }
  929. type = default;
  930. constOffset = 0;
  931. return false;
  932. }
  933. private static bool TryGetSharedMemoryOffsets(Operation operation, out Operand baseOffset, out int constOffset)
  934. {
  935. baseOffset = null;
  936. constOffset = 0;
  937. // The byte offset is right shifted by 2 to get the 32-bit word offset,
  938. // so we want to get the byte offset back, since each one of those word
  939. // offsets are a new "local variable" which will not match.
  940. if (operation.GetSource(1).AsgOp is Operation shiftRightOp &&
  941. shiftRightOp.Inst == Instruction.ShiftRightU32 &&
  942. shiftRightOp.GetSource(1).Type == OperandType.Constant &&
  943. shiftRightOp.GetSource(1).Value == 2)
  944. {
  945. baseOffset = shiftRightOp.GetSource(0);
  946. }
  947. // Check if we have a constant offset being added to the base offset.
  948. if (baseOffset?.AsgOp is Operation addOp && addOp.Inst == Instruction.Add)
  949. {
  950. Operand src1 = addOp.GetSource(0);
  951. Operand src2 = addOp.GetSource(1);
  952. if (src1.Type == OperandType.Constant && src2.Type == OperandType.LocalVariable)
  953. {
  954. constOffset = src1.Value;
  955. baseOffset = src2;
  956. }
  957. else if (src1.Type == OperandType.LocalVariable && src2.Type == OperandType.Constant)
  958. {
  959. baseOffset = src1;
  960. constOffset = src2.Value;
  961. }
  962. }
  963. return baseOffset != null && baseOffset.Type == OperandType.LocalVariable;
  964. }
  965. private static bool TryGetLocalMemoryOffset(Operation operation, out int constOffset)
  966. {
  967. Operand offset = operation.GetSource(1);
  968. if (offset.Type == OperandType.Constant)
  969. {
  970. constOffset = offset.Value;
  971. return true;
  972. }
  973. constOffset = 0;
  974. return false;
  975. }
  976. }
  977. }