SpirvGenerator.cs 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. using Ryujinx.Common;
  2. using Ryujinx.Graphics.Shader.IntermediateRepresentation;
  3. using Ryujinx.Graphics.Shader.StructuredIr;
  4. using Ryujinx.Graphics.Shader.Translation;
  5. using System;
  6. using System.Collections.Generic;
  7. using static Spv.Specification;
  8. namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
  9. {
  10. using SpvInstruction = Spv.Generator.Instruction;
  11. using SpvInstructionPool = Spv.Generator.GeneratorPool<Spv.Generator.Instruction>;
  12. using SpvLiteralInteger = Spv.Generator.LiteralInteger;
  13. using SpvLiteralIntegerPool = Spv.Generator.GeneratorPool<Spv.Generator.LiteralInteger>;
  14. static class SpirvGenerator
  15. {
  16. // Resource pools for Spirv generation. Note: Increase count when more threads are being used.
  17. private const int GeneratorPoolCount = 1;
  18. private static readonly ObjectPool<SpvInstructionPool> _instructionPool;
  19. private static readonly ObjectPool<SpvLiteralIntegerPool> _integerPool;
  20. private static readonly object _poolLock;
  21. static SpirvGenerator()
  22. {
  23. _instructionPool = new(() => new SpvInstructionPool(), GeneratorPoolCount);
  24. _integerPool = new(() => new SpvLiteralIntegerPool(), GeneratorPoolCount);
  25. _poolLock = new object();
  26. }
  27. private const HelperFunctionsMask NeedsInvocationIdMask = HelperFunctionsMask.SwizzleAdd;
  28. public static byte[] Generate(StructuredProgramInfo info, CodeGenParameters parameters)
  29. {
  30. SpvInstructionPool instPool;
  31. SpvLiteralIntegerPool integerPool;
  32. lock (_poolLock)
  33. {
  34. instPool = _instructionPool.Allocate();
  35. integerPool = _integerPool.Allocate();
  36. }
  37. CodeGenContext context = new(info, parameters, instPool, integerPool);
  38. context.AddCapability(Capability.GroupNonUniformBallot);
  39. context.AddCapability(Capability.GroupNonUniformShuffle);
  40. context.AddCapability(Capability.GroupNonUniformVote);
  41. context.AddCapability(Capability.ImageBuffer);
  42. context.AddCapability(Capability.ImageGatherExtended);
  43. context.AddCapability(Capability.ImageQuery);
  44. context.AddCapability(Capability.SampledBuffer);
  45. if (parameters.Definitions.TransformFeedbackEnabled && parameters.Definitions.LastInVertexPipeline)
  46. {
  47. context.AddCapability(Capability.TransformFeedback);
  48. }
  49. if (parameters.Definitions.Stage == ShaderStage.Fragment)
  50. {
  51. if (context.Info.IoDefinitions.Contains(new IoDefinition(StorageKind.Input, IoVariable.Layer)))
  52. {
  53. context.AddCapability(Capability.Geometry);
  54. }
  55. if (context.HostCapabilities.SupportsFragmentShaderInterlock)
  56. {
  57. context.AddCapability(Capability.FragmentShaderPixelInterlockEXT);
  58. context.AddExtension("SPV_EXT_fragment_shader_interlock");
  59. }
  60. }
  61. else if (parameters.Definitions.Stage == ShaderStage.Geometry)
  62. {
  63. context.AddCapability(Capability.Geometry);
  64. if (parameters.Definitions.GpPassthrough && context.HostCapabilities.SupportsGeometryShaderPassthrough)
  65. {
  66. context.AddExtension("SPV_NV_geometry_shader_passthrough");
  67. context.AddCapability(Capability.GeometryShaderPassthroughNV);
  68. }
  69. }
  70. else if (parameters.Definitions.Stage == ShaderStage.TessellationControl ||
  71. parameters.Definitions.Stage == ShaderStage.TessellationEvaluation)
  72. {
  73. context.AddCapability(Capability.Tessellation);
  74. }
  75. else if (parameters.Definitions.Stage == ShaderStage.Vertex)
  76. {
  77. context.AddCapability(Capability.DrawParameters);
  78. }
  79. if (context.Info.IoDefinitions.Contains(new IoDefinition(StorageKind.Output, IoVariable.ViewportMask)))
  80. {
  81. context.AddExtension("SPV_NV_viewport_array2");
  82. context.AddCapability(Capability.ShaderViewportMaskNV);
  83. }
  84. if ((info.HelperFunctionsMask & NeedsInvocationIdMask) != 0)
  85. {
  86. info.IoDefinitions.Add(new IoDefinition(StorageKind.Input, IoVariable.SubgroupLaneId));
  87. }
  88. Declarations.DeclareAll(context, info);
  89. for (int funcIndex = 0; funcIndex < info.Functions.Count; funcIndex++)
  90. {
  91. var function = info.Functions[funcIndex];
  92. var retType = context.GetType(function.ReturnType);
  93. var funcArgs = new SpvInstruction[function.InArguments.Length + function.OutArguments.Length];
  94. for (int argIndex = 0; argIndex < funcArgs.Length; argIndex++)
  95. {
  96. var argType = context.GetType(function.GetArgumentType(argIndex));
  97. var argPointerType = context.TypePointer(StorageClass.Function, argType);
  98. funcArgs[argIndex] = argPointerType;
  99. }
  100. var funcType = context.TypeFunction(retType, false, funcArgs);
  101. var spvFunc = context.Function(retType, FunctionControlMask.MaskNone, funcType);
  102. context.DeclareFunction(funcIndex, function, spvFunc);
  103. }
  104. for (int funcIndex = 0; funcIndex < info.Functions.Count; funcIndex++)
  105. {
  106. Generate(context, info, funcIndex);
  107. }
  108. byte[] result = context.Generate();
  109. lock (_poolLock)
  110. {
  111. _instructionPool.Release(instPool);
  112. _integerPool.Release(integerPool);
  113. }
  114. return result;
  115. }
  116. private static void Generate(CodeGenContext context, StructuredProgramInfo info, int funcIndex)
  117. {
  118. var (function, spvFunc) = context.GetFunction(funcIndex);
  119. context.CurrentFunction = function;
  120. context.AddFunction(spvFunc);
  121. context.StartFunction(isMainFunction: funcIndex == 0);
  122. Declarations.DeclareParameters(context, function);
  123. context.EnterBlock(function.MainBlock);
  124. Declarations.DeclareLocals(context, function);
  125. Declarations.DeclareLocalForArgs(context, info.Functions);
  126. Generate(context, function.MainBlock);
  127. // Functions must always end with a return.
  128. if (function.MainBlock.Last is not AstOperation operation ||
  129. (operation.Inst != Instruction.Return && operation.Inst != Instruction.Discard))
  130. {
  131. context.Return();
  132. }
  133. context.FunctionEnd();
  134. if (funcIndex == 0)
  135. {
  136. context.AddEntryPoint(context.Definitions.Stage.Convert(), spvFunc, "main", context.GetMainInterface());
  137. if (context.Definitions.Stage == ShaderStage.TessellationControl)
  138. {
  139. context.AddExecutionMode(spvFunc, ExecutionMode.OutputVertices, (SpvLiteralInteger)context.Definitions.ThreadsPerInputPrimitive);
  140. }
  141. else if (context.Definitions.Stage == ShaderStage.TessellationEvaluation)
  142. {
  143. switch (context.Definitions.TessPatchType)
  144. {
  145. case TessPatchType.Isolines:
  146. context.AddExecutionMode(spvFunc, ExecutionMode.Isolines);
  147. break;
  148. case TessPatchType.Triangles:
  149. context.AddExecutionMode(spvFunc, ExecutionMode.Triangles);
  150. break;
  151. case TessPatchType.Quads:
  152. context.AddExecutionMode(spvFunc, ExecutionMode.Quads);
  153. break;
  154. }
  155. switch (context.Definitions.TessSpacing)
  156. {
  157. case TessSpacing.EqualSpacing:
  158. context.AddExecutionMode(spvFunc, ExecutionMode.SpacingEqual);
  159. break;
  160. case TessSpacing.FractionalEventSpacing:
  161. context.AddExecutionMode(spvFunc, ExecutionMode.SpacingFractionalEven);
  162. break;
  163. case TessSpacing.FractionalOddSpacing:
  164. context.AddExecutionMode(spvFunc, ExecutionMode.SpacingFractionalOdd);
  165. break;
  166. }
  167. bool tessCw = context.Definitions.TessCw;
  168. if (context.TargetApi == TargetApi.Vulkan)
  169. {
  170. // We invert the front face on Vulkan backend, so we need to do that here as well.
  171. tessCw = !tessCw;
  172. }
  173. if (tessCw)
  174. {
  175. context.AddExecutionMode(spvFunc, ExecutionMode.VertexOrderCw);
  176. }
  177. else
  178. {
  179. context.AddExecutionMode(spvFunc, ExecutionMode.VertexOrderCcw);
  180. }
  181. }
  182. else if (context.Definitions.Stage == ShaderStage.Geometry)
  183. {
  184. context.AddExecutionMode(spvFunc, context.Definitions.InputTopology switch
  185. {
  186. InputTopology.Points => ExecutionMode.InputPoints,
  187. InputTopology.Lines => ExecutionMode.InputLines,
  188. InputTopology.LinesAdjacency => ExecutionMode.InputLinesAdjacency,
  189. InputTopology.Triangles => ExecutionMode.Triangles,
  190. InputTopology.TrianglesAdjacency => ExecutionMode.InputTrianglesAdjacency,
  191. _ => throw new InvalidOperationException($"Invalid input topology \"{context.Definitions.InputTopology}\"."),
  192. });
  193. context.AddExecutionMode(spvFunc, ExecutionMode.Invocations, (SpvLiteralInteger)context.Definitions.ThreadsPerInputPrimitive);
  194. context.AddExecutionMode(spvFunc, context.Definitions.OutputTopology switch
  195. {
  196. OutputTopology.PointList => ExecutionMode.OutputPoints,
  197. OutputTopology.LineStrip => ExecutionMode.OutputLineStrip,
  198. OutputTopology.TriangleStrip => ExecutionMode.OutputTriangleStrip,
  199. _ => throw new InvalidOperationException($"Invalid output topology \"{context.Definitions.OutputTopology}\"."),
  200. });
  201. context.AddExecutionMode(spvFunc, ExecutionMode.OutputVertices, (SpvLiteralInteger)context.Definitions.MaxOutputVertices);
  202. }
  203. else if (context.Definitions.Stage == ShaderStage.Fragment)
  204. {
  205. context.AddExecutionMode(spvFunc, context.Definitions.OriginUpperLeft
  206. ? ExecutionMode.OriginUpperLeft
  207. : ExecutionMode.OriginLowerLeft);
  208. if (context.Info.IoDefinitions.Contains(new IoDefinition(StorageKind.Output, IoVariable.FragmentOutputDepth)))
  209. {
  210. context.AddExecutionMode(spvFunc, ExecutionMode.DepthReplacing);
  211. }
  212. if (context.Definitions.EarlyZForce)
  213. {
  214. context.AddExecutionMode(spvFunc, ExecutionMode.EarlyFragmentTests);
  215. }
  216. if ((info.HelperFunctionsMask & HelperFunctionsMask.FSI) != 0 &&
  217. context.HostCapabilities.SupportsFragmentShaderInterlock)
  218. {
  219. context.AddExecutionMode(spvFunc, ExecutionMode.PixelInterlockOrderedEXT);
  220. }
  221. }
  222. else if (context.Definitions.Stage == ShaderStage.Compute)
  223. {
  224. var localSizeX = (SpvLiteralInteger)context.Definitions.ComputeLocalSizeX;
  225. var localSizeY = (SpvLiteralInteger)context.Definitions.ComputeLocalSizeY;
  226. var localSizeZ = (SpvLiteralInteger)context.Definitions.ComputeLocalSizeZ;
  227. context.AddExecutionMode(
  228. spvFunc,
  229. ExecutionMode.LocalSize,
  230. localSizeX,
  231. localSizeY,
  232. localSizeZ);
  233. }
  234. if (context.Definitions.Stage != ShaderStage.Fragment &&
  235. context.Definitions.Stage != ShaderStage.Geometry &&
  236. context.Definitions.Stage != ShaderStage.Compute &&
  237. context.Info.IoDefinitions.Contains(new IoDefinition(StorageKind.Output, IoVariable.Layer)))
  238. {
  239. context.AddCapability(Capability.ShaderLayer);
  240. }
  241. if (context.Definitions.TransformFeedbackEnabled && context.Definitions.LastInVertexPipeline)
  242. {
  243. context.AddExecutionMode(spvFunc, ExecutionMode.Xfb);
  244. }
  245. }
  246. }
  247. private static void Generate(CodeGenContext context, AstBlock block)
  248. {
  249. AstBlockVisitor visitor = new(block);
  250. var loopTargets = new Dictionary<AstBlock, (SpvInstruction, SpvInstruction)>();
  251. context.LoopTargets = loopTargets;
  252. visitor.BlockEntered += (sender, e) =>
  253. {
  254. AstBlock mergeBlock = e.Block.Parent;
  255. if (e.Block.Type == AstBlockType.If)
  256. {
  257. AstBlock ifTrueBlock = e.Block;
  258. AstBlock ifFalseBlock;
  259. if (AstHelper.Next(e.Block) is AstBlock nextBlock && nextBlock.Type == AstBlockType.Else)
  260. {
  261. ifFalseBlock = nextBlock;
  262. }
  263. else
  264. {
  265. ifFalseBlock = mergeBlock;
  266. }
  267. var condition = context.Get(AggregateType.Bool, e.Block.Condition);
  268. context.SelectionMerge(context.GetNextLabel(mergeBlock), SelectionControlMask.MaskNone);
  269. context.BranchConditional(condition, context.GetNextLabel(ifTrueBlock), context.GetNextLabel(ifFalseBlock));
  270. }
  271. else if (e.Block.Type == AstBlockType.DoWhile)
  272. {
  273. var continueTarget = context.Label();
  274. loopTargets.Add(e.Block, (context.NewBlock(), continueTarget));
  275. context.LoopMerge(context.GetNextLabel(mergeBlock), continueTarget, LoopControlMask.MaskNone);
  276. context.Branch(context.GetFirstLabel(e.Block));
  277. }
  278. context.EnterBlock(e.Block);
  279. };
  280. visitor.BlockLeft += (sender, e) =>
  281. {
  282. if (e.Block.Parent != null)
  283. {
  284. if (e.Block.Type == AstBlockType.DoWhile)
  285. {
  286. // This is a loop, we need to jump back to the loop header
  287. // if the condition is true.
  288. AstBlock mergeBlock = e.Block.Parent;
  289. var (loopTarget, continueTarget) = loopTargets[e.Block];
  290. context.Branch(continueTarget);
  291. context.AddLabel(continueTarget);
  292. var condition = context.Get(AggregateType.Bool, e.Block.Condition);
  293. context.BranchConditional(condition, loopTarget, context.GetNextLabel(mergeBlock));
  294. }
  295. else
  296. {
  297. // We only need a branch if the last instruction didn't
  298. // already cause the program to exit or jump elsewhere.
  299. bool lastIsCf = e.Block.Last is AstOperation lastOp &&
  300. (lastOp.Inst == Instruction.Discard ||
  301. lastOp.Inst == Instruction.LoopBreak ||
  302. lastOp.Inst == Instruction.LoopContinue ||
  303. lastOp.Inst == Instruction.Return);
  304. if (!lastIsCf)
  305. {
  306. context.Branch(context.GetNextLabel(e.Block.Parent));
  307. }
  308. }
  309. bool hasElse = AstHelper.Next(e.Block) is AstBlock nextBlock &&
  310. (nextBlock.Type == AstBlockType.Else ||
  311. nextBlock.Type == AstBlockType.ElseIf);
  312. // Re-enter the parent block.
  313. if (e.Block.Parent != null && !hasElse)
  314. {
  315. context.EnterBlock(e.Block.Parent);
  316. }
  317. }
  318. };
  319. foreach (IAstNode node in visitor.Visit())
  320. {
  321. if (node is AstAssignment assignment)
  322. {
  323. var dest = (AstOperand)assignment.Destination;
  324. if (dest.Type == OperandType.LocalVariable)
  325. {
  326. var source = context.Get(dest.VarType, assignment.Source);
  327. context.Store(context.GetLocalPointer(dest), source);
  328. }
  329. else if (dest.Type == OperandType.Argument)
  330. {
  331. var source = context.Get(dest.VarType, assignment.Source);
  332. context.Store(context.GetArgumentPointer(dest), source);
  333. }
  334. else
  335. {
  336. throw new NotImplementedException(dest.Type.ToString());
  337. }
  338. }
  339. else if (node is AstOperation operation)
  340. {
  341. Instructions.Generate(context, operation);
  342. }
  343. }
  344. }
  345. }
  346. }