SslManagedSocketConnection.cs 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. using Ryujinx.HLE.HOS.Services.Sockets.Bsd;
  2. using Ryujinx.HLE.HOS.Services.Ssl.Types;
  3. using System;
  4. using System.IO;
  5. using System.Net.Security;
  6. using System.Net.Sockets;
  7. using System.Security.Authentication;
  8. namespace Ryujinx.HLE.HOS.Services.Ssl.SslService
  9. {
  10. class SslManagedSocketConnection : ISslConnectionBase
  11. {
  12. public int SocketFd { get; }
  13. public ISocket Socket { get; }
  14. private BsdContext _bsdContext;
  15. private SslVersion _sslVersion;
  16. private SslStream _stream;
  17. private bool _isBlockingSocket;
  18. private int _previousReadTimeout;
  19. public SslManagedSocketConnection(BsdContext bsdContext, SslVersion sslVersion, int socketFd, ISocket socket)
  20. {
  21. _bsdContext = bsdContext;
  22. _sslVersion = sslVersion;
  23. SocketFd = socketFd;
  24. Socket = socket;
  25. }
  26. private void StartSslOperation()
  27. {
  28. // Save blocking state
  29. _isBlockingSocket = Socket.Blocking;
  30. // Force blocking for SslStream
  31. Socket.Blocking = true;
  32. }
  33. private void EndSslOperation()
  34. {
  35. // Restore blocking state
  36. Socket.Blocking = _isBlockingSocket;
  37. }
  38. private void StartSslReadOperation()
  39. {
  40. StartSslOperation();
  41. if (!_isBlockingSocket)
  42. {
  43. _previousReadTimeout = _stream.ReadTimeout;
  44. _stream.ReadTimeout = 1;
  45. }
  46. }
  47. private void EndSslReadOperation()
  48. {
  49. if (!_isBlockingSocket)
  50. {
  51. _stream.ReadTimeout = _previousReadTimeout;
  52. }
  53. EndSslOperation();
  54. }
  55. private static SslProtocols TranslateSslVersion(SslVersion version)
  56. {
  57. switch (version & SslVersion.VersionMask)
  58. {
  59. case SslVersion.Auto:
  60. return SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12 | SslProtocols.Tls13;
  61. case SslVersion.TlsV10:
  62. return SslProtocols.Tls;
  63. case SslVersion.TlsV11:
  64. return SslProtocols.Tls11;
  65. case SslVersion.TlsV12:
  66. return SslProtocols.Tls12;
  67. case SslVersion.TlsV13:
  68. return SslProtocols.Tls13;
  69. default:
  70. throw new NotImplementedException(version.ToString());
  71. }
  72. }
  73. public ResultCode Handshake(string hostName)
  74. {
  75. StartSslOperation();
  76. _stream = new SslStream(new NetworkStream(((ManagedSocket)Socket).Socket, false), false, null, null);
  77. _stream.AuthenticateAsClient(hostName, null, TranslateSslVersion(_sslVersion), false);
  78. EndSslOperation();
  79. return ResultCode.Success;
  80. }
  81. public ResultCode Peek(out int peekCount, Memory<byte> buffer)
  82. {
  83. // NOTE: We cannot support that on .NET SSL API.
  84. // As Nintendo's curl implementation detail check if a connection is alive via Peek, we just return that it would block to let it know that it's alive.
  85. peekCount = -1;
  86. return ResultCode.WouldBlock;
  87. }
  88. public int Pending()
  89. {
  90. // Unsupported
  91. return 0;
  92. }
  93. private static bool TryTranslateWinSockError(bool isBlocking, WsaError error, out ResultCode resultCode)
  94. {
  95. switch (error)
  96. {
  97. case WsaError.WSAETIMEDOUT:
  98. resultCode = isBlocking ? ResultCode.Timeout : ResultCode.WouldBlock;
  99. return true;
  100. case WsaError.WSAECONNABORTED:
  101. resultCode = ResultCode.ConnectionAbort;
  102. return true;
  103. case WsaError.WSAECONNRESET:
  104. resultCode = ResultCode.ConnectionReset;
  105. return true;
  106. default:
  107. resultCode = ResultCode.Success;
  108. return false;
  109. }
  110. }
  111. public ResultCode Read(out int readCount, Memory<byte> buffer)
  112. {
  113. if (!Socket.Poll(0, SelectMode.SelectRead))
  114. {
  115. readCount = -1;
  116. return ResultCode.WouldBlock;
  117. }
  118. StartSslReadOperation();
  119. try
  120. {
  121. readCount = _stream.Read(buffer.Span);
  122. }
  123. catch (IOException exception)
  124. {
  125. readCount = -1;
  126. if (exception.InnerException is SocketException socketException)
  127. {
  128. WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode;
  129. if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result))
  130. {
  131. return result;
  132. }
  133. else
  134. {
  135. throw socketException;
  136. }
  137. }
  138. else
  139. {
  140. throw exception;
  141. }
  142. }
  143. finally
  144. {
  145. EndSslReadOperation();
  146. }
  147. return ResultCode.Success;
  148. }
  149. public ResultCode Write(out int writtenCount, ReadOnlyMemory<byte> buffer)
  150. {
  151. if (!Socket.Poll(0, SelectMode.SelectWrite))
  152. {
  153. writtenCount = 0;
  154. return ResultCode.WouldBlock;
  155. }
  156. StartSslOperation();
  157. try
  158. {
  159. _stream.Write(buffer.Span);
  160. }
  161. catch (IOException exception)
  162. {
  163. writtenCount = -1;
  164. if (exception.InnerException is SocketException socketException)
  165. {
  166. WsaError socketErrorCode = (WsaError)socketException.SocketErrorCode;
  167. if (TryTranslateWinSockError(_isBlockingSocket, socketErrorCode, out ResultCode result))
  168. {
  169. return result;
  170. }
  171. else
  172. {
  173. throw socketException;
  174. }
  175. }
  176. else
  177. {
  178. throw exception;
  179. }
  180. }
  181. finally
  182. {
  183. EndSslOperation();
  184. }
  185. // .NET API doesn't provide the size written, assume all written.
  186. writtenCount = buffer.Length;
  187. return ResultCode.Success;
  188. }
  189. public ResultCode GetServerCertificate(string hostname, Span<byte> certificates, out uint storageSize, out uint certificateCount)
  190. {
  191. byte[] rawCertData = _stream.RemoteCertificate.GetRawCertData();
  192. storageSize = (uint)rawCertData.Length;
  193. certificateCount = 1;
  194. if (rawCertData.Length > certificates.Length)
  195. {
  196. return ResultCode.CertBufferTooSmall;
  197. }
  198. rawCertData.CopyTo(certificates);
  199. return ResultCode.Success;
  200. }
  201. public void Dispose()
  202. {
  203. _bsdContext.CloseFileDescriptor(SocketFd);
  204. }
  205. }
  206. }