CodeGenContext.cs 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. using Ryujinx.Graphics.Shader.StructuredIr;
  2. using Ryujinx.Graphics.Shader.Translation;
  3. using Spv.Generator;
  4. using System;
  5. using System.Collections.Generic;
  6. using System.Linq;
  7. using static Spv.Specification;
  8. namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
  9. {
  10. using IrConsts = IntermediateRepresentation.IrConsts;
  11. using IrOperandType = IntermediateRepresentation.OperandType;
  12. partial class CodeGenContext : Module
  13. {
  14. private const uint SpirvVersionMajor = 1;
  15. private const uint SpirvVersionMinor = 3;
  16. private const uint SpirvVersionRevision = 0;
  17. private const uint SpirvVersionPacked = (SpirvVersionMajor << 16) | (SpirvVersionMinor << 8) | SpirvVersionRevision;
  18. private readonly StructuredProgramInfo _info;
  19. public ShaderConfig Config { get; }
  20. public int InputVertices { get; }
  21. public Dictionary<int, Instruction> UniformBuffers { get; } = new Dictionary<int, Instruction>();
  22. public Instruction SupportBuffer { get; set; }
  23. public Instruction UniformBuffersArray { get; set; }
  24. public Instruction StorageBuffersArray { get; set; }
  25. public Instruction LocalMemory { get; set; }
  26. public Instruction SharedMemory { get; set; }
  27. public Instruction InputsArray { get; set; }
  28. public Instruction OutputsArray { get; set; }
  29. public Dictionary<TextureMeta, SamplerType> SamplersTypes { get; } = new Dictionary<TextureMeta, SamplerType>();
  30. public Dictionary<TextureMeta, (Instruction, Instruction, Instruction)> Samplers { get; } = new Dictionary<TextureMeta, (Instruction, Instruction, Instruction)>();
  31. public Dictionary<TextureMeta, (Instruction, Instruction)> Images { get; } = new Dictionary<TextureMeta, (Instruction, Instruction)>();
  32. public Dictionary<int, Instruction> Inputs { get; } = new Dictionary<int, Instruction>();
  33. public Dictionary<int, Instruction> Outputs { get; } = new Dictionary<int, Instruction>();
  34. public Dictionary<int, Instruction> InputsPerPatch { get; } = new Dictionary<int, Instruction>();
  35. public Dictionary<int, Instruction> OutputsPerPatch { get; } = new Dictionary<int, Instruction>();
  36. public Instruction CoordTemp { get; set; }
  37. private readonly Dictionary<AstOperand, Instruction> _locals = new Dictionary<AstOperand, Instruction>();
  38. private readonly Dictionary<int, Instruction[]> _localForArgs = new Dictionary<int, Instruction[]>();
  39. private readonly Dictionary<int, Instruction> _funcArgs = new Dictionary<int, Instruction>();
  40. private readonly Dictionary<int, (StructuredFunction, Instruction)> _functions = new Dictionary<int, (StructuredFunction, Instruction)>();
  41. private class BlockState
  42. {
  43. private int _entryCount;
  44. private readonly List<Instruction> _labels = new List<Instruction>();
  45. public Instruction GetNextLabel(CodeGenContext context)
  46. {
  47. return GetLabel(context, _entryCount);
  48. }
  49. public Instruction GetNextLabelAutoIncrement(CodeGenContext context)
  50. {
  51. return GetLabel(context, _entryCount++);
  52. }
  53. public Instruction GetLabel(CodeGenContext context, int index)
  54. {
  55. while (index >= _labels.Count)
  56. {
  57. _labels.Add(context.Label());
  58. }
  59. return _labels[index];
  60. }
  61. }
  62. private readonly Dictionary<AstBlock, BlockState> _labels = new Dictionary<AstBlock, BlockState>();
  63. public Dictionary<AstBlock, (Instruction, Instruction)> LoopTargets { get; set; }
  64. public AstBlock CurrentBlock { get; private set; }
  65. public SpirvDelegates Delegates { get; }
  66. public CodeGenContext(
  67. StructuredProgramInfo info,
  68. ShaderConfig config,
  69. GeneratorPool<Instruction> instPool,
  70. GeneratorPool<LiteralInteger> integerPool) : base(SpirvVersionPacked, instPool, integerPool)
  71. {
  72. _info = info;
  73. Config = config;
  74. if (config.Stage == ShaderStage.Geometry)
  75. {
  76. InputTopology inPrimitive = config.GpuAccessor.QueryPrimitiveTopology();
  77. InputVertices = inPrimitive switch
  78. {
  79. InputTopology.Points => 1,
  80. InputTopology.Lines => 2,
  81. InputTopology.LinesAdjacency => 2,
  82. InputTopology.Triangles => 3,
  83. InputTopology.TrianglesAdjacency => 3,
  84. _ => throw new InvalidOperationException($"Invalid input topology \"{inPrimitive}\".")
  85. };
  86. }
  87. AddCapability(Capability.Shader);
  88. AddCapability(Capability.Float64);
  89. SetMemoryModel(AddressingModel.Logical, MemoryModel.GLSL450);
  90. Delegates = new SpirvDelegates(this);
  91. }
  92. public void StartFunction()
  93. {
  94. _locals.Clear();
  95. _localForArgs.Clear();
  96. _funcArgs.Clear();
  97. }
  98. public void EnterBlock(AstBlock block)
  99. {
  100. CurrentBlock = block;
  101. AddLabel(GetBlockStateLazy(block).GetNextLabelAutoIncrement(this));
  102. }
  103. public Instruction GetFirstLabel(AstBlock block)
  104. {
  105. return GetBlockStateLazy(block).GetLabel(this, 0);
  106. }
  107. public Instruction GetNextLabel(AstBlock block)
  108. {
  109. return GetBlockStateLazy(block).GetNextLabel(this);
  110. }
  111. private BlockState GetBlockStateLazy(AstBlock block)
  112. {
  113. if (!_labels.TryGetValue(block, out var blockState))
  114. {
  115. blockState = new BlockState();
  116. _labels.Add(block, blockState);
  117. }
  118. return blockState;
  119. }
  120. public Instruction NewBlock()
  121. {
  122. var label = Label();
  123. Branch(label);
  124. AddLabel(label);
  125. return label;
  126. }
  127. public Instruction[] GetMainInterface()
  128. {
  129. var mainInterface = new List<Instruction>();
  130. mainInterface.AddRange(Inputs.Values);
  131. mainInterface.AddRange(Outputs.Values);
  132. mainInterface.AddRange(InputsPerPatch.Values);
  133. mainInterface.AddRange(OutputsPerPatch.Values);
  134. if (InputsArray != null)
  135. {
  136. mainInterface.Add(InputsArray);
  137. }
  138. if (OutputsArray != null)
  139. {
  140. mainInterface.Add(OutputsArray);
  141. }
  142. return mainInterface.ToArray();
  143. }
  144. public void DeclareLocal(AstOperand local, Instruction spvLocal)
  145. {
  146. _locals.Add(local, spvLocal);
  147. }
  148. public void DeclareLocalForArgs(int funcIndex, Instruction[] spvLocals)
  149. {
  150. _localForArgs.Add(funcIndex, spvLocals);
  151. }
  152. public void DeclareArgument(int argIndex, Instruction spvLocal)
  153. {
  154. _funcArgs.Add(argIndex, spvLocal);
  155. }
  156. public void DeclareFunction(int funcIndex, StructuredFunction function, Instruction spvFunc)
  157. {
  158. _functions.Add(funcIndex, (function, spvFunc));
  159. }
  160. public Instruction GetFP32(IAstNode node)
  161. {
  162. return Get(AggregateType.FP32, node);
  163. }
  164. public Instruction GetFP64(IAstNode node)
  165. {
  166. return Get(AggregateType.FP64, node);
  167. }
  168. public Instruction GetS32(IAstNode node)
  169. {
  170. return Get(AggregateType.S32, node);
  171. }
  172. public Instruction GetU32(IAstNode node)
  173. {
  174. return Get(AggregateType.U32, node);
  175. }
  176. public Instruction Get(AggregateType type, IAstNode node)
  177. {
  178. if (node is AstOperation operation)
  179. {
  180. var opResult = Instructions.Generate(this, operation);
  181. return BitcastIfNeeded(type, opResult.Type, opResult.Value);
  182. }
  183. else if (node is AstOperand operand)
  184. {
  185. return operand.Type switch
  186. {
  187. IrOperandType.Argument => GetArgument(type, operand),
  188. IrOperandType.Attribute => GetAttribute(type, operand.Value & AttributeConsts.Mask, (operand.Value & AttributeConsts.LoadOutputMask) != 0),
  189. IrOperandType.AttributePerPatch => GetAttributePerPatch(type, operand.Value & AttributeConsts.Mask, (operand.Value & AttributeConsts.LoadOutputMask) != 0),
  190. IrOperandType.Constant => GetConstant(type, operand),
  191. IrOperandType.ConstantBuffer => GetConstantBuffer(type, operand),
  192. IrOperandType.LocalVariable => GetLocal(type, operand),
  193. IrOperandType.Undefined => GetUndefined(type),
  194. _ => throw new ArgumentException($"Invalid operand type \"{operand.Type}\".")
  195. };
  196. }
  197. throw new NotImplementedException(node.GetType().Name);
  198. }
  199. private Instruction GetUndefined(AggregateType type)
  200. {
  201. return type switch
  202. {
  203. AggregateType.Bool => ConstantFalse(TypeBool()),
  204. AggregateType.FP32 => Constant(TypeFP32(), 0f),
  205. AggregateType.FP64 => Constant(TypeFP64(), 0d),
  206. _ => Constant(GetType(type), 0)
  207. };
  208. }
  209. public Instruction GetAttributeElemPointer(int attr, bool isOutAttr, Instruction index, out AggregateType elemType)
  210. {
  211. var storageClass = isOutAttr ? StorageClass.Output : StorageClass.Input;
  212. var attrInfo = AttributeInfo.From(Config, attr, isOutAttr);
  213. int attrOffset = attrInfo.BaseValue;
  214. AggregateType type = attrInfo.Type;
  215. Instruction ioVariable, elemIndex;
  216. bool isUserAttr = attr >= AttributeConsts.UserAttributeBase && attr < AttributeConsts.UserAttributeEnd;
  217. if (isUserAttr &&
  218. ((!isOutAttr && Config.UsedFeatures.HasFlag(FeatureFlags.IaIndexing)) ||
  219. (isOutAttr && Config.UsedFeatures.HasFlag(FeatureFlags.OaIndexing))))
  220. {
  221. elemType = AggregateType.FP32;
  222. ioVariable = isOutAttr ? OutputsArray : InputsArray;
  223. elemIndex = Constant(TypeU32(), attrInfo.GetInnermostIndex());
  224. var vecIndex = Constant(TypeU32(), (attr - AttributeConsts.UserAttributeBase) >> 4);
  225. if (AttributeInfo.IsArrayAttributeSpirv(Config.Stage, isOutAttr))
  226. {
  227. return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index, vecIndex, elemIndex);
  228. }
  229. else
  230. {
  231. return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, vecIndex, elemIndex);
  232. }
  233. }
  234. bool isViewportInverse = attr == AttributeConsts.SupportBlockViewInverseX || attr == AttributeConsts.SupportBlockViewInverseY;
  235. if (isViewportInverse)
  236. {
  237. elemType = AggregateType.FP32;
  238. elemIndex = Constant(TypeU32(), (attr - AttributeConsts.SupportBlockViewInverseX) >> 2);
  239. return AccessChain(TypePointer(StorageClass.Uniform, TypeFP32()), SupportBuffer, Constant(TypeU32(), 2), elemIndex);
  240. }
  241. elemType = attrInfo.Type & AggregateType.ElementTypeMask;
  242. if (isUserAttr && Config.TransformFeedbackEnabled &&
  243. ((isOutAttr && Config.LastInVertexPipeline) ||
  244. (!isOutAttr && Config.Stage == ShaderStage.Fragment)))
  245. {
  246. attrOffset = attr;
  247. type = elemType;
  248. }
  249. ioVariable = isOutAttr ? Outputs[attrOffset] : Inputs[attrOffset];
  250. bool isIndexed = AttributeInfo.IsArrayAttributeSpirv(Config.Stage, isOutAttr) && (!attrInfo.IsBuiltin || AttributeInfo.IsArrayBuiltIn(attr));
  251. if ((type & (AggregateType.Array | AggregateType.Vector)) == 0)
  252. {
  253. return isIndexed ? AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index) : ioVariable;
  254. }
  255. elemIndex = Constant(TypeU32(), attrInfo.GetInnermostIndex());
  256. if (isIndexed)
  257. {
  258. return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index, elemIndex);
  259. }
  260. else
  261. {
  262. return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, elemIndex);
  263. }
  264. }
  265. public Instruction GetAttributeElemPointer(Instruction attrIndex, bool isOutAttr, Instruction index, out AggregateType elemType)
  266. {
  267. var storageClass = isOutAttr ? StorageClass.Output : StorageClass.Input;
  268. elemType = AggregateType.FP32;
  269. var ioVariable = isOutAttr ? OutputsArray : InputsArray;
  270. var vecIndex = ShiftRightLogical(TypeS32(), attrIndex, Constant(TypeS32(), 2));
  271. var elemIndex = BitwiseAnd(TypeS32(), attrIndex, Constant(TypeS32(), 3));
  272. if (AttributeInfo.IsArrayAttributeSpirv(Config.Stage, isOutAttr))
  273. {
  274. return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, index, vecIndex, elemIndex);
  275. }
  276. else
  277. {
  278. return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, vecIndex, elemIndex);
  279. }
  280. }
  281. public Instruction GetAttribute(AggregateType type, int attr, bool isOutAttr, Instruction index = null)
  282. {
  283. if (!AttributeInfo.Validate(Config, attr, isOutAttr: false))
  284. {
  285. return GetConstant(type, new AstOperand(IrOperandType.Constant, 0));
  286. }
  287. var elemPointer = GetAttributeElemPointer(attr, isOutAttr, index, out var elemType);
  288. var value = Load(GetType(elemType), elemPointer);
  289. if (Config.Stage == ShaderStage.Fragment)
  290. {
  291. if (attr == AttributeConsts.PositionX || attr == AttributeConsts.PositionY)
  292. {
  293. var pointerType = TypePointer(StorageClass.Uniform, TypeFP32());
  294. var fieldIndex = Constant(TypeU32(), 4);
  295. var scaleIndex = Constant(TypeU32(), 0);
  296. var scaleElemPointer = AccessChain(pointerType, SupportBuffer, fieldIndex, scaleIndex);
  297. var scale = Load(TypeFP32(), scaleElemPointer);
  298. value = FDiv(TypeFP32(), value, scale);
  299. }
  300. else if (attr == AttributeConsts.FrontFacing && Config.GpuAccessor.QueryHostHasFrontFacingBug())
  301. {
  302. // Workaround for what appears to be a bug on Intel compiler.
  303. var valueFloat = Select(TypeFP32(), value, Constant(TypeFP32(), 1f), Constant(TypeFP32(), 0f));
  304. var valueAsInt = Bitcast(TypeS32(), valueFloat);
  305. var valueNegated = SNegate(TypeS32(), valueAsInt);
  306. value = SLessThan(TypeBool(), valueNegated, Constant(TypeS32(), 0));
  307. }
  308. }
  309. return BitcastIfNeeded(type, elemType, value);
  310. }
  311. public Instruction GetAttributePerPatchElemPointer(int attr, bool isOutAttr, out AggregateType elemType)
  312. {
  313. var storageClass = isOutAttr ? StorageClass.Output : StorageClass.Input;
  314. var attrInfo = AttributeInfo.From(Config, attr, isOutAttr);
  315. int attrOffset = attrInfo.BaseValue;
  316. Instruction ioVariable;
  317. bool isUserAttr = attr >= AttributeConsts.UserAttributeBase && attr < AttributeConsts.UserAttributeEnd;
  318. elemType = attrInfo.Type & AggregateType.ElementTypeMask;
  319. ioVariable = isOutAttr ? OutputsPerPatch[attrOffset] : InputsPerPatch[attrOffset];
  320. if ((attrInfo.Type & (AggregateType.Array | AggregateType.Vector)) == 0)
  321. {
  322. return ioVariable;
  323. }
  324. var elemIndex = Constant(TypeU32(), attrInfo.GetInnermostIndex());
  325. return AccessChain(TypePointer(storageClass, GetType(elemType)), ioVariable, elemIndex);
  326. }
  327. public Instruction GetAttributePerPatch(AggregateType type, int attr, bool isOutAttr)
  328. {
  329. if (!AttributeInfo.Validate(Config, attr, isOutAttr: false))
  330. {
  331. return GetConstant(type, new AstOperand(IrOperandType.Constant, 0));
  332. }
  333. var elemPointer = GetAttributePerPatchElemPointer(attr, isOutAttr, out var elemType);
  334. return BitcastIfNeeded(type, elemType, Load(GetType(elemType), elemPointer));
  335. }
  336. public Instruction GetAttribute(AggregateType type, Instruction attr, bool isOutAttr, Instruction index = null)
  337. {
  338. var elemPointer = GetAttributeElemPointer(attr, isOutAttr, index, out var elemType);
  339. return BitcastIfNeeded(type, elemType, Load(GetType(elemType), elemPointer));
  340. }
  341. public Instruction GetConstant(AggregateType type, AstOperand operand)
  342. {
  343. return type switch
  344. {
  345. AggregateType.Bool => operand.Value != 0 ? ConstantTrue(TypeBool()) : ConstantFalse(TypeBool()),
  346. AggregateType.FP32 => Constant(TypeFP32(), BitConverter.Int32BitsToSingle(operand.Value)),
  347. AggregateType.FP64 => Constant(TypeFP64(), (double)BitConverter.Int32BitsToSingle(operand.Value)),
  348. AggregateType.S32 => Constant(TypeS32(), operand.Value),
  349. AggregateType.U32 => Constant(TypeU32(), (uint)operand.Value),
  350. _ => throw new ArgumentException($"Invalid type \"{type}\".")
  351. };
  352. }
  353. public Instruction GetConstantBuffer(AggregateType type, AstOperand operand)
  354. {
  355. var i1 = Constant(TypeS32(), 0);
  356. var i2 = Constant(TypeS32(), operand.CbufOffset >> 2);
  357. var i3 = Constant(TypeU32(), operand.CbufOffset & 3);
  358. Instruction elemPointer;
  359. if (UniformBuffersArray != null)
  360. {
  361. var ubVariable = UniformBuffersArray;
  362. var i0 = Constant(TypeS32(), operand.CbufSlot);
  363. elemPointer = AccessChain(TypePointer(StorageClass.Uniform, TypeFP32()), ubVariable, i0, i1, i2, i3);
  364. }
  365. else
  366. {
  367. var ubVariable = UniformBuffers[operand.CbufSlot];
  368. elemPointer = AccessChain(TypePointer(StorageClass.Uniform, TypeFP32()), ubVariable, i1, i2, i3);
  369. }
  370. return BitcastIfNeeded(type, AggregateType.FP32, Load(TypeFP32(), elemPointer));
  371. }
  372. public Instruction GetLocalPointer(AstOperand local)
  373. {
  374. return _locals[local];
  375. }
  376. public Instruction[] GetLocalForArgsPointers(int funcIndex)
  377. {
  378. return _localForArgs[funcIndex];
  379. }
  380. public Instruction GetArgumentPointer(AstOperand funcArg)
  381. {
  382. return _funcArgs[funcArg.Value];
  383. }
  384. public Instruction GetLocal(AggregateType dstType, AstOperand local)
  385. {
  386. var srcType = local.VarType.Convert();
  387. return BitcastIfNeeded(dstType, srcType, Load(GetType(srcType), GetLocalPointer(local)));
  388. }
  389. public Instruction GetArgument(AggregateType dstType, AstOperand funcArg)
  390. {
  391. var srcType = funcArg.VarType.Convert();
  392. return BitcastIfNeeded(dstType, srcType, Load(GetType(srcType), GetArgumentPointer(funcArg)));
  393. }
  394. public (StructuredFunction, Instruction) GetFunction(int funcIndex)
  395. {
  396. return _functions[funcIndex];
  397. }
  398. public TransformFeedbackOutput GetTransformFeedbackOutput(int location, int component)
  399. {
  400. int index = (AttributeConsts.UserAttributeBase / 4) + location * 4 + component;
  401. return _info.TransformFeedbackOutputs[index];
  402. }
  403. public TransformFeedbackOutput GetTransformFeedbackOutput(int location)
  404. {
  405. int index = location / 4;
  406. return _info.TransformFeedbackOutputs[index];
  407. }
  408. public Instruction GetType(AggregateType type, int length = 1)
  409. {
  410. if (type.HasFlag(AggregateType.Array))
  411. {
  412. return TypeArray(GetType(type & ~AggregateType.Array), Constant(TypeU32(), length));
  413. }
  414. else if (type.HasFlag(AggregateType.Vector))
  415. {
  416. return TypeVector(GetType(type & ~AggregateType.Vector), length);
  417. }
  418. return type switch
  419. {
  420. AggregateType.Void => TypeVoid(),
  421. AggregateType.Bool => TypeBool(),
  422. AggregateType.FP32 => TypeFP32(),
  423. AggregateType.FP64 => TypeFP64(),
  424. AggregateType.S32 => TypeS32(),
  425. AggregateType.U32 => TypeU32(),
  426. _ => throw new ArgumentException($"Invalid attribute type \"{type}\".")
  427. };
  428. }
  429. public Instruction BitcastIfNeeded(AggregateType dstType, AggregateType srcType, Instruction value)
  430. {
  431. if (dstType == srcType)
  432. {
  433. return value;
  434. }
  435. if (dstType == AggregateType.Bool)
  436. {
  437. return INotEqual(TypeBool(), BitcastIfNeeded(AggregateType.S32, srcType, value), Constant(TypeS32(), 0));
  438. }
  439. else if (srcType == AggregateType.Bool)
  440. {
  441. var intTrue = Constant(TypeS32(), IrConsts.True);
  442. var intFalse = Constant(TypeS32(), IrConsts.False);
  443. return BitcastIfNeeded(dstType, AggregateType.S32, Select(TypeS32(), value, intTrue, intFalse));
  444. }
  445. else
  446. {
  447. return Bitcast(GetType(dstType, 1), value);
  448. }
  449. }
  450. public Instruction TypeS32()
  451. {
  452. return TypeInt(32, true);
  453. }
  454. public Instruction TypeU32()
  455. {
  456. return TypeInt(32, false);
  457. }
  458. public Instruction TypeFP32()
  459. {
  460. return TypeFloat(32);
  461. }
  462. public Instruction TypeFP64()
  463. {
  464. return TypeFloat(64);
  465. }
  466. }
  467. }