Prechádzať zdrojové kódy

Fix virtual address overflow near ulong limit (#2044)

* Fix virtual address overflow near ulong limit

* Fix comments

* Improve overflow checking for large size values

* Add overflow checking to AddressSpaceManager class

* Add overflow protection to read and write functions
Caian Benedicto 5 rokov pred
rodič
commit
f7b2daf5ec

+ 75 - 23
Ryujinx.Cpu/MemoryManager.cs

@@ -83,6 +83,8 @@ namespace Ryujinx.Cpu
         /// <param name="size">Size to be mapped</param>
         public void Map(ulong va, ulong pa, ulong size)
         {
+            AssertValidAddressAndSize(va, size);
+
             ulong remainingSize = size;
             ulong oVa = va;
             ulong oPa = pa;
@@ -110,6 +112,8 @@ namespace Ryujinx.Cpu
                 return;
             }
 
+            AssertValidAddressAndSize(va, size);
+
             UnmapEvent?.Invoke(va, size);
 
             ulong remainingSize = size;
@@ -214,6 +218,8 @@ namespace Ryujinx.Cpu
         {
             try
             {
+                AssertValidAddressAndSize(va, (ulong)data.Length);
+
                 if (IsContiguousAndMapped(va, data.Length))
                 {
                     data.CopyTo(_backingMemory.GetSpan(GetPhysicalAddressInternal(va), data.Length));
@@ -345,6 +351,23 @@ namespace Ryujinx.Cpu
             return ref _backingMemory.GetRef<T>(GetPhysicalAddressInternal(va));
         }
 
+        /// <summary>
+        /// Computes the number of pages in a virtual address range.
+        /// </summary>
+        /// <param name="va">Virtual address of the range</param>
+        /// <param name="size">Size of the range</param>
+        /// <param name="startVa">The virtual address of the beginning of the first page</param>
+        /// <remarks>This function does not differentiate between allocated and unallocated pages.</remarks>
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        private int GetPagesCount(ulong va, uint size, out ulong startVa)
+        {
+            // WARNING: Always check if ulong does not overflow during the operations.
+            startVa = va & ~(ulong)PageMask;
+            ulong vaSpan = (va - startVa + size + PageMask) & ~(ulong)PageMask;
+
+            return (int)(vaSpan / PageSize);
+        }
+
         private void ThrowMemoryNotContiguous() => throw new MemoryNotContiguousException();
 
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -353,16 +376,12 @@ namespace Ryujinx.Cpu
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
         private bool IsContiguous(ulong va, int size)
         {
-            if (!ValidateAddress(va))
+            if (!ValidateAddress(va) || !ValidateAddressAndSize(va, (ulong)size))
             {
                 return false;
             }
 
-            ulong endVa = (va + (ulong)size + PageMask) & ~(ulong)PageMask;
-
-            va &= ~(ulong)PageMask;
-
-            int pages = (int)((endVa - va) / PageSize);
+            int pages = GetPagesCount(va, (uint)size, out va);
 
             for (int page = 0; page < pages - 1; page++)
             {
@@ -391,16 +410,12 @@ namespace Ryujinx.Cpu
         /// <returns>Array of physical regions</returns>
         public (ulong address, ulong size)[] GetPhysicalRegions(ulong va, ulong size)
         {
-            if (!ValidateAddress(va))
+            if (!ValidateAddress(va) || !ValidateAddressAndSize(va, size))
             {
                 return null;
             }
 
-            ulong endVa = (va + size + PageMask) & ~(ulong)PageMask;
-
-            va &= ~(ulong)PageMask;
-
-            int pages = (int)((endVa - va) / PageSize);
+            int pages = GetPagesCount(va, (uint)size, out va);
 
             List<(ulong, ulong)> regions = new List<(ulong, ulong)>();
 
@@ -441,6 +456,8 @@ namespace Ryujinx.Cpu
 
             try
             {
+                AssertValidAddressAndSize(va, (ulong)data.Length);
+
                 int offset = 0, size;
 
                 if ((va & PageMask) != 0)
@@ -485,11 +502,14 @@ namespace Ryujinx.Cpu
                 return true;
             }
 
-            ulong endVa = (va + size + PageMask) & ~(ulong)PageMask;
+            if (!ValidateAddressAndSize(va, size))
+            {
+                return false;
+            }
 
-            va &= ~(ulong)PageMask;
+            int pages = GetPagesCount(va, (uint)size, out va);
 
-            while (va < endVa)
+            for (int page = 0; page < pages; page++)
             {
                 if (!IsMapped(va))
                 {
@@ -523,6 +543,32 @@ namespace Ryujinx.Cpu
             return va < _addressSpaceSize;
         }
 
+        /// <summary>
+        /// Checks if the combination of virtual address and size is part of the addressable space.
+        /// </summary>
+        /// <param name="va">Virtual address of the range</param>
+        /// <param name="size">Size of the range in bytes</param>
+        /// <returns>True if the combination of virtual address and size is part of the addressable space</returns>
+        private bool ValidateAddressAndSize(ulong va, ulong size)
+        {
+            ulong endVa = va + size;
+            return endVa >= va && endVa >= size && endVa <= _addressSpaceSize;
+        }
+
+        /// <summary>
+        /// Ensures the combination of virtual address and size is part of the addressable space.
+        /// </summary>
+        /// <param name="va">Virtual address of the range</param>
+        /// <param name="size">Size of the range in bytes</param>
+        /// <exception cref="InvalidMemoryRegionException">Throw when the memory region specified outside the addressable space</exception>
+        private void AssertValidAddressAndSize(ulong va, ulong size)
+        {
+            if (!ValidateAddressAndSize(va, size))
+            {
+                throw new InvalidMemoryRegionException($"va=0x{va:X16}, size=0x{size:X16}");
+            }
+        }
+
         /// <summary>
         /// Performs address translation of the address inside a CPU mapped memory range.
         /// </summary>
@@ -555,6 +601,8 @@ namespace Ryujinx.Cpu
         /// <param name="protection">Memory protection to set</param>
         public void TrackingReprotect(ulong va, ulong size, MemoryPermission protection)
         {
+            AssertValidAddressAndSize(va, size);
+
             // Protection is inverted on software pages, since the default value is 0.
             protection = (~protection) & MemoryPermission.ReadAndWrite;
 
@@ -565,12 +613,13 @@ namespace Ryujinx.Cpu
                 _ => 3L << PointerTagBit
             };
 
-            ulong endVa = (va + size + PageMask) & ~(ulong)PageMask;
+            int pages = GetPagesCount(va, (uint)size, out va);
+            ulong pageStart = va >> PageBits;
             long invTagMask = ~(0xffffL << 48);
 
-            while (va < endVa)
+            for (int page = 0; page < pages; page++)
             {
-                ref long pageRef = ref _pageTable.GetRef<long>((va >> PageBits) * PteSize);
+                ref long pageRef = ref _pageTable.GetRef<long>(pageStart * PteSize);
 
                 long pte;
 
@@ -580,7 +629,7 @@ namespace Ryujinx.Cpu
                 }
                 while (Interlocked.CompareExchange(ref pageRef, (pte & invTagMask) | tag, pte) != pte);
 
-                va += PageSize;
+                pageStart++;
             }
         }
 
@@ -627,17 +676,20 @@ namespace Ryujinx.Cpu
         /// <param name="size">Size of the region</param>
         public void SignalMemoryTracking(ulong va, ulong size, bool write)
         {
+            AssertValidAddressAndSize(va, size);
+
             // We emulate guard pages for software memory access. This makes for an easy transition to
             // tracking using host guard pages in future, but also supporting platforms where this is not possible.
 
             // Write tag includes read protection, since we don't have any read actions that aren't performed before write too.
             long tag = (write ? 3L : 1L) << PointerTagBit;
 
-            ulong endVa = (va + size + PageMask) & ~(ulong)PageMask;
+            int pages = GetPagesCount(va, (uint)size, out va);
+            ulong pageStart = va >> PageBits;
 
-            while (va < endVa)
+            for (int page = 0; page < pages; page++)
             {
-                ref long pageRef = ref _pageTable.GetRef<long>((va >> PageBits) * PteSize);
+                ref long pageRef = ref _pageTable.GetRef<long>(pageStart * PteSize);
 
                 long pte;
 
@@ -649,7 +701,7 @@ namespace Ryujinx.Cpu
                     break;
                 }
 
-                va += PageSize;
+                pageStart++;
             }
         }
 

+ 59 - 9
Ryujinx.Memory/AddressSpaceManager.cs

@@ -64,6 +64,8 @@ namespace Ryujinx.Memory
         /// <param name="size">Size to be mapped</param>
         public void Map(ulong va, ulong pa, ulong size)
         {
+            AssertValidAddressAndSize(va, size);
+
             while (size != 0)
             {
                 PtMap(va, pa);
@@ -81,6 +83,8 @@ namespace Ryujinx.Memory
         /// <param name="size">Size of the range to be unmapped</param>
         public void Unmap(ulong va, ulong size)
         {
+            AssertValidAddressAndSize(va, size);
+
             while (size != 0)
             {
                 PtUnmap(va);
@@ -138,6 +142,8 @@ namespace Ryujinx.Memory
                 return;
             }
 
+            AssertValidAddressAndSize(va, (ulong)data.Length);
+
             if (IsContiguousAndMapped(va, data.Length))
             {
                 data.CopyTo(_backingMemory.GetSpan(GetPhysicalAddressInternal(va), data.Length));
@@ -254,6 +260,23 @@ namespace Ryujinx.Memory
             return ref _backingMemory.GetRef<T>(GetPhysicalAddressInternal(va));
         }
 
+        /// <summary>
+        /// Computes the number of pages in a virtual address range.
+        /// </summary>
+        /// <param name="va">Virtual address of the range</param>
+        /// <param name="size">Size of the range</param>
+        /// <param name="startVa">The virtual address of the beginning of the first page</param>
+        /// <remarks>This function does not differentiate between allocated and unallocated pages.</remarks>
+        [MethodImpl(MethodImplOptions.AggressiveInlining)]
+        private int GetPagesCount(ulong va, uint size, out ulong startVa)
+        {
+            // WARNING: Always check if ulong does not overflow during the operations.
+            startVa = va & ~(ulong)PageMask;
+            ulong vaSpan = (va - startVa + size + PageMask) & ~(ulong)PageMask;
+
+            return (int)(vaSpan / PageSize);
+        }
+
         private void ThrowMemoryNotContiguous() => throw new MemoryNotContiguousException();
 
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
@@ -262,16 +285,12 @@ namespace Ryujinx.Memory
         [MethodImpl(MethodImplOptions.AggressiveInlining)]
         private bool IsContiguous(ulong va, int size)
         {
-            if (!ValidateAddress(va))
+            if (!ValidateAddress(va) || !ValidateAddressAndSize(va, (ulong)size))
             {
                 return false;
             }
 
-            ulong endVa = (va + (ulong)size + PageMask) & ~(ulong)PageMask;
-
-            va &= ~(ulong)PageMask;
-
-            int pages = (int)((endVa - va) / PageSize);
+            int pages = GetPagesCount(va, (uint)size, out va);
 
             for (int page = 0; page < pages - 1; page++)
             {
@@ -310,6 +329,8 @@ namespace Ryujinx.Memory
                 return;
             }
 
+            AssertValidAddressAndSize(va, (ulong)data.Length);
+
             int offset = 0, size;
 
             if ((va & PageMask) != 0)
@@ -362,11 +383,14 @@ namespace Ryujinx.Memory
                 return true;
             }
 
-            ulong endVa = (va + size + PageMask) & ~(ulong)PageMask;
+            if (!ValidateAddressAndSize(va, size))
+            {
+                return false;
+            }
 
-            va &= ~(ulong)PageMask;
+            int pages = GetPagesCount(va, (uint)size, out va);
 
-            while (va < endVa)
+            for (int page = 0; page < pages; page++)
             {
                 if (!IsMapped(va))
                 {
@@ -384,6 +408,32 @@ namespace Ryujinx.Memory
             return va < _addressSpaceSize;
         }
 
+        /// <summary>
+        /// Checks if the combination of virtual address and size is part of the addressable space.
+        /// </summary>
+        /// <param name="va">Virtual address of the range</param>
+        /// <param name="size">Size of the range in bytes</param>
+        /// <returns>True if the combination of virtual address and size is part of the addressable space</returns>
+        private bool ValidateAddressAndSize(ulong va, ulong size)
+        {
+            ulong endVa = va + size;
+            return endVa >= va && endVa >= size && endVa <= _addressSpaceSize;
+        }
+
+        /// <summary>
+        /// Ensures the combination of virtual address and size is part of the addressable space.
+        /// </summary>
+        /// <param name="va">Virtual address of the range</param>
+        /// <param name="size">Size of the range in bytes</param>
+        /// <exception cref="InvalidMemoryRegionException">Throw when the memory region specified outside the addressable space</exception>
+        private void AssertValidAddressAndSize(ulong va, ulong size)
+        {
+            if (!ValidateAddressAndSize(va, size))
+            {
+                throw new InvalidMemoryRegionException($"va=0x{va:X16}, size=0x{size:X16}");
+            }
+        }
+
         /// <summary>
         /// Performs address translation of the address inside a mapped memory range.
         /// </summary>