You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

671 lines
21 KiB

3 years ago
3 years ago
3 years ago
  1. using Apewer.Network;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Net;
  6. using System.Net.Security;
  7. using System.Net.Sockets;
  8. using System.Security.Authentication;
  9. using System.Security.Cryptography.X509Certificates;
  10. using System.Text;
  11. using System.Threading;
  12. namespace Apewer.Web
  13. {
  14. /// <summary>Socket 连接。</summary>
  15. public sealed class MiniConnection
  16. {
  17. // 超时毫秒数。
  18. const int RequestHeadTimeout = 15 * 1000;
  19. const int RequestBodyTimeout = 3600 * 1000;
  20. const int ResponseHeadTimeout = 15 * 1000;
  21. const int ResponseBodyTimeout = 3600 * 1000;
  22. const int StreamReadTimeout = 3 * 1000;
  23. const int StreamWriteTimeout = 3 * 1000;
  24. // 缓冲。
  25. const int BufferSize = 8192;
  26. const int HeadMax = 32768;
  27. // 持久变量。
  28. Timer _timer;
  29. int _reuses = 0;
  30. // 临时连接变量,每次请求时初始化。
  31. MiniContext _context;
  32. ArrayBuilder<byte> _head = null;
  33. byte[] _buffer;
  34. MiniReader _instream = null;
  35. MiniWriter _outstream = null;
  36. bool _headsent = false;
  37. bool _bodysent = false;
  38. bool _forceClose = false;
  39. bool _chunked = false;
  40. int _headread = 0;
  41. internal MiniConnection(MiniServer server, Socket socket)
  42. {
  43. _server = server;
  44. _socket = socket;
  45. // init
  46. _netstream = new NetworkStream(socket, false);
  47. _netstream.ReadTimeout = StreamReadTimeout;
  48. _netstream.WriteTimeout = StreamWriteTimeout;
  49. _local = socket.LocalEndPoint as IPEndPoint;
  50. _remote = socket.RemoteEndPoint as IPEndPoint;
  51. _stream = _netstream;
  52. // SSL 证书。
  53. var certificate = server.SslCertificate;
  54. if (certificate != null)
  55. {
  56. _sslstream = CreateSslStream(_netstream, certificate, server.SslProtocols);
  57. _stream = _sslstream;
  58. }
  59. }
  60. void Timeout(int duration = System.Threading.Timeout.Infinite)
  61. {
  62. const int Infinite = System.Threading.Timeout.Infinite;
  63. if (duration > 0)
  64. {
  65. if (_timer == null) _timer = new Timer(obj => CloseSocket(), null, Infinite, Infinite);
  66. else _timer.Change(duration, Infinite);
  67. }
  68. else
  69. {
  70. if (_timer != null)
  71. {
  72. _timer.Change(Infinite, Infinite);
  73. _timer.Dispose();
  74. _timer = null;
  75. }
  76. }
  77. }
  78. #region socket
  79. MiniServer _server = null;
  80. Socket _socket = null;
  81. Stream _stream;
  82. NetworkStream _netstream;
  83. SslStream _sslstream;
  84. IPEndPoint _local;
  85. IPEndPoint _remote;
  86. /// <summary>服务器。</summary>
  87. public MiniServer Server { get => _server; }
  88. /// <summary>本地网络终结点。</summary>
  89. public IPEndPoint LocalEndPoint { get => _local; }
  90. /// <summary>远端网络终结点。</summary>
  91. public IPEndPoint RemoteEndPoint { get => _remote; }
  92. internal Stream Stream { get => _stream; }
  93. void CloseSocket()
  94. {
  95. Timeout();
  96. if (_sslstream != null)
  97. {
  98. try { _sslstream.Close(); } catch { }
  99. try { _sslstream.Dispose(); } catch { }
  100. }
  101. if (_netstream != null)
  102. {
  103. try { _netstream.Close(); } catch { }
  104. try { _netstream.Dispose(); } catch { }
  105. }
  106. if (_socket != null)
  107. {
  108. try { _socket.LingerState = new LingerOption(false, 0); } catch { }
  109. try { _socket.Shutdown(SocketShutdown.Both); } catch { }
  110. try { _socket.Disconnect(false); } catch { }
  111. try { _socket.Close(1); } catch { }
  112. _socket = null;
  113. }
  114. }
  115. /// <summary>关闭 Socket 连接。</summary>
  116. public void Close(bool force = false)
  117. {
  118. if (_socket == null) return;
  119. SendHead();
  120. SendBody();
  121. if (force)
  122. {
  123. CloseSocket();
  124. return;
  125. }
  126. if (!force && !_context.Request.KeepAlive) force = true;
  127. if (!force && TextUtility.Lower(_context.Response.Headers.GetValue("connection", true)) == "close") force = true;
  128. if (force)
  129. {
  130. CloseSocket();
  131. return;
  132. }
  133. _reuses++;
  134. BeginRead();
  135. }
  136. #endregion
  137. #region request
  138. // 发起读取。
  139. internal void BeginRead()
  140. {
  141. // reset
  142. _context = new MiniContext(this);
  143. _head = new ArrayBuilder<byte>(BufferSize);
  144. if (_buffer == null) _buffer = new byte[BufferSize];
  145. _instream = null;
  146. _outstream = null;
  147. _headsent = false;
  148. _bodysent = false;
  149. _forceClose = false;
  150. _chunked = _server.Chunked;
  151. _headread = 0;
  152. try
  153. {
  154. Timeout(RequestHeadTimeout);
  155. if (_server.SynchronousIO)
  156. {
  157. SyncRead();
  158. ProcessHead();
  159. }
  160. else
  161. {
  162. _stream.BeginRead(_buffer, 0, BufferSize, ar2 => ((MiniConnection)ar2.AsyncState).ProcessHead(ar2), this);
  163. }
  164. }
  165. catch
  166. {
  167. Timeout();
  168. CloseSocket();
  169. }
  170. }
  171. void SyncRead()
  172. {
  173. while (true)
  174. {
  175. var nread = _stream.Read(_buffer, 0, BufferSize);
  176. if (nread > 0) _head.Add(_buffer, 0, nread);
  177. else throw new Exception("未接收到完整的请求头。");
  178. var head = ReadHead();
  179. if (head) return;
  180. if (_head.Count > HeadMax)
  181. {
  182. Send(400);
  183. Close(true);
  184. return;
  185. }
  186. }
  187. }
  188. // 读取头,返回标记表示已经完成头的读取。
  189. bool ReadHead()
  190. {
  191. var count = _head.Count;
  192. if (count < 5) return false;
  193. var array = _head.Origin;
  194. var end = count - 3;
  195. if (_headread > 3) _headread -= 3;
  196. for (var i = _headread; i < end; i++)
  197. {
  198. if (array[i] != 13) continue;
  199. if (array[i + 1] != 10) continue;
  200. if (array[i + 2] != 13) continue;
  201. if (array[i + 3] != 10) continue;
  202. // head
  203. var head = new byte[i];
  204. Buffer.BlockCopy(array, 0, head, 0, i);
  205. var text = Encoding.ASCII.GetString(head);
  206. var lines = text.Split('\r', '\n');
  207. var first = true;
  208. foreach (var line in lines)
  209. {
  210. if (line.IsEmpty()) continue;
  211. if (first)
  212. {
  213. var split = line.Split(' ');
  214. var segs = new List<string>(3);
  215. foreach (var seg in split)
  216. {
  217. if (string.IsNullOrEmpty(seg)) continue;
  218. segs.Add(seg);
  219. }
  220. // GET /index.html HTTP/1.1
  221. if (segs.Count != 3)
  222. {
  223. Send(400);
  224. Close(true);
  225. throw new Exception("请求的报文无效。");
  226. }
  227. _context.Request.Method = split[0];
  228. _context.Request.Path = split[1];
  229. _context.Request.Version = split[2];
  230. }
  231. else
  232. {
  233. var colon = line.IndexOf(":");
  234. if (colon > 1)
  235. {
  236. var name = TextUtility.Trim(line.Substring(0, colon));
  237. var value = TextUtility.Trim(line.Substring(colon + 1));
  238. if (!string.IsNullOrEmpty(value) && !string.IsNullOrEmpty(value))
  239. {
  240. _context.Request.Headers.Add(name, value);
  241. }
  242. }
  243. }
  244. first = false;
  245. }
  246. // body
  247. var remains = count - i - 4;
  248. var reader = GetRequestStream();
  249. if (remains > 0)
  250. {
  251. var contentLength = _context.Request.ContentLength;
  252. if (contentLength > 0)
  253. {
  254. var body = new byte[remains];
  255. Buffer.BlockCopy(array, i + 4, body, 0, remains);
  256. reader.RemainsBytes = body;
  257. }
  258. }
  259. return true;
  260. }
  261. return false;
  262. }
  263. void ProcessHead(IAsyncResult ar)
  264. {
  265. if (_socket == null) return;
  266. // 读取头。
  267. while (true)
  268. {
  269. int nread = -1;
  270. try
  271. {
  272. nread = _stream.EndRead(ar);
  273. }
  274. catch
  275. {
  276. // 读取失败,发送 400,并关闭 Socket。
  277. Send(400);
  278. Close(true);
  279. return;
  280. }
  281. // 未读取到内容,表示已完成。
  282. if (nread < 1) break;
  283. // 将读取的数据加入缓冲。
  284. _head.Add(_buffer, 0, nread);
  285. // 缓冲区容量已溢出。
  286. if (_head.Length > HeadMax)
  287. {
  288. Send(400);
  289. Close(true);
  290. return;
  291. }
  292. // 继续读取。
  293. if (nread < BufferSize) break;
  294. _stream.BeginRead(_buffer, 0, BufferSize, ar2 => ((MiniConnection)ar2.AsyncState).ProcessHead(ar2), this);
  295. return;
  296. }
  297. ProcessHead();
  298. }
  299. void ProcessHead()
  300. {
  301. // 处理过程中禁用计时器。
  302. Timeout();
  303. // 检查读取的数据。
  304. if (_head.Count < 1)
  305. {
  306. Close(true);
  307. return;
  308. }
  309. // Expect: 100-continue
  310. // 解析头。
  311. var bytes = _head.Export();
  312. // 检查 HTTP 方法。
  313. var method = ParseMethod(_context.Request.Method);
  314. if (method == HttpMethod.NULL)
  315. {
  316. Send(405);
  317. Close(true);
  318. return;
  319. }
  320. // 检查 HTTP 协议版本。
  321. var version = TextUtility.Upper(_context.Request.Version);
  322. switch (version)
  323. {
  324. case "HTTP/1":
  325. case "HTTP/1.0":
  326. _context.Request.Http11 = false;
  327. break;
  328. case "HTTP/1.1":
  329. _context.Request.Http11 = true;
  330. break;
  331. default:
  332. Send(505);
  333. Close(true);
  334. return;
  335. }
  336. // 保持连接。
  337. _context.Request.KeepAlive = _context.Request.Http11 && TextUtility.Lower(_context.Request.Headers.GetValue("connection", true)) == "keep-alive";
  338. // 启用压缩。
  339. if (_server.Compression)
  340. {
  341. var headerValue = _context.Request.Headers.GetValue("accept-encoding", true);
  342. if (!string.IsNullOrEmpty(headerValue))
  343. {
  344. var split = headerValue.ToLower().Split(',');
  345. foreach (var seg in split)
  346. {
  347. switch (seg)
  348. {
  349. case "gzip": _context.Request.Gzip = true; break;
  350. case "brotli": _context.Request.Brotli = true; break;
  351. }
  352. }
  353. }
  354. }
  355. // URL
  356. var host = _context.Request.Headers.GetValue("host", true);
  357. var port = 0;
  358. var local = LocalEndPoint;
  359. if (local != null)
  360. {
  361. if (string.IsNullOrEmpty(host)) host = local.Address.ToString();
  362. port = local.Port;
  363. }
  364. _context.Request.Url = new Uri($"http://{host}{_context.Request.Path}");
  365. // Handler
  366. var handler = _server.Handler;
  367. if (handler == null)
  368. {
  369. Send(501);
  370. Close(true);
  371. return;
  372. }
  373. // Invoke
  374. Timeout();
  375. handler.Invoke(_context);
  376. SendHead();
  377. SendBody();
  378. }
  379. internal void SendContinue()
  380. {
  381. Timeout(ResponseHeadTimeout);
  382. var http11 = _context.Request.Http11;
  383. var text = http11 ? "HTTP/1.1 100 Continue\r\n\r\n" : "HTTP/1.0 100 Continue\r\n\r\n";
  384. var bytes = Encoding.ASCII.GetBytes(text);
  385. _socket.Send(bytes);
  386. }
  387. internal MiniReader GetRequestStream()
  388. {
  389. if (_instream != null) return _instream;
  390. var headers = _context.Request.Headers;
  391. var length = -1L;
  392. var value = headers.GetValue("Content-Length", true);
  393. if (!string.IsNullOrEmpty(value))
  394. {
  395. var num = value.Int64();
  396. if (num.ToString() == value) length = num;
  397. }
  398. _instream = new MiniReader(this, length);
  399. Timeout(RequestBodyTimeout);
  400. return _instream;
  401. }
  402. #endregion
  403. #region response
  404. /// <summary>对响应体分块。</summary>
  405. public bool Chunked { get => _chunked && _context.Request.Http11; }
  406. internal void SendHead()
  407. {
  408. if (_headsent) return; else _headsent = true;
  409. _instream?.Close();
  410. var lines = new ArrayBuilder<string>(32);
  411. var response = _context.Response;
  412. var headers = _context.Response.Headers;
  413. // redirect
  414. var location = response.Location;
  415. if (string.IsNullOrEmpty(location)) location = null;
  416. // version
  417. var isHttp11 = _context.Request.Http11;
  418. var version = isHttp11 ? "1.1" : "1.0";
  419. // status
  420. var status = response.Status;
  421. if (status == 0) status = 200;
  422. if (location != null) status = 302;
  423. var statusDesc = NetworkUtility.HttpStatusDescription(status);
  424. if (statusDesc.IsEmpty()) statusDesc = "OK";
  425. lines.Add($"HTTP/{version} {status} {statusDesc}");
  426. // keep-alive
  427. var forceClose = true;
  428. var keepAlive = false;
  429. if (isHttp11)
  430. {
  431. if (status == 400 || status == 408 || status == 411 || status == 413 || status == 414 || status == 500 || status == 503) forceClose = true;
  432. if (_reuses > 128) forceClose = true;
  433. // request: keep-alive
  434. keepAlive = _context.Request.KeepAlive;
  435. if (!keepAlive) forceClose = true;
  436. // response: keep-alive
  437. if (!_context.Response.KeepAlive) forceClose = true;
  438. if (keepAlive)
  439. {
  440. if (forceClose)
  441. {
  442. lines.Add("Connection: close");
  443. }
  444. else
  445. {
  446. lines.Add("Connection: keep-alive");
  447. lines.Add($"Keep-Alive: timeout=15, max={128 - _reuses}");
  448. }
  449. }
  450. if (Chunked) lines.Add("Transfer-Encoding: chunked");
  451. }
  452. // redirect
  453. if (location != null) lines.Add($"Location: {location}");
  454. // content-type
  455. var contentType = response.ContentType;
  456. if (string.IsNullOrEmpty(contentType)) contentType = headers.GetValue("content-type", true);
  457. if (!string.IsNullOrEmpty(contentType)) lines.Add("Content-Type:" + contentType);
  458. // content-length
  459. var contentLength = response.ContentLength;
  460. if (contentLength >= 0) lines.Add("Content-length:" + contentLength);
  461. // 自定义头
  462. foreach (var header in headers)
  463. {
  464. var key = header.Key;
  465. if (key.IsEmpty()) continue;
  466. var lower = key.Lower();
  467. if (lower == "content-length") continue;
  468. if (lower == "content-type") continue;
  469. if (keepAlive)
  470. {
  471. if (lower == "connection") continue;
  472. if (lower == "keep-alive") continue;
  473. }
  474. if (location == null)
  475. {
  476. if (lower == "location") continue;
  477. }
  478. var value = header.Value;
  479. if (value.IsEmpty()) continue;
  480. lines.Add(key + ":" + value);
  481. }
  482. // send
  483. Timeout(ResponseHeadTimeout);
  484. var text = string.Join("\r\n", lines.Export());
  485. text += "\r\n\r\n";
  486. var bytes = TextUtility.Bytes(text);
  487. _socket?.Send(bytes);
  488. _forceClose = forceClose;
  489. }
  490. internal void SendBody()
  491. {
  492. if (_bodysent) return; else _bodysent = true;
  493. _outstream?.Close();
  494. Close(_forceClose);
  495. }
  496. internal MiniWriter GetResponseStream()
  497. {
  498. if (_outstream != null) return _outstream;
  499. SendHead();
  500. _outstream = new MiniWriter(this);
  501. Timeout(ResponseBodyTimeout);
  502. return _outstream;
  503. }
  504. /// <summary>发送状态码,并关闭 Socket 连接。</summary>
  505. internal void Send(int status)
  506. {
  507. var response = _context.Response;
  508. response.Status = status;
  509. response.Close();
  510. }
  511. #endregion
  512. #region static
  513. static HttpMethod ParseMethod(string method)
  514. {
  515. if (!string.IsNullOrEmpty(method))
  516. {
  517. var upper = TextUtility.Upper(method);
  518. if (upper.Contains("OPTIONS")) return HttpMethod.OPTIONS;
  519. else if (upper.Contains("POST")) return HttpMethod.POST;
  520. else if (upper.Contains("GET")) return HttpMethod.GET;
  521. else if (upper.Contains("CONNECT")) return HttpMethod.CONNECT;
  522. else if (upper.Contains("DELETE")) return HttpMethod.DELETE;
  523. else if (upper.Contains("HEAD")) return HttpMethod.HEAD;
  524. else if (upper.Contains("PATCH")) return HttpMethod.PATCH;
  525. else if (upper.Contains("PUT")) return HttpMethod.PUT;
  526. else if (upper.Contains("TRACE")) return HttpMethod.TRACE;
  527. }
  528. return HttpMethod.NULL;
  529. }
  530. static SslStream CreateSslStream(Stream stream, X509Certificate certificate, SslProtocols protocols)
  531. {
  532. var sslStream = new SslStream(stream, false, (t, c, ch, e) =>
  533. {
  534. if (c == null) return true;
  535. var c2 = c as X509Certificate2;
  536. if (c2 == null) c2 = new X509Certificate2(c.GetRawCertData());
  537. return true;
  538. });
  539. sslStream.AuthenticateAsServer(certificate, true, protocols, false);
  540. return sslStream;
  541. }
  542. /// <summary>证书验证。忽略所有错误。</summary>
  543. static bool ApproveAll(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors errors)
  544. {
  545. return true;
  546. }
  547. /// <summary>证书验证。</summary>
  548. static X509Certificate ApproveFirst(object sender, string targetHost, X509CertificateCollection localCertificates, X509Certificate remoteCertificate, string[] acceptableIssuers)
  549. {
  550. if (localCertificates != null)
  551. {
  552. for (var i = 0; i < localCertificates.Count; i++)
  553. {
  554. var certificate = localCertificates[i];
  555. if (certificate != null) return certificate;
  556. }
  557. }
  558. return null;
  559. }
  560. static long GetContentLength(StringPairs headers)
  561. {
  562. if (headers != null)
  563. {
  564. var value = headers.GetValue("Content-Length", true);
  565. if (!string.IsNullOrEmpty(value))
  566. {
  567. var num = value.Int64();
  568. if (num.ToString() == value) return num;
  569. }
  570. }
  571. return -1L;
  572. }
  573. #endregion
  574. }
  575. }