Jelajahi Sumber

Use multiple dest operands for shader call instructions (#1975)

* Use multiple dest operands for shader call instructions

* Passing opNode is no longer needed
gdkchan 5 tahun lalu
induk
melakukan
053dcfdb05

+ 0 - 1
Ryujinx.Graphics.Shader/IntermediateRepresentation/Instruction.cs

@@ -32,7 +32,6 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation
         BranchIfFalse,
         BranchIfFalse,
         BranchIfTrue,
         BranchIfTrue,
         Call,
         Call,
-        CallOutArgument,
         Ceiling,
         Ceiling,
         Clamp,
         Clamp,
         ClampU32,
         ClampU32,

+ 22 - 1
Ryujinx.Graphics.Shader/IntermediateRepresentation/Operation.cs

@@ -1,4 +1,5 @@
 using System;
 using System;
+using System.Diagnostics;
 
 
 namespace Ryujinx.Graphics.Shader.IntermediateRepresentation
 namespace Ryujinx.Graphics.Shader.IntermediateRepresentation
 {
 {
@@ -96,7 +97,27 @@ namespace Ryujinx.Graphics.Shader.IntermediateRepresentation
             Index = index;
             Index = index;
         }
         }
 
 
-        public void AppendOperands(params Operand[] operands)
+        public void AppendDests(Operand[] operands)
+        {
+            int startIndex = _dests.Length;
+
+            Array.Resize(ref _dests, startIndex + operands.Length);
+
+            for (int index = 0; index < operands.Length; index++)
+            {
+                Operand dest = operands[index];
+
+                if (dest != null && dest.Type == OperandType.LocalVariable)
+                {
+                    Debug.Assert(dest.AsgOp == null);
+                    dest.AsgOp = this;
+                }
+
+                _dests[startIndex + index] = dest;
+            }
+        }
+
+        public void AppendSources(Operand[] operands)
         {
         {
             int startIndex = _sources.Length;
             int startIndex = _sources.Length;
 
 

+ 3 - 32
Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs

@@ -51,9 +51,9 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
                         {
                         {
                             context.LeaveBlock(block, operation);
                             context.LeaveBlock(block, operation);
                         }
                         }
-                        else if (operation.Inst != Instruction.CallOutArgument)
+                        else
                         {
                         {
-                            AddOperation(context, opNode);
+                            AddOperation(context, operation);
                         }
                         }
                     }
                     }
                 }
                 }
@@ -68,32 +68,13 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
             return context.Info;
             return context.Info;
         }
         }
 
 
-        private static void AddOperation(StructuredProgramContext context, LinkedListNode<INode> opNode)
+        private static void AddOperation(StructuredProgramContext context, Operation operation)
         {
         {
-            Operation operation = (Operation)opNode.Value;
-
             Instruction inst = operation.Inst;
             Instruction inst = operation.Inst;
 
 
-            bool isCall = inst == Instruction.Call;
-
             int sourcesCount = operation.SourcesCount;
             int sourcesCount = operation.SourcesCount;
             int outDestsCount = operation.DestsCount != 0 ? operation.DestsCount - 1 : 0;
             int outDestsCount = operation.DestsCount != 0 ? operation.DestsCount - 1 : 0;
 
 
-            List<Operand> callOutOperands = new List<Operand>();
-
-            if (isCall)
-            {
-                LinkedListNode<INode> scan = opNode.Next;
-
-                while (scan != null && scan.Value is Operation nextOp && nextOp.Inst == Instruction.CallOutArgument)
-                {
-                    callOutOperands.Add(nextOp.Dest);
-                    scan = scan.Next;
-                }
-
-                sourcesCount += callOutOperands.Count;
-            }
-
             IAstNode[] sources = new IAstNode[sourcesCount + outDestsCount];
             IAstNode[] sources = new IAstNode[sourcesCount + outDestsCount];
 
 
             for (int index = 0; index < operation.SourcesCount; index++)
             for (int index = 0; index < operation.SourcesCount; index++)
@@ -101,16 +82,6 @@ namespace Ryujinx.Graphics.Shader.StructuredIr
                 sources[index] = context.GetOperandUse(operation.GetSource(index));
                 sources[index] = context.GetOperandUse(operation.GetSource(index));
             }
             }
 
 
-            if (isCall)
-            {
-                for (int index = 0; index < callOutOperands.Count; index++)
-                {
-                    sources[operation.SourcesCount + index] = context.GetOperandDef(callOutOperands[index]);
-                }
-
-                callOutOperands.Clear();
-            }
-
             for (int index = 0; index < outDestsCount; index++)
             for (int index = 0; index < outDestsCount; index++)
             {
             {
                 AstOperand oper = context.GetOperandDef(operation.GetDest(1 + index));
                 AstOperand oper = context.GetOperandDef(operation.GetDest(1 + index));

+ 6 - 3
Ryujinx.Graphics.Shader/Translation/Optimizations/Optimizer.cs

@@ -289,7 +289,6 @@ namespace Ryujinx.Graphics.Shader.Translation.Optimizations
                     case Instruction.AtomicSwap:
                     case Instruction.AtomicSwap:
                     case Instruction.AtomicXor:
                     case Instruction.AtomicXor:
                     case Instruction.Call:
                     case Instruction.Call:
-                    case Instruction.CallOutArgument:
                         return true;
                         return true;
                 }
                 }
             }
             }
@@ -306,7 +305,9 @@ namespace Ryujinx.Graphics.Shader.Translation.Optimizations
 
 
             for (int index = 0; index < node.DestsCount; index++)
             for (int index = 0; index < node.DestsCount; index++)
             {
             {
-                if (node.GetDest(index).Type != OperandType.LocalVariable)
+                Operand dest = node.GetDest(index);
+
+                if (dest != null && dest.Type != OperandType.LocalVariable)
                 {
                 {
                     return false;
                     return false;
                 }
                 }
@@ -319,7 +320,9 @@ namespace Ryujinx.Graphics.Shader.Translation.Optimizations
         {
         {
             for (int index = 0; index < node.DestsCount; index++)
             for (int index = 0; index < node.DestsCount; index++)
             {
             {
-                if (node.GetDest(index).UseOps.Count != 0)
+                Operand dest = node.GetDest(index);
+
+                if (dest != null && dest.UseOps.Count != 0)
                 {
                 {
                     return false;
                     return false;
                 }
                 }

+ 8 - 6
Ryujinx.Graphics.Shader/Translation/RegisterUsage.cs

@@ -299,21 +299,23 @@ namespace Ryujinx.Graphics.Shader.Translation
 
 
                         var fru = frus[funcId.Value];
                         var fru = frus[funcId.Value];
 
 
-                        Operand[] regs = new Operand[fru.InArguments.Length];
+                        Operand[] inRegs = new Operand[fru.InArguments.Length];
 
 
                         for (int i = 0; i < fru.InArguments.Length; i++)
                         for (int i = 0; i < fru.InArguments.Length; i++)
                         {
                         {
-                            regs[i] = OperandHelper.Register(fru.InArguments[i]);
+                            inRegs[i] = OperandHelper.Register(fru.InArguments[i]);
                         }
                         }
 
 
-                        operation.AppendOperands(regs);
+                        operation.AppendSources(inRegs);
+
+                        Operand[] outRegs = new Operand[1 + fru.OutArguments.Length];
 
 
                         for (int i = 0; i < fru.OutArguments.Length; i++)
                         for (int i = 0; i < fru.OutArguments.Length; i++)
                         {
                         {
-                            Operation callOutArgOp = new Operation(Instruction.CallOutArgument, OperandHelper.Register(fru.OutArguments[i]));
-
-                            node = block.Operations.AddAfter(node, callOutArgOp);
+                            outRegs[1 + i] = OperandHelper.Register(fru.OutArguments[i]);
                         }
                         }
+
+                        operation.AppendDests(outRegs);
                     }
                     }
                 }
                 }
             }
             }

+ 1 - 1
Ryujinx.Graphics.Shader/Translation/Ssa.cs

@@ -120,7 +120,7 @@ namespace Ryujinx.Graphics.Shader.Translation
                         {
                         {
                             Operand dest = operation.GetDest(index);
                             Operand dest = operation.GetDest(index);
 
 
-                            if (dest.Type == OperandType.Register)
+                            if (dest != null && dest.Type == OperandType.Register)
                             {
                             {
                                 Operand local = Local();
                                 Operand local = Local();
 
 

+ 0 - 1
Ryujinx.Graphics.Shader/Translation/Translator.cs

@@ -88,7 +88,6 @@ namespace Ryujinx.Graphics.Shader.Translation
                     RegisterUsage.FixupCalls(cfg.Blocks, frus);
                     RegisterUsage.FixupCalls(cfg.Blocks, frus);
 
 
                     Dominance.FindDominators(cfg);
                     Dominance.FindDominators(cfg);
-
                     Dominance.FindDominanceFrontiers(cfg.Blocks);
                     Dominance.FindDominanceFrontiers(cfg.Blocks);
 
 
                     Ssa.Rename(cfg.Blocks);
                     Ssa.Rename(cfg.Blocks);