SslManagedSocketConnection.cs 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. using Ryujinx.HLE.HOS.Services.Sockets.Bsd;
  2. using Ryujinx.HLE.HOS.Services.Sockets.Bsd.Impl;
  3. using Ryujinx.HLE.HOS.Services.Ssl.Types;
  4. using System;
  5. using System.IO;
  6. using System.Net.Security;
  7. using System.Net.Sockets;
  8. using System.Security.Authentication;
  9. namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
  10. {
  11. class SslManagedSocketConnection : ISslConnectionBase
  12. {
  13. public int SocketFd { get; }
  14. public ISocket Socket { get; }
  15. private BsdContext _bsdContext;
  16. private SslVersion _sslVersion;
  17. private SslStream _stream;
  18. private bool _isBlockingSocket;
  19. private int _previousReadTimeout;
  20. public SslManagedSocketConnection(BsdContext bsdContext, SslVersion sslVersion, int socketFd, ISocket socket)
  21. {
  22. _bsdContext = bsdContext;
  23. _sslVersion = sslVersion;
  24. SocketFd = socketFd;
  25. Socket = socket;
  26. }
  27. private void StartSslOperation()
  28. {
  29. // Save blocking state
  30. _isBlockingSocket = Socket.Blocking;
  31. // Force blocking for SslStream
  32. Socket.Blocking = true;
  33. }
  34. private void EndSslOperation()
  35. {
  36. // Restore blocking state
  37. Socket.Blocking = _isBlockingSocket;
  38. }
  39. private void StartSslReadOperation()
  40. {
  41. StartSslOperation();
  42. if (!_isBlockingSocket)
  43. {
  44. _previousReadTimeout = _stream.ReadTimeout;
  45. _stream.ReadTimeout = 1;
  46. }
  47. }
  48. private void EndSslReadOperation()
  49. {
  50. if (!_isBlockingSocket)
  51. {
  52. _stream.ReadTimeout = _previousReadTimeout;
  53. }
  54. EndSslOperation();
  55. }
  56. // NOTE: We silence warnings about TLS 1.0 and 1.1 as games will likely use it.
  57. #pragma warning disable SYSLIB0039
  58. private static SslProtocols TranslateSslVersion(SslVersion version)
  59. {
  60. switch (version & SslVersion.VersionMask)
  61. {
  62. case SslVersion.Auto:
  63. return SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12 | SslProtocols.Tls13;
  64. case SslVersion.TlsV10:
  65. return SslProtocols.Tls;
  66. case SslVersion.TlsV11:
  67. return SslProtocols.Tls11;
  68. case SslVersion.TlsV12:
  69. return SslProtocols.Tls12;
  70. case SslVersion.TlsV13:
  71. return SslProtocols.Tls13;
  72. default:
  73. throw new NotImplementedException(version.ToString());
  74. }
  75. }
  76. #pragma warning restore SYSLIB0039
  77. public ResultCode Handshake(string hostName)
  78. {
  79. StartSslOperation();
  80. _stream = new SslStream(new NetworkStream(((ManagedSocket)Socket).Socket, false), false, null, null);
  81. _stream.AuthenticateAsClient(hostName, null, TranslateSslVersion(_sslVersion), false);
  82. EndSslOperation();
  83. return ResultCode.Success;
  84. }
  85. public ResultCode Peek(out int peekCount, Memory<byte> buffer)
  86. {
  87. // NOTE: We cannot support that on .NET SSL API.
  88. // As Nintendo's curl implementation detail check if a connection is alive via Peek, we just return that it would block to let it know that it's alive.
  89. peekCount = -1;
  90. return ResultCode.WouldBlock;
  91. }
  92. public int Pending()
  93. {
  94. // Unsupported
  95. return 0;
  96. }
  97. private static bool TryTranslateWinSockError(bool isBlocking, WsaError error, out ResultCode resultCode)
  98. {
  99. switch (error)
  100. {
  101. case WsaError.WSAETIMEDOUT:
  102. resultCode = isBlocking ? ResultCode.Timeout : ResultCode.WouldBlock;
  103. return true;
  104. case WsaError.WSAECONNABORTED:
  105. resultCode = ResultCode.ConnectionAbort;
  106. return true;
  107. case WsaError.WSAECONNRESET:
  108. resultCode = ResultCode.ConnectionReset;
  109. return true;
  110. default:
  111. resultCode = ResultCode.Success;
  112. return false;
  113. }
  114. }
  115. public ResultCode Read(out int readCount, Memory<byte> buffer)
  116. {
  117. if (!Socket.Poll(0, SelectMode.SelectRead))
  118. {
  119. readCount = -1;
  120. return ResultCode.WouldBlock;
  121. }
  122. StartSslReadOperation();
  123. try
  124. {
  125. readCount = _stream.Read(buffer.Span);
  126. }
  127. catch (IOException exception)
  128. {
  129. readCount = -1;
  130. if (exception.InnerException is SocketException socketException)
  131. {
  132. WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode;
  133. if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result))
  134. {
  135. return result;
  136. }
  137. else
  138. {
  139. throw socketException;
  140. }
  141. }
  142. else
  143. {
  144. throw exception;
  145. }
  146. }
  147. finally
  148. {
  149. EndSslReadOperation();
  150. }
  151. return ResultCode.Success;
  152. }
  153. public ResultCode Write(out int writtenCount, ReadOnlyMemory<byte> buffer)
  154. {
  155. if (!Socket.Poll(0, SelectMode.SelectWrite))
  156. {
  157. writtenCount = 0;
  158. return ResultCode.WouldBlock;
  159. }
  160. StartSslOperation();
  161. try
  162. {
  163. _stream.Write(buffer.Span);
  164. }
  165. catch (IOException exception)
  166. {
  167. writtenCount = -1;
  168. if (exception.InnerException is SocketException socketException)
  169. {
  170. WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode;
  171. if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result))
  172. {
  173. return result;
  174. }
  175. else
  176. {
  177. throw socketException;
  178. }
  179. }
  180. else
  181. {
  182. throw exception;
  183. }
  184. }
  185. finally
  186. {
  187. EndSslOperation();
  188. }
  189. // .NET API doesn't provide the size written, assume all written.
  190. writtenCount = buffer.Length;
  191. return ResultCode.Success;
  192. }
  193. public ResultCode GetServerCertificate(string hostname, Span<byte> certificates, out uint storageSize, out uint certificateCount)
  194. {
  195. byte[] rawCertData = _stream.RemoteCertificate.GetRawCertData();
  196. storageSize = (uint)rawCertData.Length;
  197. certificateCount = 1;
  198. if (rawCertData.Length > certificates.Length)
  199. {
  200. return ResultCode.CertBufferTooSmall;
  201. }
  202. rawCertData.CopyTo(certificates);
  203. return ResultCode.Success;
  204. }
  205. public void Dispose()
  206. {
  207. _bsdContext.CloseFileDescriptor(SocketFd);
  208. }
  209. }
  210. }