Просмотр исходного кода

Add missing check for thread termination on ArbitrateLock (#4722)

* Add missing check for thread termination on ArbitrateLock

* Use TerminationRequested in all places where it can be used
gdkchan 3 лет назад
Родитель
Сommit
097562bc6c

+ 2 - 4
Ryujinx.HLE/HOS/Kernel/Ipc/KServerSession.cs

@@ -188,8 +188,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
 
             if (request.AsyncEvent == null)
             {
-                if (request.ClientThread.ShallBeTerminated ||
-                    request.ClientThread.SchedFlags == ThreadSchedState.TerminationPending)
+                if (request.ClientThread.TerminationRequested)
                 {
                     return KernelResult.ThreadTerminating;
                 }
@@ -1104,8 +1103,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Ipc
         {
             foreach (KSessionRequest request in IterateWithRemovalOfAllRequests())
             {
-                if (request.ClientThread.ShallBeTerminated ||
-                    request.ClientThread.SchedFlags == ThreadSchedState.TerminationPending)
+                if (request.ClientThread.TerminationRequested)
                 {
                     continue;
                 }

+ 10 - 6
Ryujinx.HLE/HOS/Kernel/Threading/KAddressArbiter.cs

@@ -31,6 +31,13 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
 
             _context.CriticalSection.Enter();
 
+            if (currentThread.TerminationRequested)
+            {
+                _context.CriticalSection.Leave();
+
+                return KernelResult.ThreadTerminating;
+            }
+
             currentThread.SignaledObj   = null;
             currentThread.ObjSyncResult = Result.Success;
 
@@ -114,8 +121,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
             currentThread.SignaledObj   = null;
             currentThread.ObjSyncResult = KernelResult.TimedOut;
 
-            if (currentThread.ShallBeTerminated ||
-                currentThread.SchedFlags == ThreadSchedState.TerminationPending)
+            if (currentThread.TerminationRequested)
             {
                 _context.CriticalSection.Leave();
 
@@ -280,8 +286,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
 
             _context.CriticalSection.Enter();
 
-            if (currentThread.ShallBeTerminated ||
-                currentThread.SchedFlags == ThreadSchedState.TerminationPending)
+            if (currentThread.TerminationRequested)
             {
                 _context.CriticalSection.Leave();
 
@@ -351,8 +356,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
 
             _context.CriticalSection.Enter();
 
-            if (currentThread.ShallBeTerminated ||
-                currentThread.SchedFlags == ThreadSchedState.TerminationPending)
+            if (currentThread.TerminationRequested)
             {
                 _context.CriticalSection.Leave();
 

+ 1 - 2
Ryujinx.HLE/HOS/Kernel/Threading/KConditionVariable.cs

@@ -19,8 +19,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
 
             currentThread.WithholderNode = threadList.AddLast(currentThread);
 
-            if (currentThread.ShallBeTerminated ||
-                currentThread.SchedFlags == ThreadSchedState.TerminationPending)
+            if (currentThread.TerminationRequested)
             {
                 threadList.Remove(currentThread.WithholderNode);
 

+ 2 - 3
Ryujinx.HLE/HOS/Kernel/Threading/KSynchronization.cs

@@ -47,8 +47,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
 
             KThread currentThread = KernelStatic.GetCurrentThread();
 
-            if (currentThread.ShallBeTerminated ||
-                currentThread.SchedFlags == ThreadSchedState.TerminationPending)
+            if (currentThread.TerminationRequested)
             {
                 result = KernelResult.ThreadTerminating;
             }
@@ -61,7 +60,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
             else
             {
                 LinkedListNode<KThread>[] syncNodesArray = ArrayPool<LinkedListNode<KThread>>.Shared.Rent(syncObjs.Length);
-                
+
                 Span<LinkedListNode<KThread>> syncNodes = syncNodesArray.AsSpan(0, syncObjs.Length);
 
                 for (int index = 0; index < syncObjs.Length; index++)

+ 4 - 8
Ryujinx.HLE/HOS/Kernel/Threading/KThread.cs

@@ -99,11 +99,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
 
         private int _shallBeTerminated;
 
-        public bool ShallBeTerminated
-        {
-            get => _shallBeTerminated != 0;
-            set => _shallBeTerminated = value ? 1 : 0;
-        }
+        private bool ShallBeTerminated => _shallBeTerminated != 0;
 
         public bool TerminationRequested => ShallBeTerminated || SchedFlags == ThreadSchedState.TerminationPending;
 
@@ -322,7 +318,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
 
             ThreadSchedState result;
 
-            if (Interlocked.CompareExchange(ref _shallBeTerminated, 1, 0) == 0)
+            if (Interlocked.Exchange(ref _shallBeTerminated, 1) == 0)
             {
                 if ((SchedFlags & ThreadSchedState.LowMask) == ThreadSchedState.None)
                 {
@@ -470,7 +466,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
         {
             KernelContext.CriticalSection.Enter();
 
-            if (ShallBeTerminated || SchedFlags == ThreadSchedState.TerminationPending)
+            if (TerminationRequested)
             {
                 KernelContext.CriticalSection.Leave();
 
@@ -552,7 +548,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
                     return KernelResult.InvalidState;
                 }
 
-                if (!ShallBeTerminated && SchedFlags != ThreadSchedState.TerminationPending)
+                if (!TerminationRequested)
                 {
                     if (pause)
                     {