Răsfoiți Sursa

Declare and use gl_PerVertex block for VTG per-vertex built-ins (#5576)

* Declare and use gl_PerVertex block for VTG per-vertex built-ins

* Shader cache version bump
gdkchan 2 ani în urmă
părinte
comite
17354d59d1

+ 1 - 1
src/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs

@@ -22,7 +22,7 @@ namespace Ryujinx.Graphics.Gpu.Shader.DiskCache
         private const ushort FileFormatVersionMajor = 1;
         private const ushort FileFormatVersionMinor = 2;
         private const uint FileFormatVersionPacked = ((uint)FileFormatVersionMajor << 16) | FileFormatVersionMinor;
-        private const uint CodeGenVersion = 5529;
+        private const uint CodeGenVersion = 5576;
 
         private const string SharedTocFileName = "shared.toc";
         private const string SharedDataFileName = "shared.data";

+ 86 - 0
src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs

@@ -348,12 +348,98 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
                         }
                     }
                 }
+                else if (IoMap.IsPerVertexBuiltIn(ioDefinition.IoVariable))
+                {
+                    continue;
+                }
 
                 bool isOutput = ioDefinition.StorageKind.IsOutput();
                 bool isPerPatch = ioDefinition.StorageKind.IsPerPatch();
 
                 DeclareInputOrOutput(context, ioDefinition, isOutput, isPerPatch, iq, firstLocation);
             }
+
+            DeclarePerVertexBlock(context);
+        }
+
+        private static void DeclarePerVertexBlock(CodeGenContext context)
+        {
+            if (context.Definitions.Stage.IsVtg())
+            {
+                if (context.Definitions.Stage != ShaderStage.Vertex)
+                {
+                    var perVertexInputStructType = CreatePerVertexStructType(context);
+                    int arraySize = context.Definitions.Stage == ShaderStage.Geometry ? context.InputVertices : 32;
+                    var perVertexInputArrayType = context.TypeArray(perVertexInputStructType, context.Constant(context.TypeU32(), arraySize));
+                    var perVertexInputPointerType = context.TypePointer(StorageClass.Input, perVertexInputArrayType);
+                    var perVertexInputVariable = context.Variable(perVertexInputPointerType, StorageClass.Input);
+
+                    context.Name(perVertexInputVariable, "gl_in");
+
+                    context.AddGlobalVariable(perVertexInputVariable);
+                    context.Inputs.Add(new IoDefinition(StorageKind.Input, IoVariable.Position), perVertexInputVariable);
+                }
+
+                var perVertexOutputStructType = CreatePerVertexStructType(context);
+
+                void DecorateTfo(IoVariable ioVariable, int fieldIndex)
+                {
+                    if (context.Definitions.TryGetTransformFeedbackOutput(ioVariable, 0, 0, out var transformFeedbackOutput))
+                    {
+                        context.MemberDecorate(perVertexOutputStructType, fieldIndex, Decoration.XfbBuffer, (LiteralInteger)transformFeedbackOutput.Buffer);
+                        context.MemberDecorate(perVertexOutputStructType, fieldIndex, Decoration.XfbStride, (LiteralInteger)transformFeedbackOutput.Stride);
+                        context.MemberDecorate(perVertexOutputStructType, fieldIndex, Decoration.Offset, (LiteralInteger)transformFeedbackOutput.Offset);
+                    }
+                }
+
+                DecorateTfo(IoVariable.Position, 0);
+                DecorateTfo(IoVariable.PointSize, 1);
+                DecorateTfo(IoVariable.ClipDistance, 2);
+
+                SpvInstruction perVertexOutputArrayType;
+
+                if (context.Definitions.Stage == ShaderStage.TessellationControl)
+                {
+                    int arraySize = context.Definitions.ThreadsPerInputPrimitive;
+                    perVertexOutputArrayType = context.TypeArray(perVertexOutputStructType, context.Constant(context.TypeU32(), arraySize));
+                }
+                else
+                {
+                    perVertexOutputArrayType = perVertexOutputStructType;
+                }
+
+                var perVertexOutputPointerType = context.TypePointer(StorageClass.Output, perVertexOutputArrayType);
+                var perVertexOutputVariable = context.Variable(perVertexOutputPointerType, StorageClass.Output);
+
+                context.AddGlobalVariable(perVertexOutputVariable);
+                context.Outputs.Add(new IoDefinition(StorageKind.Output, IoVariable.Position), perVertexOutputVariable);
+            }
+        }
+
+        private static SpvInstruction CreatePerVertexStructType(CodeGenContext context)
+        {
+            var vec4FloatType = context.TypeVector(context.TypeFP32(), 4);
+            var floatType = context.TypeFP32();
+            var array8FloatType = context.TypeArray(context.TypeFP32(), context.Constant(context.TypeU32(), 8));
+            var array1FloatType = context.TypeArray(context.TypeFP32(), context.Constant(context.TypeU32(), 1));
+
+            var perVertexStructType = context.TypeStruct(true, vec4FloatType, floatType, array8FloatType, array1FloatType);
+
+            context.Name(perVertexStructType, "gl_PerVertex");
+
+            context.MemberName(perVertexStructType, 0, "gl_Position");
+            context.MemberName(perVertexStructType, 1, "gl_PointSize");
+            context.MemberName(perVertexStructType, 2, "gl_ClipDistance");
+            context.MemberName(perVertexStructType, 3, "gl_CullDistance");
+
+            context.Decorate(perVertexStructType, Decoration.Block);
+
+            context.MemberDecorate(perVertexStructType, 0, Decoration.BuiltIn, (LiteralInteger)BuiltIn.Position);
+            context.MemberDecorate(perVertexStructType, 1, Decoration.BuiltIn, (LiteralInteger)BuiltIn.PointSize);
+            context.MemberDecorate(perVertexStructType, 2, Decoration.BuiltIn, (LiteralInteger)BuiltIn.ClipDistance);
+            context.MemberDecorate(perVertexStructType, 3, Decoration.BuiltIn, (LiteralInteger)BuiltIn.CullDistance);
+
+            return perVertexStructType;
         }
 
         private static void DeclareInputOrOutput(

+ 32 - 0
src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs

@@ -1788,6 +1788,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
             StorageClass storageClass;
             SpvInstruction baseObj;
             int srcIndex = 0;
+            IoVariable? perVertexBuiltIn = null;
 
             switch (storageKind)
             {
@@ -1881,6 +1882,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
                     else
                     {
                         (_, varType) = IoMap.GetSpirvBuiltIn(ioVariable);
+
+                        if (IoMap.IsPerVertexBuiltIn(ioVariable))
+                        {
+                            perVertexBuiltIn = ioVariable;
+                            ioVariable = IoVariable.Position;
+                        }
                     }
 
                     varType &= AggregateType.ElementTypeMask;
@@ -1902,6 +1909,31 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
             bool isStoreOrAtomic = operation.Inst == Instruction.Store || operation.Inst.IsAtomic();
             int inputsCount = (isStoreOrAtomic ? operation.SourcesCount - 1 : operation.SourcesCount) - srcIndex;
 
+            if (perVertexBuiltIn.HasValue)
+            {
+                int fieldIndex = IoMap.GetPerVertexStructFieldIndex(perVertexBuiltIn.Value);
+
+                var indexes = new SpvInstruction[inputsCount + 1];
+                int index = 0;
+
+                if (IoMap.IsPerVertexArrayBuiltIn(storageKind, context.Definitions.Stage))
+                {
+                    indexes[index++] = context.Get(AggregateType.S32, operation.GetSource(srcIndex++));
+                    indexes[index++] = context.Constant(context.TypeS32(), fieldIndex);
+                }
+                else
+                {
+                    indexes[index++] = context.Constant(context.TypeS32(), fieldIndex);
+                }
+
+                for (; index < inputsCount + 1; srcIndex++, index++)
+                {
+                    indexes[index] = context.Get(AggregateType.S32, operation.GetSource(srcIndex));
+                }
+
+                return context.AccessChain(context.TypePointer(storageClass, context.GetType(varType)), baseObj, indexes);
+            }
+
             if (operation.Inst == Instruction.AtomicCompareAndSwap)
             {
                 inputsCount--;

+ 39 - 0
src/Ryujinx.Graphics.Shader/CodeGen/Spirv/IoMap.cs

@@ -1,5 +1,6 @@
 using Ryujinx.Graphics.Shader.IntermediateRepresentation;
 using Ryujinx.Graphics.Shader.Translation;
+using System;
 using static Spv.Specification;
 
 namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
@@ -80,5 +81,43 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
 
             return false;
         }
+
+        public static bool IsPerVertexBuiltIn(IoVariable ioVariable)
+        {
+            switch (ioVariable)
+            {
+                case IoVariable.Position:
+                case IoVariable.PointSize:
+                case IoVariable.ClipDistance:
+                    return true;
+            }
+
+            return false;
+        }
+
+        public static bool IsPerVertexArrayBuiltIn(StorageKind storageKind, ShaderStage stage)
+        {
+            if (storageKind == StorageKind.Output)
+            {
+                return stage == ShaderStage.TessellationControl;
+            }
+            else
+            {
+                return stage == ShaderStage.TessellationControl ||
+                       stage == ShaderStage.TessellationEvaluation ||
+                       stage == ShaderStage.Geometry;
+            }
+        }
+
+        public static int GetPerVertexStructFieldIndex(IoVariable ioVariable)
+        {
+            return ioVariable switch
+            {
+                IoVariable.Position => 0,
+                IoVariable.PointSize => 1,
+                IoVariable.ClipDistance => 2,
+                _ => throw new ArgumentException($"Invalid built-in variable {ioVariable}.")
+            };
+        }
     }
 }

+ 13 - 0
src/Ryujinx.Graphics.Shader/ShaderStage.cs

@@ -23,5 +23,18 @@ namespace Ryujinx.Graphics.Shader
         {
             return stage == ShaderStage.Vertex || stage == ShaderStage.Fragment || stage == ShaderStage.Compute;
         }
+
+        /// <summary>
+        /// Checks if the shader stage is vertex, tessellation or geometry.
+        /// </summary>
+        /// <param name="stage">Shader stage</param>
+        /// <returns>True if the shader stage is vertex, tessellation or geometry, false otherwise</returns>
+        public static bool IsVtg(this ShaderStage stage)
+        {
+            return stage == ShaderStage.Vertex ||
+                   stage == ShaderStage.TessellationControl ||
+                   stage == ShaderStage.TessellationEvaluation ||
+                   stage == ShaderStage.Geometry;
+        }
     }
 }

+ 5 - 2
src/Spv.Generator/Module.cs

@@ -28,6 +28,7 @@ namespace Spv.Generator
 
         // In the declaration block.
         private readonly Dictionary<TypeDeclarationKey, Instruction> _typeDeclarations;
+        private readonly List<Instruction> _typeDeclarationsList;
         // In the declaration block.
         private readonly List<Instruction> _globals;
         // In the declaration block.
@@ -54,6 +55,7 @@ namespace Spv.Generator
             _debug = new List<Instruction>();
             _annotations = new List<Instruction>();
             _typeDeclarations = new Dictionary<TypeDeclarationKey, Instruction>();
+            _typeDeclarationsList = new List<Instruction>();
             _constants = new Dictionary<ConstantKey, Instruction>();
             _globals = new List<Instruction>();
             _functionsDeclarations = new List<Instruction>();
@@ -126,7 +128,8 @@ namespace Spv.Generator
 
             instruction.SetId(GetNewId());
 
-            _typeDeclarations.Add(key, instruction);
+            _typeDeclarations[key] = instruction;
+            _typeDeclarationsList.Add(instruction);
         }
 
         public void AddEntryPoint(ExecutionModel executionModel, Instruction function, string name, params Instruction[] interfaces)
@@ -330,7 +333,7 @@ namespace Spv.Generator
 
             // Ensure that everything is in the right order in the declarations section.
             List<Instruction> declarations = new();
-            declarations.AddRange(_typeDeclarations.Values);
+            declarations.AddRange(_typeDeclarationsList);
             declarations.AddRange(_globals);
             declarations.AddRange(_constants.Values);
             declarations.Sort((Instruction x, Instruction y) => x.Id.CompareTo(y.Id));