Explorar el Código

Fix SSL GetCertificates with certificate ID set to All (#3727)

* Fix SSL GetCertificates with certificate ID set to All

* Fix last entry status value
gdkchan hace 3 años
padre
commit
dbe43c1719

+ 11 - 2
Ryujinx.HLE/HOS/Services/Ssl/BuiltInCertificateManager.cs

@@ -181,7 +181,11 @@ namespace Ryujinx.HLE.HOS.Services.Ssl
             }
             }
         }
         }
 
 
-        public bool TryGetCertificates(ReadOnlySpan<CaCertificateId> ids, out CertStoreEntry[] entries)
+        public bool TryGetCertificates(
+            ReadOnlySpan<CaCertificateId> ids,
+            out CertStoreEntry[] entries,
+            out bool hasAllCertificates,
+            out int requiredSize)
         {
         {
             lock (_lock)
             lock (_lock)
             {
             {
@@ -190,7 +194,8 @@ namespace Ryujinx.HLE.HOS.Services.Ssl
                     throw new InvalidSystemResourceException(CertStoreTitleMissingErrorMessage);
                     throw new InvalidSystemResourceException(CertStoreTitleMissingErrorMessage);
                 }
                 }
 
 
-                bool hasAllCertificates = false;
+                requiredSize = 0;
+                hasAllCertificates = false;
 
 
                 foreach (CaCertificateId id in ids)
                 foreach (CaCertificateId id in ids)
                 {
                 {
@@ -205,12 +210,14 @@ namespace Ryujinx.HLE.HOS.Services.Ssl
                 if (hasAllCertificates)
                 if (hasAllCertificates)
                 {
                 {
                     entries = new CertStoreEntry[_certificates.Count];
                     entries = new CertStoreEntry[_certificates.Count];
+                    requiredSize = (_certificates.Count + 1) * Unsafe.SizeOf<BuiltInCertificateInfo>();
 
 
                     int i = 0;
                     int i = 0;
 
 
                     foreach (CertStoreEntry entry in _certificates.Values)
                     foreach (CertStoreEntry entry in _certificates.Values)
                     {
                     {
                         entries[i++] = entry;
                         entries[i++] = entry;
+                        requiredSize += (entry.Data.Length + 3) & ~3;
                     }
                     }
 
 
                     return true;
                     return true;
@@ -218,6 +225,7 @@ namespace Ryujinx.HLE.HOS.Services.Ssl
                 else
                 else
                 {
                 {
                     entries = new CertStoreEntry[ids.Length];
                     entries = new CertStoreEntry[ids.Length];
+                    requiredSize = ids.Length * Unsafe.SizeOf<BuiltInCertificateInfo>();
 
 
                     for (int i = 0; i < ids.Length; i++)
                     for (int i = 0; i < ids.Length; i++)
                     {
                     {
@@ -227,6 +235,7 @@ namespace Ryujinx.HLE.HOS.Services.Ssl
                         }
                         }
 
 
                         entries[i] = entry;
                         entries[i] = entry;
+                        requiredSize += (entry.Data.Length + 3) & ~3;
                     }
                     }
 
 
                     return true;
                     return true;

+ 29 - 20
Ryujinx.HLE/HOS/Services/Ssl/ISslService.cs

@@ -29,42 +29,40 @@ namespace Ryujinx.HLE.HOS.Services.Ssl
             return ResultCode.Success;
             return ResultCode.Success;
         }
         }
 
 
-        private uint ComputeCertificateBufferSizeRequired(ReadOnlySpan<BuiltInCertificateManager.CertStoreEntry> entries)
-        {
-            uint totalSize = 0;
-
-            for (int i = 0; i < entries.Length; i++)
-            {
-                totalSize += (uint)Unsafe.SizeOf<BuiltInCertificateInfo>();
-                totalSize += (uint)entries[i].Data.Length;
-            }
-
-            return totalSize;
-        }
-
         [CommandHipc(2)]
         [CommandHipc(2)]
         // GetCertificates(buffer<CaCertificateId, 5> ids) -> (u32 certificates_count, buffer<bytes, 6> certificates)
         // GetCertificates(buffer<CaCertificateId, 5> ids) -> (u32 certificates_count, buffer<bytes, 6> certificates)
         public ResultCode GetCertificates(ServiceCtx context)
         public ResultCode GetCertificates(ServiceCtx context)
         {
         {
             ReadOnlySpan<CaCertificateId> ids = MemoryMarshal.Cast<byte, CaCertificateId>(context.Memory.GetSpan(context.Request.SendBuff[0].Position, (int)context.Request.SendBuff[0].Size));
             ReadOnlySpan<CaCertificateId> ids = MemoryMarshal.Cast<byte, CaCertificateId>(context.Memory.GetSpan(context.Request.SendBuff[0].Position, (int)context.Request.SendBuff[0].Size));
 
 
-            if (!BuiltInCertificateManager.Instance.TryGetCertificates(ids, out BuiltInCertificateManager.CertStoreEntry[] entries))
+            if (!BuiltInCertificateManager.Instance.TryGetCertificates(
+                ids,
+                out BuiltInCertificateManager.CertStoreEntry[] entries,
+                out bool hasAllCertificates,
+                out int requiredSize))
             {
             {
                 throw new InvalidOperationException();
                 throw new InvalidOperationException();
             }
             }
 
 
-            if (ComputeCertificateBufferSizeRequired(entries) > context.Request.ReceiveBuff[0].Size)
+            if ((uint)requiredSize > (uint)context.Request.ReceiveBuff[0].Size)
             {
             {
                 return ResultCode.InvalidCertBufSize;
                 return ResultCode.InvalidCertBufSize;
             }
             }
 
 
+            int infosCount = entries.Length;
+
+            if (hasAllCertificates)
+            {
+                infosCount++;
+            }
+
             using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size))
             using (WritableRegion region = context.Memory.GetWritableRegion(context.Request.ReceiveBuff[0].Position, (int)context.Request.ReceiveBuff[0].Size))
             {
             {
                 Span<byte> rawData = region.Memory.Span;
                 Span<byte> rawData = region.Memory.Span;
-                Span<BuiltInCertificateInfo> infos = MemoryMarshal.Cast<byte, BuiltInCertificateInfo>(rawData)[..entries.Length];
-                Span<byte> certificatesData = rawData[(Unsafe.SizeOf<BuiltInCertificateInfo>() * entries.Length)..];
+                Span<BuiltInCertificateInfo> infos = MemoryMarshal.Cast<byte, BuiltInCertificateInfo>(rawData)[..infosCount];
+                Span<byte> certificatesData = rawData[(Unsafe.SizeOf<BuiltInCertificateInfo>() * infosCount)..];
 
 
-                for (int i = 0; i < infos.Length; i++)
+                for (int i = 0; i < entries.Length; i++)
                 {
                 {
                     entries[i].Data.CopyTo(certificatesData);
                     entries[i].Data.CopyTo(certificatesData);
 
 
@@ -78,6 +76,17 @@ namespace Ryujinx.HLE.HOS.Services.Ssl
 
 
                     certificatesData = certificatesData[entries[i].Data.Length..];
                     certificatesData = certificatesData[entries[i].Data.Length..];
                 }
                 }
+
+                if (hasAllCertificates)
+                {
+                    infos[entries.Length] = new BuiltInCertificateInfo
+                    {
+                        Id = CaCertificateId.All,
+                        Status = TrustedCertStatus.Invalid,
+                        CertificateDataSize = 0,
+                        CertificateDataOffset = 0
+                    };
+                }
             }
             }
 
 
             context.ResponseData.Write(entries.Length);
             context.ResponseData.Write(entries.Length);
@@ -91,12 +100,12 @@ namespace Ryujinx.HLE.HOS.Services.Ssl
         {
         {
             ReadOnlySpan<CaCertificateId> ids = MemoryMarshal.Cast<byte, CaCertificateId>(context.Memory.GetSpan(context.Request.SendBuff[0].Position, (int)context.Request.SendBuff[0].Size));
             ReadOnlySpan<CaCertificateId> ids = MemoryMarshal.Cast<byte, CaCertificateId>(context.Memory.GetSpan(context.Request.SendBuff[0].Position, (int)context.Request.SendBuff[0].Size));
 
 
-            if (!BuiltInCertificateManager.Instance.TryGetCertificates(ids, out BuiltInCertificateManager.CertStoreEntry[] entries))
+            if (!BuiltInCertificateManager.Instance.TryGetCertificates(ids, out _, out _, out int requiredSize))
             {
             {
                 throw new InvalidOperationException();
                 throw new InvalidOperationException();
             }
             }
 
 
-            context.ResponseData.Write(ComputeCertificateBufferSizeRequired(entries));
+            context.ResponseData.Write(requiredSize);
 
 
             return ResultCode.Success;
             return ResultCode.Success;
         }
         }