CodeGenContext.cs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. using Ryujinx.Graphics.Shader.StructuredIr;
  2. using Ryujinx.Graphics.Shader.Translation;
  3. using Spv.Generator;
  4. using System;
  5. using System.Collections.Generic;
  6. using static Spv.Specification;
  7. namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
  8. {
  9. using IrConsts = IntermediateRepresentation.IrConsts;
  10. using IrOperandType = IntermediateRepresentation.OperandType;
  11. partial class CodeGenContext : Module
  12. {
  13. private const uint SpirvVersionMajor = 1;
  14. private const uint SpirvVersionMinor = 3;
  15. private const uint SpirvVersionRevision = 0;
  16. private const uint SpirvVersionPacked = (SpirvVersionMajor << 16) | (SpirvVersionMinor << 8) | SpirvVersionRevision;
  17. public StructuredProgramInfo Info { get; }
  18. public ShaderConfig Config { get; }
  19. public int InputVertices { get; }
  20. public Dictionary<int, Instruction> ConstantBuffers { get; } = new Dictionary<int, Instruction>();
  21. public Dictionary<int, Instruction> StorageBuffers { get; } = new Dictionary<int, Instruction>();
  22. public Dictionary<int, Instruction> LocalMemories { get; } = new Dictionary<int, Instruction>();
  23. public Dictionary<int, Instruction> SharedMemories { get; } = new Dictionary<int, Instruction>();
  24. public Dictionary<TextureMeta, SamplerType> SamplersTypes { get; } = new Dictionary<TextureMeta, SamplerType>();
  25. public Dictionary<TextureMeta, (Instruction, Instruction, Instruction)> Samplers { get; } = new Dictionary<TextureMeta, (Instruction, Instruction, Instruction)>();
  26. public Dictionary<TextureMeta, (Instruction, Instruction)> Images { get; } = new Dictionary<TextureMeta, (Instruction, Instruction)>();
  27. public Dictionary<IoDefinition, Instruction> Inputs { get; } = new Dictionary<IoDefinition, Instruction>();
  28. public Dictionary<IoDefinition, Instruction> Outputs { get; } = new Dictionary<IoDefinition, Instruction>();
  29. public Dictionary<IoDefinition, Instruction> InputsPerPatch { get; } = new Dictionary<IoDefinition, Instruction>();
  30. public Dictionary<IoDefinition, Instruction> OutputsPerPatch { get; } = new Dictionary<IoDefinition, Instruction>();
  31. public StructuredFunction CurrentFunction { get; set; }
  32. private readonly Dictionary<AstOperand, Instruction> _locals = new Dictionary<AstOperand, Instruction>();
  33. private readonly Dictionary<int, Instruction[]> _localForArgs = new Dictionary<int, Instruction[]>();
  34. private readonly Dictionary<int, Instruction> _funcArgs = new Dictionary<int, Instruction>();
  35. private readonly Dictionary<int, (StructuredFunction, Instruction)> _functions = new Dictionary<int, (StructuredFunction, Instruction)>();
  36. private class BlockState
  37. {
  38. private int _entryCount;
  39. private readonly List<Instruction> _labels = new List<Instruction>();
  40. public Instruction GetNextLabel(CodeGenContext context)
  41. {
  42. return GetLabel(context, _entryCount);
  43. }
  44. public Instruction GetNextLabelAutoIncrement(CodeGenContext context)
  45. {
  46. return GetLabel(context, _entryCount++);
  47. }
  48. public Instruction GetLabel(CodeGenContext context, int index)
  49. {
  50. while (index >= _labels.Count)
  51. {
  52. _labels.Add(context.Label());
  53. }
  54. return _labels[index];
  55. }
  56. }
  57. private readonly Dictionary<AstBlock, BlockState> _labels = new Dictionary<AstBlock, BlockState>();
  58. public Dictionary<AstBlock, (Instruction, Instruction)> LoopTargets { get; set; }
  59. public AstBlock CurrentBlock { get; private set; }
  60. public SpirvDelegates Delegates { get; }
  61. public bool IsMainFunction { get; private set; }
  62. public bool MayHaveReturned { get; set; }
  63. public CodeGenContext(
  64. StructuredProgramInfo info,
  65. ShaderConfig config,
  66. GeneratorPool<Instruction> instPool,
  67. GeneratorPool<LiteralInteger> integerPool) : base(SpirvVersionPacked, instPool, integerPool)
  68. {
  69. Info = info;
  70. Config = config;
  71. if (config.Stage == ShaderStage.Geometry)
  72. {
  73. InputTopology inPrimitive = config.GpuAccessor.QueryPrimitiveTopology();
  74. InputVertices = inPrimitive switch
  75. {
  76. InputTopology.Points => 1,
  77. InputTopology.Lines => 2,
  78. InputTopology.LinesAdjacency => 2,
  79. InputTopology.Triangles => 3,
  80. InputTopology.TrianglesAdjacency => 3,
  81. _ => throw new InvalidOperationException($"Invalid input topology \"{inPrimitive}\".")
  82. };
  83. }
  84. AddCapability(Capability.Shader);
  85. AddCapability(Capability.Float64);
  86. SetMemoryModel(AddressingModel.Logical, MemoryModel.GLSL450);
  87. Delegates = new SpirvDelegates(this);
  88. }
  89. public void StartFunction(bool isMainFunction)
  90. {
  91. IsMainFunction = isMainFunction;
  92. MayHaveReturned = false;
  93. _locals.Clear();
  94. _localForArgs.Clear();
  95. _funcArgs.Clear();
  96. }
  97. public void EnterBlock(AstBlock block)
  98. {
  99. CurrentBlock = block;
  100. AddLabel(GetBlockStateLazy(block).GetNextLabelAutoIncrement(this));
  101. }
  102. public Instruction GetFirstLabel(AstBlock block)
  103. {
  104. return GetBlockStateLazy(block).GetLabel(this, 0);
  105. }
  106. public Instruction GetNextLabel(AstBlock block)
  107. {
  108. return GetBlockStateLazy(block).GetNextLabel(this);
  109. }
  110. private BlockState GetBlockStateLazy(AstBlock block)
  111. {
  112. if (!_labels.TryGetValue(block, out var blockState))
  113. {
  114. blockState = new BlockState();
  115. _labels.Add(block, blockState);
  116. }
  117. return blockState;
  118. }
  119. public Instruction NewBlock()
  120. {
  121. var label = Label();
  122. Branch(label);
  123. AddLabel(label);
  124. return label;
  125. }
  126. public Instruction[] GetMainInterface()
  127. {
  128. var mainInterface = new List<Instruction>();
  129. mainInterface.AddRange(Inputs.Values);
  130. mainInterface.AddRange(Outputs.Values);
  131. mainInterface.AddRange(InputsPerPatch.Values);
  132. mainInterface.AddRange(OutputsPerPatch.Values);
  133. return mainInterface.ToArray();
  134. }
  135. public void DeclareLocal(AstOperand local, Instruction spvLocal)
  136. {
  137. _locals.Add(local, spvLocal);
  138. }
  139. public void DeclareLocalForArgs(int funcIndex, Instruction[] spvLocals)
  140. {
  141. _localForArgs.Add(funcIndex, spvLocals);
  142. }
  143. public void DeclareArgument(int argIndex, Instruction spvLocal)
  144. {
  145. _funcArgs.Add(argIndex, spvLocal);
  146. }
  147. public void DeclareFunction(int funcIndex, StructuredFunction function, Instruction spvFunc)
  148. {
  149. _functions.Add(funcIndex, (function, spvFunc));
  150. }
  151. public Instruction GetFP32(IAstNode node)
  152. {
  153. return Get(AggregateType.FP32, node);
  154. }
  155. public Instruction GetFP64(IAstNode node)
  156. {
  157. return Get(AggregateType.FP64, node);
  158. }
  159. public Instruction GetS32(IAstNode node)
  160. {
  161. return Get(AggregateType.S32, node);
  162. }
  163. public Instruction GetU32(IAstNode node)
  164. {
  165. return Get(AggregateType.U32, node);
  166. }
  167. public Instruction Get(AggregateType type, IAstNode node)
  168. {
  169. if (node is AstOperation operation)
  170. {
  171. var opResult = Instructions.Generate(this, operation);
  172. return BitcastIfNeeded(type, opResult.Type, opResult.Value);
  173. }
  174. else if (node is AstOperand operand)
  175. {
  176. return operand.Type switch
  177. {
  178. IrOperandType.Argument => GetArgument(type, operand),
  179. IrOperandType.Constant => GetConstant(type, operand),
  180. IrOperandType.LocalVariable => GetLocal(type, operand),
  181. IrOperandType.Undefined => GetUndefined(type),
  182. _ => throw new ArgumentException($"Invalid operand type \"{operand.Type}\".")
  183. };
  184. }
  185. throw new NotImplementedException(node.GetType().Name);
  186. }
  187. public Instruction GetWithType(IAstNode node, out AggregateType type)
  188. {
  189. if (node is AstOperation operation)
  190. {
  191. var opResult = Instructions.Generate(this, operation);
  192. type = opResult.Type;
  193. return opResult.Value;
  194. }
  195. else if (node is AstOperand operand)
  196. {
  197. switch (operand.Type)
  198. {
  199. case IrOperandType.LocalVariable:
  200. type = operand.VarType;
  201. return GetLocal(type, operand);
  202. default:
  203. throw new ArgumentException($"Invalid operand type \"{operand.Type}\".");
  204. }
  205. }
  206. throw new NotImplementedException(node.GetType().Name);
  207. }
  208. private Instruction GetUndefined(AggregateType type)
  209. {
  210. return type switch
  211. {
  212. AggregateType.Bool => ConstantFalse(TypeBool()),
  213. AggregateType.FP32 => Constant(TypeFP32(), 0f),
  214. AggregateType.FP64 => Constant(TypeFP64(), 0d),
  215. _ => Constant(GetType(type), 0)
  216. };
  217. }
  218. public Instruction GetConstant(AggregateType type, AstOperand operand)
  219. {
  220. return type switch
  221. {
  222. AggregateType.Bool => operand.Value != 0 ? ConstantTrue(TypeBool()) : ConstantFalse(TypeBool()),
  223. AggregateType.FP32 => Constant(TypeFP32(), BitConverter.Int32BitsToSingle(operand.Value)),
  224. AggregateType.FP64 => Constant(TypeFP64(), (double)BitConverter.Int32BitsToSingle(operand.Value)),
  225. AggregateType.S32 => Constant(TypeS32(), operand.Value),
  226. AggregateType.U32 => Constant(TypeU32(), (uint)operand.Value),
  227. _ => throw new ArgumentException($"Invalid type \"{type}\".")
  228. };
  229. }
  230. public Instruction GetLocalPointer(AstOperand local)
  231. {
  232. return _locals[local];
  233. }
  234. public Instruction[] GetLocalForArgsPointers(int funcIndex)
  235. {
  236. return _localForArgs[funcIndex];
  237. }
  238. public Instruction GetArgumentPointer(AstOperand funcArg)
  239. {
  240. return _funcArgs[funcArg.Value];
  241. }
  242. public Instruction GetLocal(AggregateType dstType, AstOperand local)
  243. {
  244. var srcType = local.VarType;
  245. return BitcastIfNeeded(dstType, srcType, Load(GetType(srcType), GetLocalPointer(local)));
  246. }
  247. public Instruction GetArgument(AggregateType dstType, AstOperand funcArg)
  248. {
  249. var srcType = funcArg.VarType;
  250. return BitcastIfNeeded(dstType, srcType, Load(GetType(srcType), GetArgumentPointer(funcArg)));
  251. }
  252. public (StructuredFunction, Instruction) GetFunction(int funcIndex)
  253. {
  254. return _functions[funcIndex];
  255. }
  256. public Instruction GetType(AggregateType type, int length = 1)
  257. {
  258. if ((type & AggregateType.Array) != 0)
  259. {
  260. if (length > 0)
  261. {
  262. return TypeArray(GetType(type & ~AggregateType.Array), Constant(TypeU32(), length));
  263. }
  264. else
  265. {
  266. return TypeRuntimeArray(GetType(type & ~AggregateType.Array));
  267. }
  268. }
  269. else if ((type & AggregateType.ElementCountMask) != 0)
  270. {
  271. int vectorLength = (type & AggregateType.ElementCountMask) switch
  272. {
  273. AggregateType.Vector2 => 2,
  274. AggregateType.Vector3 => 3,
  275. AggregateType.Vector4 => 4,
  276. _ => 1
  277. };
  278. return TypeVector(GetType(type & ~AggregateType.ElementCountMask), vectorLength);
  279. }
  280. return type switch
  281. {
  282. AggregateType.Void => TypeVoid(),
  283. AggregateType.Bool => TypeBool(),
  284. AggregateType.FP32 => TypeFP32(),
  285. AggregateType.FP64 => TypeFP64(),
  286. AggregateType.S32 => TypeS32(),
  287. AggregateType.U32 => TypeU32(),
  288. _ => throw new ArgumentException($"Invalid attribute type \"{type}\".")
  289. };
  290. }
  291. public Instruction BitcastIfNeeded(AggregateType dstType, AggregateType srcType, Instruction value)
  292. {
  293. if (dstType == srcType)
  294. {
  295. return value;
  296. }
  297. if (dstType == AggregateType.Bool)
  298. {
  299. return INotEqual(TypeBool(), BitcastIfNeeded(AggregateType.S32, srcType, value), Constant(TypeS32(), 0));
  300. }
  301. else if (srcType == AggregateType.Bool)
  302. {
  303. var intTrue = Constant(TypeS32(), IrConsts.True);
  304. var intFalse = Constant(TypeS32(), IrConsts.False);
  305. return BitcastIfNeeded(dstType, AggregateType.S32, Select(TypeS32(), value, intTrue, intFalse));
  306. }
  307. else
  308. {
  309. return Bitcast(GetType(dstType, 1), value);
  310. }
  311. }
  312. public Instruction TypeS32()
  313. {
  314. return TypeInt(32, true);
  315. }
  316. public Instruction TypeU32()
  317. {
  318. return TypeInt(32, false);
  319. }
  320. public Instruction TypeFP32()
  321. {
  322. return TypeFloat(32);
  323. }
  324. public Instruction TypeFP64()
  325. {
  326. return TypeFloat(64);
  327. }
  328. }
  329. }