SpirvGenerator.cs 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. using Ryujinx.Common;
  2. using Ryujinx.Graphics.Shader.IntermediateRepresentation;
  3. using Ryujinx.Graphics.Shader.StructuredIr;
  4. using Ryujinx.Graphics.Shader.Translation;
  5. using System;
  6. using System.Collections.Generic;
  7. using static Spv.Specification;
  8. namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
  9. {
  10. using SpvInstruction = Spv.Generator.Instruction;
  11. using SpvLiteralInteger = Spv.Generator.LiteralInteger;
  12. using SpvInstructionPool = Spv.Generator.GeneratorPool<Spv.Generator.Instruction>;
  13. using SpvLiteralIntegerPool = Spv.Generator.GeneratorPool<Spv.Generator.LiteralInteger>;
  14. static class SpirvGenerator
  15. {
  16. // Resource pools for Spirv generation. Note: Increase count when more threads are being used.
  17. private const int GeneratorPoolCount = 1;
  18. private static ObjectPool<SpvInstructionPool> InstructionPool;
  19. private static ObjectPool<SpvLiteralIntegerPool> IntegerPool;
  20. private static object PoolLock;
  21. static SpirvGenerator()
  22. {
  23. InstructionPool = new (() => new SpvInstructionPool(), GeneratorPoolCount);
  24. IntegerPool = new (() => new SpvLiteralIntegerPool(), GeneratorPoolCount);
  25. PoolLock = new object();
  26. }
  27. private const HelperFunctionsMask NeedsInvocationIdMask =
  28. HelperFunctionsMask.Shuffle |
  29. HelperFunctionsMask.ShuffleDown |
  30. HelperFunctionsMask.ShuffleUp |
  31. HelperFunctionsMask.ShuffleXor |
  32. HelperFunctionsMask.SwizzleAdd;
  33. public static byte[] Generate(StructuredProgramInfo info, ShaderConfig config)
  34. {
  35. SpvInstructionPool instPool;
  36. SpvLiteralIntegerPool integerPool;
  37. lock (PoolLock)
  38. {
  39. instPool = InstructionPool.Allocate();
  40. integerPool = IntegerPool.Allocate();
  41. }
  42. CodeGenContext context = new CodeGenContext(info, config, instPool, integerPool);
  43. context.AddCapability(Capability.GroupNonUniformBallot);
  44. context.AddCapability(Capability.ImageBuffer);
  45. context.AddCapability(Capability.ImageGatherExtended);
  46. context.AddCapability(Capability.ImageQuery);
  47. context.AddCapability(Capability.SampledBuffer);
  48. context.AddCapability(Capability.SubgroupBallotKHR);
  49. context.AddCapability(Capability.SubgroupVoteKHR);
  50. if (config.TransformFeedbackEnabled && config.LastInVertexPipeline)
  51. {
  52. context.AddCapability(Capability.TransformFeedback);
  53. }
  54. if (config.Stage == ShaderStage.Fragment)
  55. {
  56. if (context.Info.Inputs.Contains(AttributeConsts.Layer))
  57. {
  58. context.AddCapability(Capability.Geometry);
  59. }
  60. if (context.Config.GpuAccessor.QueryHostSupportsFragmentShaderInterlock())
  61. {
  62. context.AddCapability(Capability.FragmentShaderPixelInterlockEXT);
  63. context.AddExtension("SPV_EXT_fragment_shader_interlock");
  64. }
  65. }
  66. else if (config.Stage == ShaderStage.Geometry)
  67. {
  68. context.AddCapability(Capability.Geometry);
  69. if (config.GpPassthrough && context.Config.GpuAccessor.QueryHostSupportsGeometryShaderPassthrough())
  70. {
  71. context.AddExtension("SPV_NV_geometry_shader_passthrough");
  72. context.AddCapability(Capability.GeometryShaderPassthroughNV);
  73. }
  74. }
  75. else if (config.Stage == ShaderStage.TessellationControl || config.Stage == ShaderStage.TessellationEvaluation)
  76. {
  77. context.AddCapability(Capability.Tessellation);
  78. }
  79. else if (config.Stage == ShaderStage.Vertex)
  80. {
  81. context.AddCapability(Capability.DrawParameters);
  82. }
  83. context.AddExtension("SPV_KHR_shader_ballot");
  84. context.AddExtension("SPV_KHR_subgroup_vote");
  85. Declarations.DeclareAll(context, info);
  86. if ((info.HelperFunctionsMask & NeedsInvocationIdMask) != 0)
  87. {
  88. Declarations.DeclareInvocationId(context);
  89. }
  90. for (int funcIndex = 0; funcIndex < info.Functions.Count; funcIndex++)
  91. {
  92. var function = info.Functions[funcIndex];
  93. var retType = context.GetType(function.ReturnType.Convert());
  94. var funcArgs = new SpvInstruction[function.InArguments.Length + function.OutArguments.Length];
  95. for (int argIndex = 0; argIndex < funcArgs.Length; argIndex++)
  96. {
  97. var argType = context.GetType(function.GetArgumentType(argIndex).Convert());
  98. var argPointerType = context.TypePointer(StorageClass.Function, argType);
  99. funcArgs[argIndex] = argPointerType;
  100. }
  101. var funcType = context.TypeFunction(retType, false, funcArgs);
  102. var spvFunc = context.Function(retType, FunctionControlMask.MaskNone, funcType);
  103. context.DeclareFunction(funcIndex, function, spvFunc);
  104. }
  105. for (int funcIndex = 0; funcIndex < info.Functions.Count; funcIndex++)
  106. {
  107. Generate(context, info, funcIndex);
  108. }
  109. byte[] result = context.Generate();
  110. lock (PoolLock)
  111. {
  112. InstructionPool.Release(instPool);
  113. IntegerPool.Release(integerPool);
  114. }
  115. return result;
  116. }
  117. private static void Generate(CodeGenContext context, StructuredProgramInfo info, int funcIndex)
  118. {
  119. var function = info.Functions[funcIndex];
  120. (_, var spvFunc) = context.GetFunction(funcIndex);
  121. context.AddFunction(spvFunc);
  122. context.StartFunction();
  123. Declarations.DeclareParameters(context, function);
  124. context.EnterBlock(function.MainBlock);
  125. Declarations.DeclareLocals(context, function);
  126. Declarations.DeclareLocalForArgs(context, info.Functions);
  127. Generate(context, function.MainBlock);
  128. // Functions must always end with a return.
  129. if (!(function.MainBlock.Last is AstOperation operation) ||
  130. (operation.Inst != Instruction.Return && operation.Inst != Instruction.Discard))
  131. {
  132. context.Return();
  133. }
  134. context.FunctionEnd();
  135. if (funcIndex == 0)
  136. {
  137. context.AddEntryPoint(context.Config.Stage.Convert(), spvFunc, "main", context.GetMainInterface());
  138. if (context.Config.Stage == ShaderStage.TessellationControl)
  139. {
  140. context.AddExecutionMode(spvFunc, ExecutionMode.OutputVertices, (SpvLiteralInteger)context.Config.ThreadsPerInputPrimitive);
  141. }
  142. else if (context.Config.Stage == ShaderStage.TessellationEvaluation)
  143. {
  144. switch (context.Config.GpuAccessor.QueryTessPatchType())
  145. {
  146. case TessPatchType.Isolines:
  147. context.AddExecutionMode(spvFunc, ExecutionMode.Isolines);
  148. break;
  149. case TessPatchType.Triangles:
  150. context.AddExecutionMode(spvFunc, ExecutionMode.Triangles);
  151. break;
  152. case TessPatchType.Quads:
  153. context.AddExecutionMode(spvFunc, ExecutionMode.Quads);
  154. break;
  155. }
  156. switch (context.Config.GpuAccessor.QueryTessSpacing())
  157. {
  158. case TessSpacing.EqualSpacing:
  159. context.AddExecutionMode(spvFunc, ExecutionMode.SpacingEqual);
  160. break;
  161. case TessSpacing.FractionalEventSpacing:
  162. context.AddExecutionMode(spvFunc, ExecutionMode.SpacingFractionalEven);
  163. break;
  164. case TessSpacing.FractionalOddSpacing:
  165. context.AddExecutionMode(spvFunc, ExecutionMode.SpacingFractionalOdd);
  166. break;
  167. }
  168. bool tessCw = context.Config.GpuAccessor.QueryTessCw();
  169. if (context.Config.Options.TargetApi == TargetApi.Vulkan)
  170. {
  171. // We invert the front face on Vulkan backend, so we need to do that here aswell.
  172. tessCw = !tessCw;
  173. }
  174. if (tessCw)
  175. {
  176. context.AddExecutionMode(spvFunc, ExecutionMode.VertexOrderCw);
  177. }
  178. else
  179. {
  180. context.AddExecutionMode(spvFunc, ExecutionMode.VertexOrderCcw);
  181. }
  182. }
  183. else if (context.Config.Stage == ShaderStage.Geometry)
  184. {
  185. InputTopology inputTopology = context.Config.GpuAccessor.QueryPrimitiveTopology();
  186. context.AddExecutionMode(spvFunc, inputTopology switch
  187. {
  188. InputTopology.Points => ExecutionMode.InputPoints,
  189. InputTopology.Lines => ExecutionMode.InputLines,
  190. InputTopology.LinesAdjacency => ExecutionMode.InputLinesAdjacency,
  191. InputTopology.Triangles => ExecutionMode.Triangles,
  192. InputTopology.TrianglesAdjacency => ExecutionMode.InputTrianglesAdjacency,
  193. _ => throw new InvalidOperationException($"Invalid input topology \"{inputTopology}\".")
  194. });
  195. context.AddExecutionMode(spvFunc, ExecutionMode.Invocations, (SpvLiteralInteger)context.Config.ThreadsPerInputPrimitive);
  196. context.AddExecutionMode(spvFunc, context.Config.OutputTopology switch
  197. {
  198. OutputTopology.PointList => ExecutionMode.OutputPoints,
  199. OutputTopology.LineStrip => ExecutionMode.OutputLineStrip,
  200. OutputTopology.TriangleStrip => ExecutionMode.OutputTriangleStrip,
  201. _ => throw new InvalidOperationException($"Invalid output topology \"{context.Config.OutputTopology}\".")
  202. });
  203. int maxOutputVertices = context.Config.GpPassthrough ? context.InputVertices : context.Config.MaxOutputVertices;
  204. context.AddExecutionMode(spvFunc, ExecutionMode.OutputVertices, (SpvLiteralInteger)maxOutputVertices);
  205. }
  206. else if (context.Config.Stage == ShaderStage.Fragment)
  207. {
  208. context.AddExecutionMode(spvFunc, context.Config.Options.TargetApi == TargetApi.Vulkan
  209. ? ExecutionMode.OriginUpperLeft
  210. : ExecutionMode.OriginLowerLeft);
  211. if (context.Outputs.ContainsKey(AttributeConsts.FragmentOutputDepth))
  212. {
  213. context.AddExecutionMode(spvFunc, ExecutionMode.DepthReplacing);
  214. }
  215. if (context.Config.GpuAccessor.QueryEarlyZForce())
  216. {
  217. context.AddExecutionMode(spvFunc, ExecutionMode.EarlyFragmentTests);
  218. }
  219. if ((info.HelperFunctionsMask & HelperFunctionsMask.FSI) != 0 &&
  220. context.Config.GpuAccessor.QueryHostSupportsFragmentShaderInterlock())
  221. {
  222. context.AddExecutionMode(spvFunc, ExecutionMode.PixelInterlockOrderedEXT);
  223. }
  224. }
  225. else if (context.Config.Stage == ShaderStage.Compute)
  226. {
  227. var localSizeX = (SpvLiteralInteger)context.Config.GpuAccessor.QueryComputeLocalSizeX();
  228. var localSizeY = (SpvLiteralInteger)context.Config.GpuAccessor.QueryComputeLocalSizeY();
  229. var localSizeZ = (SpvLiteralInteger)context.Config.GpuAccessor.QueryComputeLocalSizeZ();
  230. context.AddExecutionMode(
  231. spvFunc,
  232. ExecutionMode.LocalSize,
  233. localSizeX,
  234. localSizeY,
  235. localSizeZ);
  236. }
  237. if (context.Config.TransformFeedbackEnabled && context.Config.LastInVertexPipeline)
  238. {
  239. context.AddExecutionMode(spvFunc, ExecutionMode.Xfb);
  240. }
  241. }
  242. }
  243. private static void Generate(CodeGenContext context, AstBlock block)
  244. {
  245. AstBlockVisitor visitor = new AstBlockVisitor(block);
  246. var loopTargets = new Dictionary<AstBlock, (SpvInstruction, SpvInstruction)>();
  247. context.LoopTargets = loopTargets;
  248. visitor.BlockEntered += (sender, e) =>
  249. {
  250. AstBlock mergeBlock = e.Block.Parent;
  251. if (e.Block.Type == AstBlockType.If)
  252. {
  253. AstBlock ifTrueBlock = e.Block;
  254. AstBlock ifFalseBlock;
  255. if (AstHelper.Next(e.Block) is AstBlock nextBlock && nextBlock.Type == AstBlockType.Else)
  256. {
  257. ifFalseBlock = nextBlock;
  258. }
  259. else
  260. {
  261. ifFalseBlock = mergeBlock;
  262. }
  263. var condition = context.Get(AggregateType.Bool, e.Block.Condition);
  264. context.SelectionMerge(context.GetNextLabel(mergeBlock), SelectionControlMask.MaskNone);
  265. context.BranchConditional(condition, context.GetNextLabel(ifTrueBlock), context.GetNextLabel(ifFalseBlock));
  266. }
  267. else if (e.Block.Type == AstBlockType.DoWhile)
  268. {
  269. var continueTarget = context.Label();
  270. loopTargets.Add(e.Block, (context.NewBlock(), continueTarget));
  271. context.LoopMerge(context.GetNextLabel(mergeBlock), continueTarget, LoopControlMask.MaskNone);
  272. context.Branch(context.GetFirstLabel(e.Block));
  273. }
  274. context.EnterBlock(e.Block);
  275. };
  276. visitor.BlockLeft += (sender, e) =>
  277. {
  278. if (e.Block.Parent != null)
  279. {
  280. if (e.Block.Type == AstBlockType.DoWhile)
  281. {
  282. // This is a loop, we need to jump back to the loop header
  283. // if the condition is true.
  284. AstBlock mergeBlock = e.Block.Parent;
  285. (var loopTarget, var continueTarget) = loopTargets[e.Block];
  286. context.Branch(continueTarget);
  287. context.AddLabel(continueTarget);
  288. var condition = context.Get(AggregateType.Bool, e.Block.Condition);
  289. context.BranchConditional(condition, loopTarget, context.GetNextLabel(mergeBlock));
  290. }
  291. else
  292. {
  293. // We only need a branch if the last instruction didn't
  294. // already cause the program to exit or jump elsewhere.
  295. bool lastIsCf = e.Block.Last is AstOperation lastOp &&
  296. (lastOp.Inst == Instruction.Discard ||
  297. lastOp.Inst == Instruction.LoopBreak ||
  298. lastOp.Inst == Instruction.LoopContinue ||
  299. lastOp.Inst == Instruction.Return);
  300. if (!lastIsCf)
  301. {
  302. context.Branch(context.GetNextLabel(e.Block.Parent));
  303. }
  304. }
  305. bool hasElse = AstHelper.Next(e.Block) is AstBlock nextBlock &&
  306. (nextBlock.Type == AstBlockType.Else ||
  307. nextBlock.Type == AstBlockType.ElseIf);
  308. // Re-enter the parent block.
  309. if (e.Block.Parent != null && !hasElse)
  310. {
  311. context.EnterBlock(e.Block.Parent);
  312. }
  313. }
  314. };
  315. foreach (IAstNode node in visitor.Visit())
  316. {
  317. if (node is AstAssignment assignment)
  318. {
  319. var dest = (AstOperand)assignment.Destination;
  320. if (dest.Type == OperandType.LocalVariable)
  321. {
  322. var source = context.Get(dest.VarType.Convert(), assignment.Source);
  323. context.Store(context.GetLocalPointer(dest), source);
  324. }
  325. else if (dest.Type == OperandType.Attribute || dest.Type == OperandType.AttributePerPatch)
  326. {
  327. bool perPatch = dest.Type == OperandType.AttributePerPatch;
  328. if (AttributeInfo.Validate(context.Config, dest.Value, isOutAttr: true, perPatch))
  329. {
  330. AggregateType elemType;
  331. var elemPointer = perPatch
  332. ? context.GetAttributePerPatchElemPointer(dest.Value, true, out elemType)
  333. : context.GetAttributeElemPointer(dest.Value, true, null, out elemType);
  334. context.Store(elemPointer, context.Get(elemType, assignment.Source));
  335. }
  336. }
  337. else if (dest.Type == OperandType.Argument)
  338. {
  339. var source = context.Get(dest.VarType.Convert(), assignment.Source);
  340. context.Store(context.GetArgumentPointer(dest), source);
  341. }
  342. else
  343. {
  344. throw new NotImplementedException(dest.Type.ToString());
  345. }
  346. }
  347. else if (node is AstOperation operation)
  348. {
  349. Instructions.Generate(context, operation);
  350. }
  351. }
  352. }
  353. }
  354. }