Parcourir la source

Use SIMD acceleration for audio upsampler (#4410)

* Use SIMD acceleration for audio upsampler filter kernel for a moderate speedup

* Address formatting. Implement AVX2 fast path for high quality resampling in ResamplerHelper

* now really, are we really getting the benefit of inlining 50+ line methods?

* adding unit tests for resampler + upsampler. The upsampler ones fail for some reason

* Fixing upsampler test. Apparently this algo only works at specific ratios

---------

Co-authored-by: Logan Stromberg <lostromb@microsoft.com>
Logan Stromberg il y a 3 ans
Parent
commit
edfd4d70c0

+ 102 - 81
Ryujinx.Audio/Renderer/Dsp/ResamplerHelper.cs

@@ -1,5 +1,6 @@
 using System;
 using System.Linq;
+using System.Numerics;
 using System.Runtime.CompilerServices;
 using System.Runtime.Intrinsics;
 using System.Runtime.Intrinsics.X86;
@@ -380,7 +381,6 @@ namespace Ryujinx.Audio.Renderer.Dsp
             return _normalCurveLut2F;
         }
 
-        [MethodImpl(MethodImplOptions.AggressiveInlining)]
         private unsafe static void ResampleDefaultQuality(Span<float> outputBuffer, ReadOnlySpan<short> inputBuffer, float ratio, ref float fraction, int sampleCount, bool needPitch)
         {
             ReadOnlySpan<float> parameters = GetDefaultParameter(ratio);
@@ -394,35 +394,33 @@ namespace Ryujinx.Audio.Renderer.Dsp
                 if (ratio == 1f)
                 {
                     fixed (short* pInput = inputBuffer)
+                    fixed (float* pOutput = outputBuffer, pParameters = parameters)
                     {
-                        fixed (float* pOutput = outputBuffer, pParameters = parameters)
-                        {
-                            Vector128<float> parameter = Sse.LoadVector128(pParameters);
+                        Vector128<float> parameter = Sse.LoadVector128(pParameters);
 
-                            for (; i < (sampleCount & ~3); i += 4)
-                            {
-                                Vector128<int> intInput0 = Sse41.ConvertToVector128Int32(pInput + (uint)i);
-                                Vector128<int> intInput1 = Sse41.ConvertToVector128Int32(pInput + (uint)i + 1);
-                                Vector128<int> intInput2 = Sse41.ConvertToVector128Int32(pInput + (uint)i + 2);
-                                Vector128<int> intInput3 = Sse41.ConvertToVector128Int32(pInput + (uint)i + 3);
+                        for (; i < (sampleCount & ~3); i += 4)
+                        {
+                            Vector128<int> intInput0 = Sse41.ConvertToVector128Int32(pInput + (uint)i);
+                            Vector128<int> intInput1 = Sse41.ConvertToVector128Int32(pInput + (uint)i + 1);
+                            Vector128<int> intInput2 = Sse41.ConvertToVector128Int32(pInput + (uint)i + 2);
+                            Vector128<int> intInput3 = Sse41.ConvertToVector128Int32(pInput + (uint)i + 3);
 
-                                Vector128<float> input0 = Sse2.ConvertToVector128Single(intInput0);
-                                Vector128<float> input1 = Sse2.ConvertToVector128Single(intInput1);
-                                Vector128<float> input2 = Sse2.ConvertToVector128Single(intInput2);
-                                Vector128<float> input3 = Sse2.ConvertToVector128Single(intInput3);
+                            Vector128<float> input0 = Sse2.ConvertToVector128Single(intInput0);
+                            Vector128<float> input1 = Sse2.ConvertToVector128Single(intInput1);
+                            Vector128<float> input2 = Sse2.ConvertToVector128Single(intInput2);
+                            Vector128<float> input3 = Sse2.ConvertToVector128Single(intInput3);
 
-                                Vector128<float> mix0 = Sse.Multiply(input0, parameter);
-                                Vector128<float> mix1 = Sse.Multiply(input1, parameter);
-                                Vector128<float> mix2 = Sse.Multiply(input2, parameter);
-                                Vector128<float> mix3 = Sse.Multiply(input3, parameter);
+                            Vector128<float> mix0 = Sse.Multiply(input0, parameter);
+                            Vector128<float> mix1 = Sse.Multiply(input1, parameter);
+                            Vector128<float> mix2 = Sse.Multiply(input2, parameter);
+                            Vector128<float> mix3 = Sse.Multiply(input3, parameter);
 
-                                Vector128<float> mix01 = Sse3.HorizontalAdd(mix0, mix1);
-                                Vector128<float> mix23 = Sse3.HorizontalAdd(mix2, mix3);
+                            Vector128<float> mix01 = Sse3.HorizontalAdd(mix0, mix1);
+                            Vector128<float> mix23 = Sse3.HorizontalAdd(mix2, mix3);
 
-                                Vector128<float> mix0123 = Sse3.HorizontalAdd(mix01, mix23);
+                            Vector128<float> mix0123 = Sse3.HorizontalAdd(mix01, mix23);
 
-                                Sse.Store(pOutput + (uint)i, Sse41.RoundToNearestInteger(mix0123));
-                            }
+                            Sse.Store(pOutput + (uint)i, Sse41.RoundToNearestInteger(mix0123));
                         }
                     }
 
@@ -431,62 +429,60 @@ namespace Ryujinx.Audio.Renderer.Dsp
                 else
                 {
                     fixed (short* pInput = inputBuffer)
+                    fixed (float* pOutput = outputBuffer, pParameters = parameters)
                     {
-                        fixed (float* pOutput = outputBuffer, pParameters = parameters)
+                        for (; i < (sampleCount & ~3); i += 4)
                         {
-                            for (; i < (sampleCount & ~3); i += 4)
-                            {
-                                uint baseIndex0 = (uint)(fraction * 128) * 4;
-                                uint inputIndex0 = (uint)inputBufferIndex;
+                            uint baseIndex0 = (uint)(fraction * 128) * 4;
+                            uint inputIndex0 = (uint)inputBufferIndex;
 
-                                fraction += ratio;
+                            fraction += ratio;
 
-                                uint baseIndex1 = ((uint)(fraction * 128) & 127) * 4;
-                                uint inputIndex1 = (uint)inputBufferIndex + (uint)fraction;
+                            uint baseIndex1 = ((uint)(fraction * 128) & 127) * 4;
+                            uint inputIndex1 = (uint)inputBufferIndex + (uint)fraction;
 
-                                fraction += ratio;
+                            fraction += ratio;
 
-                                uint baseIndex2 = ((uint)(fraction * 128) & 127) * 4;
-                                uint inputIndex2 = (uint)inputBufferIndex + (uint)fraction;
+                            uint baseIndex2 = ((uint)(fraction * 128) & 127) * 4;
+                            uint inputIndex2 = (uint)inputBufferIndex + (uint)fraction;
 
-                                fraction += ratio;
+                            fraction += ratio;
 
-                                uint baseIndex3 = ((uint)(fraction * 128) & 127) * 4;
-                                uint inputIndex3 = (uint)inputBufferIndex + (uint)fraction;
+                            uint baseIndex3 = ((uint)(fraction * 128) & 127) * 4;
+                            uint inputIndex3 = (uint)inputBufferIndex + (uint)fraction;
 
-                                fraction += ratio;
-                                inputBufferIndex += (int)fraction;
+                            fraction += ratio;
+                            inputBufferIndex += (int)fraction;
 
-                                // Only keep lower part (safe as fraction isn't supposed to be negative)
-                                fraction -= (int)fraction;
+                            // Only keep lower part (safe as fraction isn't supposed to be negative)
+                            fraction -= (int)fraction;
 
-                                Vector128<float> parameter0 = Sse.LoadVector128(pParameters + baseIndex0);
-                                Vector128<float> parameter1 = Sse.LoadVector128(pParameters + baseIndex1);
-                                Vector128<float> parameter2 = Sse.LoadVector128(pParameters + baseIndex2);
-                                Vector128<float> parameter3 = Sse.LoadVector128(pParameters + baseIndex3);
+                            Vector128<float> parameter0 = Sse.LoadVector128(pParameters + baseIndex0);
+                            Vector128<float> parameter1 = Sse.LoadVector128(pParameters + baseIndex1);
+                            Vector128<float> parameter2 = Sse.LoadVector128(pParameters + baseIndex2);
+                            Vector128<float> parameter3 = Sse.LoadVector128(pParameters + baseIndex3);
 
-                                Vector128<int> intInput0 = Sse41.ConvertToVector128Int32(pInput + inputIndex0);
-                                Vector128<int> intInput1 = Sse41.ConvertToVector128Int32(pInput + inputIndex1);
-                                Vector128<int> intInput2 = Sse41.ConvertToVector128Int32(pInput + inputIndex2);
-                                Vector128<int> intInput3 = Sse41.ConvertToVector128Int32(pInput + inputIndex3);
+                            Vector128<int> intInput0 = Sse41.ConvertToVector128Int32(pInput + inputIndex0);
+                            Vector128<int> intInput1 = Sse41.ConvertToVector128Int32(pInput + inputIndex1);
+                            Vector128<int> intInput2 = Sse41.ConvertToVector128Int32(pInput + inputIndex2);
+                            Vector128<int> intInput3 = Sse41.ConvertToVector128Int32(pInput + inputIndex3);
 
-                                Vector128<float> input0 = Sse2.ConvertToVector128Single(intInput0);
-                                Vector128<float> input1 = Sse2.ConvertToVector128Single(intInput1);
-                                Vector128<float> input2 = Sse2.ConvertToVector128Single(intInput2);
-                                Vector128<float> input3 = Sse2.ConvertToVector128Single(intInput3);
+                            Vector128<float> input0 = Sse2.ConvertToVector128Single(intInput0);
+                            Vector128<float> input1 = Sse2.ConvertToVector128Single(intInput1);
+                            Vector128<float> input2 = Sse2.ConvertToVector128Single(intInput2);
+                            Vector128<float> input3 = Sse2.ConvertToVector128Single(intInput3);
 
-                                Vector128<float> mix0 = Sse.Multiply(input0, parameter0);
-                                Vector128<float> mix1 = Sse.Multiply(input1, parameter1);
-                                Vector128<float> mix2 = Sse.Multiply(input2, parameter2);
-                                Vector128<float> mix3 = Sse.Multiply(input3, parameter3);
+                            Vector128<float> mix0 = Sse.Multiply(input0, parameter0);
+                            Vector128<float> mix1 = Sse.Multiply(input1, parameter1);
+                            Vector128<float> mix2 = Sse.Multiply(input2, parameter2);
+                            Vector128<float> mix3 = Sse.Multiply(input3, parameter3);
 
-                                Vector128<float> mix01 = Sse3.HorizontalAdd(mix0, mix1);
-                                Vector128<float> mix23 = Sse3.HorizontalAdd(mix2, mix3);
+                            Vector128<float> mix01 = Sse3.HorizontalAdd(mix0, mix1);
+                            Vector128<float> mix23 = Sse3.HorizontalAdd(mix2, mix3);
 
-                                Vector128<float> mix0123 = Sse3.HorizontalAdd(mix01, mix23);
+                            Vector128<float> mix0123 = Sse3.HorizontalAdd(mix01, mix23);
 
-                                Sse.Store(pOutput + (uint)i, Sse41.RoundToNearestInteger(mix0123));
-                            }
+                            Sse.Store(pOutput + (uint)i, Sse41.RoundToNearestInteger(mix0123));
                         }
                     }
                 }
@@ -526,34 +522,59 @@ namespace Ryujinx.Audio.Renderer.Dsp
             return _highCurveLut2F;
         }
 
-        [MethodImpl(MethodImplOptions.AggressiveInlining)]
-        private static void ResampleHighQuality(Span<float> outputBuffer, ReadOnlySpan<short> inputBuffer, float ratio, ref float fraction, int sampleCount)
+        private static unsafe void ResampleHighQuality(Span<float> outputBuffer, ReadOnlySpan<short> inputBuffer, float ratio, ref float fraction, int sampleCount)
         {
             ReadOnlySpan<float> parameters = GetHighParameter(ratio);
 
             int inputBufferIndex = 0;
 
-            // TODO: fast path
-
-            for (int i = 0; i < sampleCount; i++)
+            if (Avx2.IsSupported)
             {
-                int baseIndex = (int)(fraction * 128) * 8;
-                ReadOnlySpan<float> parameter = parameters.Slice(baseIndex, 8);
-                ReadOnlySpan<short> currentInput = inputBuffer.Slice(inputBufferIndex, 8);
+                // Fast path; assumes 256-bit vectors for simplicity because the filter is 8 taps
+                fixed (short* pInput = inputBuffer)
+                fixed (float* pParameters = parameters)
+                {
+                    for (int i = 0; i < sampleCount; i++)
+                    {
+                        int baseIndex = (int)(fraction * 128) * 8;
 
-                outputBuffer[i] = (float)Math.Round(currentInput[0] * parameter[0] +
-                                                    currentInput[1] * parameter[1] +
-                                                    currentInput[2] * parameter[2] +
-                                                    currentInput[3] * parameter[3] +
-                                                    currentInput[4] * parameter[4] +
-                                                    currentInput[5] * parameter[5] +
-                                                    currentInput[6] * parameter[6] +
-                                                    currentInput[7] * parameter[7]);
+                        Vector256<int> intInput = Avx2.ConvertToVector256Int32(pInput + inputBufferIndex);
+                        Vector256<float> floatInput = Avx.ConvertToVector256Single(intInput);
+                        Vector256<float> parameter = Avx.LoadVector256(pParameters + baseIndex);
+                        Vector256<float> dp = Avx.DotProduct(floatInput, parameter, control: 0xFF);
 
-                fraction += ratio;
-                inputBufferIndex += (int)MathF.Truncate(fraction);
+                        // avx2 does an 8-element dot product piecewise so we have to sum up 2 intermediate results
+                        outputBuffer[i] = (float)Math.Round(dp[0] + dp[4]);
 
-                fraction -= (int)fraction;
+                        fraction += ratio;
+                        inputBufferIndex += (int)MathF.Truncate(fraction);
+
+                        fraction -= (int)fraction;
+                    }
+                }
+            }
+            else
+            {
+                for (int i = 0; i < sampleCount; i++)
+                {
+                    int baseIndex = (int)(fraction * 128) * 8;
+                    ReadOnlySpan<float> parameter = parameters.Slice(baseIndex, 8);
+                    ReadOnlySpan<short> currentInput = inputBuffer.Slice(inputBufferIndex, 8);
+
+                    outputBuffer[i] = (float)Math.Round(currentInput[0] * parameter[0] +
+                                                        currentInput[1] * parameter[1] +
+                                                        currentInput[2] * parameter[2] +
+                                                        currentInput[3] * parameter[3] +
+                                                        currentInput[4] * parameter[4] +
+                                                        currentInput[5] * parameter[5] +
+                                                        currentInput[6] * parameter[6] +
+                                                        currentInput[7] * parameter[7]);
+
+                    fraction += ratio;
+                    inputBufferIndex += (int)MathF.Truncate(fraction);
+
+                    fraction -= (int)fraction;
+                }
             }
         }
 

+ 20 - 3
Ryujinx.Audio/Renderer/Dsp/UpsamplerHelper.cs

@@ -2,6 +2,7 @@ using Ryujinx.Audio.Renderer.Server.Upsampler;
 using Ryujinx.Common.Memory;
 using System;
 using System.Diagnostics;
+using System.Numerics;
 using System.Runtime.CompilerServices;
 
 namespace Ryujinx.Audio.Renderer.Dsp
@@ -70,16 +71,32 @@ namespace Ryujinx.Audio.Renderer.Dsp
                 return;
             }
 
-            [MethodImpl(MethodImplOptions.AggressiveInlining)]
             float DoFilterBank(ref UpsamplerBufferState state, in Array20<float> bank)
             {
                 float result = 0.0f;
 
                 Debug.Assert(state.History.Length == HistoryLength);
                 Debug.Assert(bank.Length == FilterBankLength);
-                for (int j = 0; j < FilterBankLength; j++)
+
+                int curIdx = 0;
+                if (Vector.IsHardwareAccelerated)
+                {
+                    // Do SIMD-accelerated block operations where possible.
+                    // Only about a 2x speedup since filter bank length is short
+                    int stopIdx = FilterBankLength - (FilterBankLength % Vector<float>.Count);
+                    while (curIdx < stopIdx)
+                    {
+                        result += Vector.Dot(
+                            new Vector<float>(bank.AsSpan().Slice(curIdx, Vector<float>.Count)),
+                            new Vector<float>(state.History.AsSpan().Slice(curIdx, Vector<float>.Count)));
+                        curIdx += Vector<float>.Count;
+                    }
+                }
+
+                while (curIdx < FilterBankLength)
                 {
-                    result += bank[j] * state.History[j];
+                    result += bank[curIdx] * state.History[curIdx];
+                    curIdx++;
                 }
 
                 return result;

+ 93 - 0
Ryujinx.Tests/Audio/Renderer/Dsp/ResamplerTests.cs

@@ -0,0 +1,93 @@
+using NUnit.Framework;
+using Ryujinx.Audio.Renderer.Dsp;
+using Ryujinx.Audio.Renderer.Parameter;
+using Ryujinx.Audio.Renderer.Server.Upsampler;
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace Ryujinx.Tests.Audio.Renderer.Dsp
+{
+    class ResamplerTests
+    {
+        [Test]
+        [TestCase(VoiceInParameter.SampleRateConversionQuality.Low)]
+        [TestCase(VoiceInParameter.SampleRateConversionQuality.Default)]
+        [TestCase(VoiceInParameter.SampleRateConversionQuality.High)]
+        public void TestResamplerConsistencyUpsampling(VoiceInParameter.SampleRateConversionQuality quality)
+        {
+            DoResamplingTest(44100, 48000, quality);
+        }
+
+        [Test]
+        [TestCase(VoiceInParameter.SampleRateConversionQuality.Low)]
+        [TestCase(VoiceInParameter.SampleRateConversionQuality.Default)]
+        [TestCase(VoiceInParameter.SampleRateConversionQuality.High)]
+        public void TestResamplerConsistencyDownsampling(VoiceInParameter.SampleRateConversionQuality quality)
+        {
+            DoResamplingTest(48000, 44100, quality);
+        }
+
+        /// <summary>
+        /// Generates a 1-second sine wave sample at input rate, resamples it to output rate, and
+        /// ensures that it resampled at the expected rate with no discontinuities
+        /// </summary>
+        /// <param name="inputRate">The input sample rate to test</param>
+        /// <param name="outputRate">The output sample rate to test</param>
+        /// <param name="quality">The resampler quality to use</param>
+        private static void DoResamplingTest(int inputRate, int outputRate, VoiceInParameter.SampleRateConversionQuality quality)
+        {
+            float inputSampleRate = (float)inputRate;
+            float outputSampleRate = (float)outputRate;
+            int inputSampleCount = inputRate;
+            int outputSampleCount = outputRate;
+            short[] inputBuffer = new short[inputSampleCount + 100]; // add some safety buffer at the end
+            float[] outputBuffer = new float[outputSampleCount + 100];
+            for (int sample = 0; sample < inputBuffer.Length; sample++)
+            {
+                // 440 hz sine wave with amplitude = 0.5f at input sample rate
+                inputBuffer[sample] = (short)(32767 * MathF.Sin((440 / inputSampleRate) * (float)sample * MathF.PI * 2f) * 0.5f);
+            }
+
+            float fraction = 0;
+
+            ResamplerHelper.Resample(
+                outputBuffer.AsSpan(),
+                inputBuffer.AsSpan(),
+                inputSampleRate / outputSampleRate,
+                ref fraction,
+                outputSampleCount,
+                quality,
+                false);
+
+            float[] expectedOutput = new float[outputSampleCount];
+            float sumDifference = 0;
+            int delay = quality switch
+            {
+                VoiceInParameter.SampleRateConversionQuality.High => 3,
+                VoiceInParameter.SampleRateConversionQuality.Default => 1,
+                _ => 0
+            };
+
+            for (int sample = 0; sample < outputSampleCount; sample++)
+            {
+                outputBuffer[sample] /= 32767;
+                // 440 hz sine wave with amplitude = 0.5f at output sample rate
+                expectedOutput[sample] = MathF.Sin((440 / outputSampleRate) * (float)(sample + delay) * MathF.PI * 2f) * 0.5f;
+                float thisDelta = Math.Abs(expectedOutput[sample] - outputBuffer[sample]);
+
+                // Ensure no discontinuities
+                Assert.IsTrue(thisDelta < 0.1f);
+                sumDifference += thisDelta;
+            }
+
+            sumDifference = sumDifference / (float)outputSampleCount;
+            // Expect the output to be 99% similar to the expected resampled sine wave
+            Assert.IsTrue(sumDifference < 0.01f);
+        }
+    }
+}

+ 64 - 0
Ryujinx.Tests/Audio/Renderer/Dsp/UpsamplerTests.cs

@@ -0,0 +1,64 @@
+using NUnit.Framework;
+using Ryujinx.Audio.Renderer.Dsp;
+using Ryujinx.Audio.Renderer.Parameter;
+using Ryujinx.Audio.Renderer.Server.Upsampler;
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace Ryujinx.Tests.Audio.Renderer.Dsp
+{
+    class UpsamplerTests
+    {
+        [Test]
+        public void TestUpsamplerConsistency()
+        {
+            UpsamplerBufferState bufferState = new UpsamplerBufferState();
+            int inputBlockSize = 160;
+            int numInputSamples = 32000;
+            int numOutputSamples = 48000;
+            float inputSampleRate = numInputSamples;
+            float outputSampleRate = numOutputSamples;
+            float[] inputBuffer = new float[numInputSamples + 100];
+            float[] outputBuffer = new float[numOutputSamples + 100];
+            for (int sample = 0; sample < inputBuffer.Length; sample++)
+            {
+                // 440 hz sine wave with amplitude = 0.5f at input sample rate
+                inputBuffer[sample] = MathF.Sin((440 / inputSampleRate) * (float)sample * MathF.PI * 2f) * 0.5f;
+            }
+
+            int inputIdx = 0;
+            int outputIdx = 0;
+            while (inputIdx + inputBlockSize < numInputSamples)
+            {
+                int outputBufLength = (int)Math.Round((float)(inputIdx + inputBlockSize) * outputSampleRate / inputSampleRate) - outputIdx;
+                UpsamplerHelper.Upsample(
+                    outputBuffer.AsSpan(outputIdx),
+                    inputBuffer.AsSpan(inputIdx),
+                    outputBufLength,
+                    inputBlockSize,
+                    ref bufferState);
+
+                inputIdx += inputBlockSize;
+                outputIdx += outputBufLength;
+            }
+
+            float[] expectedOutput = new float[numOutputSamples];
+            float sumDifference = 0;
+            for (int sample = 0; sample < numOutputSamples; sample++)
+            {
+                // 440 hz sine wave with amplitude = 0.5f at output sample rate with an offset of 15
+                expectedOutput[sample] = MathF.Sin((440 / outputSampleRate) * (float)(sample - 15) * MathF.PI * 2f) * 0.5f;
+                sumDifference += Math.Abs(expectedOutput[sample] - outputBuffer[sample]);
+            }
+
+            sumDifference = sumDifference / (float)expectedOutput.Length;
+            // Expect the output to be 98% similar to the expected resampled sine wave
+            Assert.IsTrue(sumDifference < 0.02f);
+        }
+    }
+}