CodeGenContext.cs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. using Ryujinx.Graphics.Shader.StructuredIr;
  2. using Ryujinx.Graphics.Shader.Translation;
  3. using Spv.Generator;
  4. using System;
  5. using System.Collections.Generic;
  6. using static Spv.Specification;
  7. namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
  8. {
  9. using IrConsts = IntermediateRepresentation.IrConsts;
  10. using IrOperandType = IntermediateRepresentation.OperandType;
  11. partial class CodeGenContext : Module
  12. {
  13. private const uint SpirvVersionMajor = 1;
  14. private const uint SpirvVersionMinor = 3;
  15. private const uint SpirvVersionRevision = 0;
  16. private const uint SpirvVersionPacked = (SpirvVersionMajor << 16) | (SpirvVersionMinor << 8) | SpirvVersionRevision;
  17. public StructuredProgramInfo Info { get; }
  18. public ShaderConfig Config { get; }
  19. public int InputVertices { get; }
  20. public Dictionary<int, Instruction> UniformBuffers { get; } = new Dictionary<int, Instruction>();
  21. public Instruction SupportBuffer { get; set; }
  22. public Instruction UniformBuffersArray { get; set; }
  23. public Instruction StorageBuffersArray { get; set; }
  24. public Instruction LocalMemory { get; set; }
  25. public Instruction SharedMemory { get; set; }
  26. public Dictionary<TextureMeta, SamplerType> SamplersTypes { get; } = new Dictionary<TextureMeta, SamplerType>();
  27. public Dictionary<TextureMeta, (Instruction, Instruction, Instruction)> Samplers { get; } = new Dictionary<TextureMeta, (Instruction, Instruction, Instruction)>();
  28. public Dictionary<TextureMeta, (Instruction, Instruction)> Images { get; } = new Dictionary<TextureMeta, (Instruction, Instruction)>();
  29. public Dictionary<IoDefinition, Instruction> Inputs { get; } = new Dictionary<IoDefinition, Instruction>();
  30. public Dictionary<IoDefinition, Instruction> Outputs { get; } = new Dictionary<IoDefinition, Instruction>();
  31. public Dictionary<IoDefinition, Instruction> InputsPerPatch { get; } = new Dictionary<IoDefinition, Instruction>();
  32. public Dictionary<IoDefinition, Instruction> OutputsPerPatch { get; } = new Dictionary<IoDefinition, Instruction>();
  33. public Instruction CoordTemp { get; set; }
  34. private readonly Dictionary<AstOperand, Instruction> _locals = new Dictionary<AstOperand, Instruction>();
  35. private readonly Dictionary<int, Instruction[]> _localForArgs = new Dictionary<int, Instruction[]>();
  36. private readonly Dictionary<int, Instruction> _funcArgs = new Dictionary<int, Instruction>();
  37. private readonly Dictionary<int, (StructuredFunction, Instruction)> _functions = new Dictionary<int, (StructuredFunction, Instruction)>();
  38. private class BlockState
  39. {
  40. private int _entryCount;
  41. private readonly List<Instruction> _labels = new List<Instruction>();
  42. public Instruction GetNextLabel(CodeGenContext context)
  43. {
  44. return GetLabel(context, _entryCount);
  45. }
  46. public Instruction GetNextLabelAutoIncrement(CodeGenContext context)
  47. {
  48. return GetLabel(context, _entryCount++);
  49. }
  50. public Instruction GetLabel(CodeGenContext context, int index)
  51. {
  52. while (index >= _labels.Count)
  53. {
  54. _labels.Add(context.Label());
  55. }
  56. return _labels[index];
  57. }
  58. }
  59. private readonly Dictionary<AstBlock, BlockState> _labels = new Dictionary<AstBlock, BlockState>();
  60. public Dictionary<AstBlock, (Instruction, Instruction)> LoopTargets { get; set; }
  61. public AstBlock CurrentBlock { get; private set; }
  62. public SpirvDelegates Delegates { get; }
  63. public CodeGenContext(
  64. StructuredProgramInfo info,
  65. ShaderConfig config,
  66. GeneratorPool<Instruction> instPool,
  67. GeneratorPool<LiteralInteger> integerPool) : base(SpirvVersionPacked, instPool, integerPool)
  68. {
  69. Info = info;
  70. Config = config;
  71. if (config.Stage == ShaderStage.Geometry)
  72. {
  73. InputTopology inPrimitive = config.GpuAccessor.QueryPrimitiveTopology();
  74. InputVertices = inPrimitive switch
  75. {
  76. InputTopology.Points => 1,
  77. InputTopology.Lines => 2,
  78. InputTopology.LinesAdjacency => 2,
  79. InputTopology.Triangles => 3,
  80. InputTopology.TrianglesAdjacency => 3,
  81. _ => throw new InvalidOperationException($"Invalid input topology \"{inPrimitive}\".")
  82. };
  83. }
  84. AddCapability(Capability.Shader);
  85. AddCapability(Capability.Float64);
  86. SetMemoryModel(AddressingModel.Logical, MemoryModel.GLSL450);
  87. Delegates = new SpirvDelegates(this);
  88. }
  89. public void StartFunction()
  90. {
  91. _locals.Clear();
  92. _localForArgs.Clear();
  93. _funcArgs.Clear();
  94. }
  95. public void EnterBlock(AstBlock block)
  96. {
  97. CurrentBlock = block;
  98. AddLabel(GetBlockStateLazy(block).GetNextLabelAutoIncrement(this));
  99. }
  100. public Instruction GetFirstLabel(AstBlock block)
  101. {
  102. return GetBlockStateLazy(block).GetLabel(this, 0);
  103. }
  104. public Instruction GetNextLabel(AstBlock block)
  105. {
  106. return GetBlockStateLazy(block).GetNextLabel(this);
  107. }
  108. private BlockState GetBlockStateLazy(AstBlock block)
  109. {
  110. if (!_labels.TryGetValue(block, out var blockState))
  111. {
  112. blockState = new BlockState();
  113. _labels.Add(block, blockState);
  114. }
  115. return blockState;
  116. }
  117. public Instruction NewBlock()
  118. {
  119. var label = Label();
  120. Branch(label);
  121. AddLabel(label);
  122. return label;
  123. }
  124. public Instruction[] GetMainInterface()
  125. {
  126. var mainInterface = new List<Instruction>();
  127. mainInterface.AddRange(Inputs.Values);
  128. mainInterface.AddRange(Outputs.Values);
  129. mainInterface.AddRange(InputsPerPatch.Values);
  130. mainInterface.AddRange(OutputsPerPatch.Values);
  131. return mainInterface.ToArray();
  132. }
  133. public void DeclareLocal(AstOperand local, Instruction spvLocal)
  134. {
  135. _locals.Add(local, spvLocal);
  136. }
  137. public void DeclareLocalForArgs(int funcIndex, Instruction[] spvLocals)
  138. {
  139. _localForArgs.Add(funcIndex, spvLocals);
  140. }
  141. public void DeclareArgument(int argIndex, Instruction spvLocal)
  142. {
  143. _funcArgs.Add(argIndex, spvLocal);
  144. }
  145. public void DeclareFunction(int funcIndex, StructuredFunction function, Instruction spvFunc)
  146. {
  147. _functions.Add(funcIndex, (function, spvFunc));
  148. }
  149. public Instruction GetFP32(IAstNode node)
  150. {
  151. return Get(AggregateType.FP32, node);
  152. }
  153. public Instruction GetFP64(IAstNode node)
  154. {
  155. return Get(AggregateType.FP64, node);
  156. }
  157. public Instruction GetS32(IAstNode node)
  158. {
  159. return Get(AggregateType.S32, node);
  160. }
  161. public Instruction GetU32(IAstNode node)
  162. {
  163. return Get(AggregateType.U32, node);
  164. }
  165. public Instruction Get(AggregateType type, IAstNode node)
  166. {
  167. if (node is AstOperation operation)
  168. {
  169. var opResult = Instructions.Generate(this, operation);
  170. return BitcastIfNeeded(type, opResult.Type, opResult.Value);
  171. }
  172. else if (node is AstOperand operand)
  173. {
  174. return operand.Type switch
  175. {
  176. IrOperandType.Argument => GetArgument(type, operand),
  177. IrOperandType.Constant => GetConstant(type, operand),
  178. IrOperandType.ConstantBuffer => GetConstantBuffer(type, operand),
  179. IrOperandType.LocalVariable => GetLocal(type, operand),
  180. IrOperandType.Undefined => GetUndefined(type),
  181. _ => throw new ArgumentException($"Invalid operand type \"{operand.Type}\".")
  182. };
  183. }
  184. throw new NotImplementedException(node.GetType().Name);
  185. }
  186. public Instruction GetWithType(IAstNode node, out AggregateType type)
  187. {
  188. if (node is AstOperation operation)
  189. {
  190. var opResult = Instructions.Generate(this, operation);
  191. type = opResult.Type;
  192. return opResult.Value;
  193. }
  194. else if (node is AstOperand operand)
  195. {
  196. switch (operand.Type)
  197. {
  198. case IrOperandType.LocalVariable:
  199. type = operand.VarType;
  200. return GetLocal(type, operand);
  201. default:
  202. throw new ArgumentException($"Invalid operand type \"{operand.Type}\".");
  203. }
  204. }
  205. throw new NotImplementedException(node.GetType().Name);
  206. }
  207. private Instruction GetUndefined(AggregateType type)
  208. {
  209. return type switch
  210. {
  211. AggregateType.Bool => ConstantFalse(TypeBool()),
  212. AggregateType.FP32 => Constant(TypeFP32(), 0f),
  213. AggregateType.FP64 => Constant(TypeFP64(), 0d),
  214. _ => Constant(GetType(type), 0)
  215. };
  216. }
  217. public Instruction GetConstant(AggregateType type, AstOperand operand)
  218. {
  219. return type switch
  220. {
  221. AggregateType.Bool => operand.Value != 0 ? ConstantTrue(TypeBool()) : ConstantFalse(TypeBool()),
  222. AggregateType.FP32 => Constant(TypeFP32(), BitConverter.Int32BitsToSingle(operand.Value)),
  223. AggregateType.FP64 => Constant(TypeFP64(), (double)BitConverter.Int32BitsToSingle(operand.Value)),
  224. AggregateType.S32 => Constant(TypeS32(), operand.Value),
  225. AggregateType.U32 => Constant(TypeU32(), (uint)operand.Value),
  226. _ => throw new ArgumentException($"Invalid type \"{type}\".")
  227. };
  228. }
  229. public Instruction GetConstantBuffer(AggregateType type, AstOperand operand)
  230. {
  231. var i1 = Constant(TypeS32(), 0);
  232. var i2 = Constant(TypeS32(), operand.CbufOffset >> 2);
  233. var i3 = Constant(TypeU32(), operand.CbufOffset & 3);
  234. Instruction elemPointer;
  235. if (UniformBuffersArray != null)
  236. {
  237. var ubVariable = UniformBuffersArray;
  238. var i0 = Constant(TypeS32(), operand.CbufSlot);
  239. elemPointer = AccessChain(TypePointer(StorageClass.Uniform, TypeFP32()), ubVariable, i0, i1, i2, i3);
  240. }
  241. else
  242. {
  243. var ubVariable = UniformBuffers[operand.CbufSlot];
  244. elemPointer = AccessChain(TypePointer(StorageClass.Uniform, TypeFP32()), ubVariable, i1, i2, i3);
  245. }
  246. return BitcastIfNeeded(type, AggregateType.FP32, Load(TypeFP32(), elemPointer));
  247. }
  248. public Instruction GetLocalPointer(AstOperand local)
  249. {
  250. return _locals[local];
  251. }
  252. public Instruction[] GetLocalForArgsPointers(int funcIndex)
  253. {
  254. return _localForArgs[funcIndex];
  255. }
  256. public Instruction GetArgumentPointer(AstOperand funcArg)
  257. {
  258. return _funcArgs[funcArg.Value];
  259. }
  260. public Instruction GetLocal(AggregateType dstType, AstOperand local)
  261. {
  262. var srcType = local.VarType;
  263. return BitcastIfNeeded(dstType, srcType, Load(GetType(srcType), GetLocalPointer(local)));
  264. }
  265. public Instruction GetArgument(AggregateType dstType, AstOperand funcArg)
  266. {
  267. var srcType = funcArg.VarType;
  268. return BitcastIfNeeded(dstType, srcType, Load(GetType(srcType), GetArgumentPointer(funcArg)));
  269. }
  270. public (StructuredFunction, Instruction) GetFunction(int funcIndex)
  271. {
  272. return _functions[funcIndex];
  273. }
  274. public Instruction GetType(AggregateType type, int length = 1)
  275. {
  276. if ((type & AggregateType.Array) != 0)
  277. {
  278. return TypeArray(GetType(type & ~AggregateType.Array), Constant(TypeU32(), length));
  279. }
  280. else if ((type & AggregateType.ElementCountMask) != 0)
  281. {
  282. int vectorLength = (type & AggregateType.ElementCountMask) switch
  283. {
  284. AggregateType.Vector2 => 2,
  285. AggregateType.Vector3 => 3,
  286. AggregateType.Vector4 => 4,
  287. _ => 1
  288. };
  289. return TypeVector(GetType(type & ~AggregateType.ElementCountMask), vectorLength);
  290. }
  291. return type switch
  292. {
  293. AggregateType.Void => TypeVoid(),
  294. AggregateType.Bool => TypeBool(),
  295. AggregateType.FP32 => TypeFP32(),
  296. AggregateType.FP64 => TypeFP64(),
  297. AggregateType.S32 => TypeS32(),
  298. AggregateType.U32 => TypeU32(),
  299. _ => throw new ArgumentException($"Invalid attribute type \"{type}\".")
  300. };
  301. }
  302. public Instruction BitcastIfNeeded(AggregateType dstType, AggregateType srcType, Instruction value)
  303. {
  304. if (dstType == srcType)
  305. {
  306. return value;
  307. }
  308. if (dstType == AggregateType.Bool)
  309. {
  310. return INotEqual(TypeBool(), BitcastIfNeeded(AggregateType.S32, srcType, value), Constant(TypeS32(), 0));
  311. }
  312. else if (srcType == AggregateType.Bool)
  313. {
  314. var intTrue = Constant(TypeS32(), IrConsts.True);
  315. var intFalse = Constant(TypeS32(), IrConsts.False);
  316. return BitcastIfNeeded(dstType, AggregateType.S32, Select(TypeS32(), value, intTrue, intFalse));
  317. }
  318. else
  319. {
  320. return Bitcast(GetType(dstType, 1), value);
  321. }
  322. }
  323. public Instruction TypeS32()
  324. {
  325. return TypeInt(32, true);
  326. }
  327. public Instruction TypeU32()
  328. {
  329. return TypeInt(32, false);
  330. }
  331. public Instruction TypeFP32()
  332. {
  333. return TypeFloat(32);
  334. }
  335. public Instruction TypeFP64()
  336. {
  337. return TypeFloat(64);
  338. }
  339. }
  340. }