Просмотр исходного кода

Update BSD service implementation (#363)

* Update BSD service to handle libnx's 'smart IPC buffers' for address info

* Use existing "GetBufferType0x21" for certain BSD socket methods

* Parse address port as unsigned short

* Fix bounds check on reading the IPC buffer

* Implement Read, Write methods

* rebased and cleaned

* addressed nits

* remove unused swap method

* fixed alignments
emmauss 7 лет назад
Родитель
Сommit
da7e702751

+ 18 - 18
Ryujinx.HLE/HOS/Ipc/IpcMessage.cs

@@ -174,39 +174,39 @@ namespace Ryujinx.HLE.HOS.Ipc
             return 0;
         }
 
-        public (long Position, long Size) GetBufferType0x21()
+        public (long Position, long Size) GetBufferType0x21(int Index = 0)
         {
-            if (PtrBuff.Count       != 0 &&
-                PtrBuff[0].Position != 0 &&
-                PtrBuff[0].Size     != 0)
+            if (PtrBuff.Count > Index &&
+                PtrBuff[Index].Position != 0 &&
+                PtrBuff[Index].Size     != 0)
             {
-                return (PtrBuff[0].Position, PtrBuff[0].Size);
+                return (PtrBuff[Index].Position, PtrBuff[Index].Size);
             }
 
-            if (SendBuff.Count       != 0 &&
-                SendBuff[0].Position != 0 &&
-                SendBuff[0].Size     != 0)
+            if (SendBuff.Count > Index &&
+                SendBuff[Index].Position != 0 &&
+                SendBuff[Index].Size     != 0)
             {
-                return (SendBuff[0].Position, SendBuff[0].Size);
+                return (SendBuff[Index].Position, SendBuff[Index].Size);
             }
 
             return (0, 0);
         }
 
-        public (long Position, long Size) GetBufferType0x22()
+        public (long Position, long Size) GetBufferType0x22(int Index = 0)
         {
-            if (RecvListBuff.Count       != 0 &&
-                RecvListBuff[0].Position != 0 &&
-                RecvListBuff[0].Size     != 0)
+            if (RecvListBuff.Count > Index &&
+                RecvListBuff[Index].Position != 0 &&
+                RecvListBuff[Index].Size     != 0)
             {
-                return (RecvListBuff[0].Position, RecvListBuff[0].Size);
+                return (RecvListBuff[Index].Position, RecvListBuff[Index].Size);
             }
 
-            if (ReceiveBuff.Count       != 0 &&
-                ReceiveBuff[0].Position != 0 &&
-                ReceiveBuff[0].Size     != 0)
+            if (ReceiveBuff.Count > Index &&
+                ReceiveBuff[Index].Position != 0 &&
+                ReceiveBuff[Index].Size     != 0)
             {
-                return (ReceiveBuff[0].Position, ReceiveBuff[0].Size);
+                return (ReceiveBuff[Index].Position, ReceiveBuff[Index].Size);
             }
 
             return (0, 0);

+ 79 - 24
Ryujinx.HLE/HOS/Services/Bsd/IClient.cs

@@ -32,6 +32,8 @@ namespace Ryujinx.HLE.HOS.Services.Bsd
                 { 14, Connect         },
                 { 18, Listen          },
                 { 21, SetSockOpt      },
+                { 24, Write           },
+                { 25, Read            },
                 { 26, Close           }
             };
         }
@@ -122,15 +124,15 @@ namespace Ryujinx.HLE.HOS.Services.Bsd
             int SocketId    = Context.RequestData.ReadInt32();
             int SocketFlags = Context.RequestData.ReadInt32();
 
-            byte[] ReceivedBuffer = new byte[Context.Request.ReceiveBuff[0].Size];
+            (long ReceivePosition, long ReceiveLength) = Context.Request.GetBufferType0x22();
+
+            byte[] ReceivedBuffer = new byte[ReceiveLength];
 
             try
             {
                 int BytesRead = Sockets[SocketId].Handle.Receive(ReceivedBuffer);
 
-                //Logging.Debug("Received Buffer:" + Environment.NewLine + Logging.HexDump(ReceivedBuffer));
-
-                Context.Memory.WriteBytes(Context.Request.ReceiveBuff[0].Position, ReceivedBuffer);
+                Context.Memory.WriteBytes(ReceivePosition, ReceivedBuffer);
 
                 Context.ResponseData.Write(BytesRead);
                 Context.ResponseData.Write(0);
@@ -150,13 +152,12 @@ namespace Ryujinx.HLE.HOS.Services.Bsd
             int SocketId    = Context.RequestData.ReadInt32();
             int SocketFlags = Context.RequestData.ReadInt32();
 
-            byte[] SentBuffer = Context.Memory.ReadBytes(Context.Request.SendBuff[0].Position,
-                                                         Context.Request.SendBuff[0].Size);
+            (long SentPosition, long SentSize) = Context.Request.GetBufferType0x21();
+
+            byte[] SentBuffer = Context.Memory.ReadBytes(SentPosition, SentSize);
 
             try
             {
-                //Logging.Debug("Sent Buffer:" + Environment.NewLine + Logging.HexDump(SentBuffer));
-
                 int BytesSent = Sockets[SocketId].Handle.Send(SentBuffer);
 
                 Context.ResponseData.Write(BytesSent);
@@ -180,8 +181,9 @@ namespace Ryujinx.HLE.HOS.Services.Bsd
             byte[] SentBuffer = Context.Memory.ReadBytes(Context.Request.SendBuff[0].Position,
                                                          Context.Request.SendBuff[0].Size);
 
-            byte[] AddressBuffer = Context.Memory.ReadBytes(Context.Request.SendBuff[1].Position,
-                                                            Context.Request.SendBuff[1].Size);
+            (long AddressPosition, long AddressSize) = Context.Request.GetBufferType0x21(Index: 1);
+
+            byte[] AddressBuffer = Context.Memory.ReadBytes(AddressPosition, AddressSize);
 
             if (!Sockets[SocketId].Handle.Connected)
             {
@@ -200,8 +202,6 @@ namespace Ryujinx.HLE.HOS.Services.Bsd
 
             try
             {
-                //Logging.Debug("Sent Buffer:" + Environment.NewLine + Logging.HexDump(SentBuffer));
-
                 int BytesSent = Sockets[SocketId].Handle.Send(SentBuffer);
 
                 Context.ResponseData.Write(BytesSent);
@@ -221,7 +221,7 @@ namespace Ryujinx.HLE.HOS.Services.Bsd
         {
             int SocketId = Context.RequestData.ReadInt32();
 
-            long AddrBufferPtr = Context.Request.ReceiveBuff[0].Position;
+            (long AddrBufferPosition, long AddrBuffSize) = Context.Request.GetBufferType0x22();
 
             Socket HandleAccept = null;
 
@@ -246,7 +246,7 @@ namespace Ryujinx.HLE.HOS.Services.Bsd
                 {
                     IpAddress = ((IPEndPoint)Sockets[SocketId].Handle.LocalEndPoint).Address,
                     RemoteEP  = ((IPEndPoint)Sockets[SocketId].Handle.LocalEndPoint),
-                    Handle    = HandleAccept
+                    Handle = HandleAccept
                 };
 
                 Sockets.Add(NewBsdSocket);
@@ -265,7 +265,7 @@ namespace Ryujinx.HLE.HOS.Services.Bsd
 
                     Writer.Write(IpAddress);
 
-                    Context.Memory.WriteBytes(AddrBufferPtr, MS.ToArray());
+                    Context.Memory.WriteBytes(AddrBufferPosition, MS.ToArray());
 
                     Context.ResponseData.Write(Sockets.Count - 1);
                     Context.ResponseData.Write(0);
@@ -286,8 +286,9 @@ namespace Ryujinx.HLE.HOS.Services.Bsd
         {
             int SocketId = Context.RequestData.ReadInt32();
 
-            byte[] AddressBuffer = Context.Memory.ReadBytes(Context.Request.SendBuff[0].Position,
-                                                            Context.Request.SendBuff[0].Size);
+            (long AddressPosition, long AddressSize) = Context.Request.GetBufferType0x21();
+
+            byte[] AddressBuffer = Context.Memory.ReadBytes(AddressPosition, AddressSize);
 
             try
             {
@@ -310,8 +311,9 @@ namespace Ryujinx.HLE.HOS.Services.Bsd
         {
             int SocketId = Context.RequestData.ReadInt32();
 
-            byte[] AddressBuffer = Context.Memory.ReadBytes(Context.Request.SendBuff[0].Position,
-                                                            Context.Request.SendBuff[0].Size);
+            (long AddressPosition, long AddressSize) = Context.Request.GetBufferType0x21();
+
+            byte[] AddressBuffer = Context.Memory.ReadBytes(AddressPosition, AddressSize);
 
             try
             {
@@ -359,8 +361,8 @@ namespace Ryujinx.HLE.HOS.Services.Bsd
         {
             int SocketId = Context.RequestData.ReadInt32();
 
-            SocketOptionLevel SocketLevel      = (SocketOptionLevel)Context.RequestData.ReadInt32();
-            SocketOptionName  SocketOptionName =  (SocketOptionName)Context.RequestData.ReadInt32();
+            SocketOptionLevel SocketLevel     = (SocketOptionLevel)Context.RequestData.ReadInt32();
+            SocketOptionName SocketOptionName = (SocketOptionName)Context.RequestData.ReadInt32();
 
             byte[] SocketOptionValue = Context.Memory.ReadBytes(Context.Request.PtrBuff[0].Position,
                                                                 Context.Request.PtrBuff[0].Size);
@@ -383,6 +385,60 @@ namespace Ryujinx.HLE.HOS.Services.Bsd
             return 0;
         }
 
+        //(u32 socket, buffer<i8, 0x21, 0> message) -> (i32 ret, u32 bsd_errno)
+        public long Write(ServiceCtx Context)
+        {
+            int SocketId = Context.RequestData.ReadInt32();
+
+            (long SentPosition, long SentSize) = Context.Request.GetBufferType0x21();
+
+            byte[] SentBuffer = Context.Memory.ReadBytes(SentPosition, SentSize);
+
+            try
+            {
+                //Logging.Debug("Wrote Buffer:" + Environment.NewLine + Logging.HexDump(SentBuffer));
+
+                int BytesSent = Sockets[SocketId].Handle.Send(SentBuffer);
+
+                Context.ResponseData.Write(BytesSent);
+                Context.ResponseData.Write(0);
+            }
+            catch (SocketException Ex)
+            {
+                Context.ResponseData.Write(-1);
+                Context.ResponseData.Write(Ex.ErrorCode - 10000);
+            }
+
+            return 0;
+        }
+
+        //(u32 socket) -> (i32 ret, u32 bsd_errno, buffer<i8, 0x22, 0> message)
+        public long Read(ServiceCtx Context)
+        {
+            int SocketId = Context.RequestData.ReadInt32();
+
+            (long ReceivePosition, long ReceiveLength) = Context.Request.GetBufferType0x22();
+
+            byte[] ReceivedBuffer = new byte[ReceiveLength];
+
+            try
+            {
+                int BytesRead = Sockets[SocketId].Handle.Receive(ReceivedBuffer);
+
+                Context.Memory.WriteBytes(ReceivePosition, ReceivedBuffer);
+
+                Context.ResponseData.Write(BytesRead);
+                Context.ResponseData.Write(0);
+            }
+            catch (SocketException Ex)
+            {
+                Context.ResponseData.Write(-1);
+                Context.ResponseData.Write(Ex.ErrorCode - 10000);
+            }
+
+            return 0;
+        }
+
         //(u32 socket) -> (i32 ret, u32 bsd_errno)
         public long Close(ServiceCtx Context)
         {
@@ -413,7 +469,7 @@ namespace Ryujinx.HLE.HOS.Services.Bsd
 
                 int Size   = Reader.ReadByte();
                 int Family = Reader.ReadByte();
-                int Port   = EndianSwap.Swap16(Reader.ReadInt16());
+                int Port   = EndianSwap.Swap16(Reader.ReadUInt16());
 
                 string IpAddress = Reader.ReadByte().ToString() + "." +
                                    Reader.ReadByte().ToString() + "." +
@@ -421,8 +477,7 @@ namespace Ryujinx.HLE.HOS.Services.Bsd
                                    Reader.ReadByte().ToString();
 
                 Sockets[SocketId].IpAddress = IPAddress.Parse(IpAddress);
-
-                Sockets[SocketId].RemoteEP = new IPEndPoint(Sockets[SocketId].IpAddress, Port);
+                Sockets[SocketId].RemoteEP  = new IPEndPoint(Sockets[SocketId].IpAddress, Port);
             }
         }
 

+ 1 - 1
Ryujinx.HLE/Utilities/EndianSwap.cs

@@ -2,7 +2,7 @@
 {
     static class EndianSwap
     {
-        public static short Swap16(short Value) => (short)(((Value >> 8) & 0xff) | (Value << 8));
+        public static ushort Swap16(ushort Value) => (ushort)(((Value >> 8) & 0xff) | (Value << 8));
 
         public static int Swap32(int Value)
         {