|
|
using Apewer.Network; using System; using System.Collections.Generic; using System.IO; using System.Net; using System.Net.Security; using System.Net.Sockets; using System.Security.Authentication; using System.Security.Cryptography.X509Certificates; using System.Text; using System.Threading;
namespace Apewer.Web {
/// <summary>Socket 连接。</summary>
public sealed class MiniConnection {
// 超时毫秒数。
const int RequestHeadTimeout = 15 * 1000; const int RequestBodyTimeout = 3600 * 1000; const int ResponseHeadTimeout = 15 * 1000; const int ResponseBodyTimeout = 3600 * 1000; const int StreamReadTimeout = 3 * 1000; const int StreamWriteTimeout = 3 * 1000;
// 缓冲。
const int BufferSize = 8192; const int HeadMax = 32768;
// 持久变量。
Timer _timer; int _reuses = 0;
// 临时连接变量,每次请求时初始化。
MiniContext _context; ArrayBuilder<byte> _head = null; byte[] _buffer; MiniReader _instream = null; MiniWriter _outstream = null; bool _headsent = false; bool _bodysent = false; bool _forceClose = false; bool _chunked = false; int _headread = 0;
internal MiniConnection(MiniServer server, Socket socket) { _server = server; _socket = socket;
// init
_netstream = new NetworkStream(socket, false); _netstream.ReadTimeout = StreamReadTimeout; _netstream.WriteTimeout = StreamWriteTimeout; _local = socket.LocalEndPoint as IPEndPoint; _remote = socket.RemoteEndPoint as IPEndPoint; _stream = _netstream;
// SSL 证书。
var certificate = server.SslCertificate; if (certificate != null) { _sslstream = CreateSslStream(_netstream, certificate, server.SslProtocols); _stream = _sslstream; } }
void Timeout(int duration = System.Threading.Timeout.Infinite) { const int Infinite = System.Threading.Timeout.Infinite; if (duration > 0) { if (_timer == null) _timer = new Timer(obj => CloseSocket(), null, Infinite, Infinite); else _timer.Change(duration, Infinite); } else { if (_timer != null) { _timer.Change(Infinite, Infinite); _timer.Dispose(); _timer = null; } } }
#region socket
MiniServer _server = null; Socket _socket = null; Stream _stream; NetworkStream _netstream; SslStream _sslstream; IPEndPoint _local; IPEndPoint _remote;
/// <summary>服务器。</summary>
public MiniServer Server { get => _server; }
/// <summary>本地网络终结点。</summary>
public IPEndPoint LocalEndPoint { get => _local; }
/// <summary>远端网络终结点。</summary>
public IPEndPoint RemoteEndPoint { get => _remote; }
internal Stream Stream { get => _stream; }
void CloseSocket() { Timeout();
if (_sslstream != null) { try { _sslstream.Close(); } catch { } try { _sslstream.Dispose(); } catch { } }
if (_netstream != null) { try { _netstream.Close(); } catch { } try { _netstream.Dispose(); } catch { } }
if (_socket != null) { try { _socket.LingerState = new LingerOption(false, 0); } catch { } try { _socket.Shutdown(SocketShutdown.Both); } catch { } try { _socket.Disconnect(false); } catch { } try { _socket.Close(1); } catch { } _socket = null; } }
/// <summary>关闭 Socket 连接。</summary>
public void Close(bool force = false) { if (_socket == null) return; SendHead(); SendBody();
if (force) { CloseSocket(); return; }
if (!force && !_context.Request.KeepAlive) force = true; if (!force && TextUtility.Lower(_context.Response.Headers.GetValue("connection", true)) == "close") force = true; if (force) { CloseSocket(); return; }
_reuses++; BeginRead(); }
#endregion
#region request
// 发起读取。
internal void BeginRead() { // reset
_context = new MiniContext(this); _head = new ArrayBuilder<byte>(BufferSize); if (_buffer == null) _buffer = new byte[BufferSize]; _instream = null; _outstream = null; _headsent = false; _bodysent = false; _forceClose = false; _chunked = _server.Chunked; _headread = 0;
try { Timeout(RequestHeadTimeout); if (_server.SynchronousIO) { SyncRead(); ProcessHead(); } else { _stream.BeginRead(_buffer, 0, BufferSize, ar2 => ((MiniConnection)ar2.AsyncState).ProcessHead(ar2), this); } } catch { Timeout(); CloseSocket(); } }
void SyncRead() { while (true) { var nread = _stream.Read(_buffer, 0, BufferSize); if (nread > 0) _head.Add(_buffer, 0, nread); else throw new Exception("未接收到完整的请求头。");
var head = ReadHead(); if (head) return;
if (_head.Count > HeadMax) { Send(400); Close(true); return; } } }
// 读取头,返回标记表示已经完成头的读取。
bool ReadHead() { var count = _head.Count; if (count < 5) return false;
var array = _head.Origin; var end = count - 3; if (_headread > 3) _headread -= 3; for (var i = _headread; i < end; i++) { if (array[i] != 13) continue; if (array[i + 1] != 10) continue; if (array[i + 2] != 13) continue; if (array[i + 3] != 10) continue;
// head
var head = new byte[i]; Buffer.BlockCopy(array, 0, head, 0, i); var text = Encoding.ASCII.GetString(head); var lines = text.Split('\r', '\n'); var first = true; foreach (var line in lines) { if (line.IsEmpty()) continue; if (first) { var split = line.Split(' '); var segs = new List<string>(3); foreach (var seg in split) { if (string.IsNullOrEmpty(seg)) continue; segs.Add(seg); }
// GET /index.html HTTP/1.1
if (segs.Count != 3) { Send(400); Close(true); throw new Exception("请求的报文无效。"); }
_context.Request.Method = split[0]; _context.Request.Path = split[1]; _context.Request.Version = split[2]; } else { var colon = line.IndexOf(":"); if (colon > 1) { var name = TextUtility.Trim(line.Substring(0, colon)); var value = TextUtility.Trim(line.Substring(colon + 1)); if (!string.IsNullOrEmpty(value) && !string.IsNullOrEmpty(value)) { _context.Request.Headers.Add(name, value); } } } first = false; }
// body
var remains = count - i - 4; var reader = GetRequestStream(); if (remains > 0) { var contentLength = _context.Request.ContentLength; if (contentLength > 0) { var body = new byte[remains]; Buffer.BlockCopy(array, i + 4, body, 0, remains); reader.RemainsBytes = body; } }
return true; } return false; }
void ProcessHead(IAsyncResult ar) { if (_socket == null) return;
// 读取头。
while (true) { int nread = -1; try { nread = _stream.EndRead(ar); } catch { // 读取失败,发送 400,并关闭 Socket。
Send(400); Close(true); return; }
// 未读取到内容,表示已完成。
if (nread < 1) break;
// 将读取的数据加入缓冲。
_head.Add(_buffer, 0, nread);
// 缓冲区容量已溢出。
if (_head.Length > HeadMax) { Send(400); Close(true); return; }
// 继续读取。
if (nread < BufferSize) break; _stream.BeginRead(_buffer, 0, BufferSize, ar2 => ((MiniConnection)ar2.AsyncState).ProcessHead(ar2), this); return; }
ProcessHead(); }
void ProcessHead() { // 处理过程中禁用计时器。
Timeout();
// 检查读取的数据。
if (_head.Count < 1) { Close(true); return; }
// Expect: 100-continue
// 解析头。
var bytes = _head.Export();
// 检查 HTTP 方法。
var method = ParseMethod(_context.Request.Method); if (method == HttpMethod.NULL) { Send(405); Close(true); return; }
// 检查 HTTP 协议版本。
var version = TextUtility.Upper(_context.Request.Version); switch (version) { case "HTTP/1": case "HTTP/1.0": _context.Request.Http11 = false; break; case "HTTP/1.1": _context.Request.Http11 = true; break; default: Send(505); Close(true); return; }
// 保持连接。
_context.Request.KeepAlive = _context.Request.Http11 && TextUtility.Lower(_context.Request.Headers.GetValue("connection", true)) == "keep-alive";
// 启用压缩。
if (_server.Compression) { var headerValue = _context.Request.Headers.GetValue("accept-encoding", true); if (!string.IsNullOrEmpty(headerValue)) { var split = headerValue.ToLower().Split(','); foreach (var seg in split) { switch (seg) { case "gzip": _context.Request.Gzip = true; break; case "brotli": _context.Request.Brotli = true; break; } } } }
// URL
var host = _context.Request.Headers.GetValue("host", true); var port = 0; var local = LocalEndPoint; if (local != null) { if (string.IsNullOrEmpty(host)) host = local.Address.ToString(); port = local.Port; } _context.Request.Url = new Uri($"http://{host}{_context.Request.Path}");
// Handler
var handler = _server.Handler; if (handler == null) { Send(501); Close(true); return; }
// Invoke
Timeout(); handler.Invoke(_context); SendHead(); SendBody(); }
internal void SendContinue() { Timeout(ResponseHeadTimeout); var http11 = _context.Request.Http11; var text = http11 ? "HTTP/1.1 100 Continue\r\n\r\n" : "HTTP/1.0 100 Continue\r\n\r\n"; var bytes = Encoding.ASCII.GetBytes(text); _socket.Send(bytes); }
internal MiniReader GetRequestStream() { if (_instream != null) return _instream;
var headers = _context.Request.Headers; var length = -1L; var value = headers.GetValue("Content-Length", true); if (!string.IsNullOrEmpty(value)) { var num = value.Int64(); if (num.ToString() == value) length = num; } _instream = new MiniReader(this, length);
Timeout(RequestBodyTimeout); return _instream; }
#endregion
#region response
/// <summary>对响应体分块。</summary>
public bool Chunked { get => _chunked && _context.Request.Http11; }
internal void SendHead() { if (_headsent) return; else _headsent = true; _instream?.Close();
var lines = new ArrayBuilder<string>(32); var response = _context.Response; var headers = _context.Response.Headers;
// redirect
var location = response.Location; if (string.IsNullOrEmpty(location)) location = null;
// version
var isHttp11 = _context.Request.Http11; var version = isHttp11 ? "1.1" : "1.0";
// status
var status = response.Status; if (status == 0) status = 200; if (location != null) status = 302; var statusDesc = NetworkUtility.HttpStatusDescription(status); if (statusDesc.IsEmpty()) statusDesc = "OK"; lines.Add($"HTTP/{version} {status} {statusDesc}");
// keep-alive
var forceClose = true; var keepAlive = false; if (isHttp11) { if (status == 400 || status == 408 || status == 411 || status == 413 || status == 414 || status == 500 || status == 503) forceClose = true; if (_reuses > 128) forceClose = true;
// request: keep-alive
keepAlive = _context.Request.KeepAlive; if (!keepAlive) forceClose = true;
// response: keep-alive
if (!_context.Response.KeepAlive) forceClose = true;
if (keepAlive) { if (forceClose) { lines.Add("Connection: close"); } else { lines.Add("Connection: keep-alive"); lines.Add($"Keep-Alive: timeout=15, max={128 - _reuses}"); } } if (Chunked) lines.Add("Transfer-Encoding: chunked"); }
// redirect
if (location != null) lines.Add($"Location: {location}");
// content-type
var contentType = response.ContentType; if (string.IsNullOrEmpty(contentType)) contentType = headers.GetValue("content-type", true); if (!string.IsNullOrEmpty(contentType)) lines.Add("Content-Type:" + contentType);
// content-length
var contentLength = response.ContentLength; if (contentLength >= 0) lines.Add("Content-length:" + contentLength);
// 自定义头
foreach (var header in headers) { var key = header.Key; if (key.IsEmpty()) continue;
var lower = key.Lower(); if (lower == "content-length") continue; if (lower == "content-type") continue; if (keepAlive) { if (lower == "connection") continue; if (lower == "keep-alive") continue; } if (location == null) { if (lower == "location") continue; }
var value = header.Value; if (value.IsEmpty()) continue;
lines.Add(key + ":" + value); }
// send
Timeout(ResponseHeadTimeout); var text = string.Join("\r\n", lines.Export()); text += "\r\n\r\n"; var bytes = TextUtility.Bytes(text); _socket?.Send(bytes); _forceClose = forceClose; }
internal void SendBody() { if (_bodysent) return; else _bodysent = true; _outstream?.Close(); Close(_forceClose); }
internal MiniWriter GetResponseStream() { if (_outstream != null) return _outstream;
SendHead(); _outstream = new MiniWriter(this);
Timeout(ResponseBodyTimeout); return _outstream; }
/// <summary>发送状态码,并关闭 Socket 连接。</summary>
internal void Send(int status) { var response = _context.Response; response.Status = status; response.Close(); }
#endregion
#region static
static HttpMethod ParseMethod(string method) { if (!string.IsNullOrEmpty(method)) { var upper = TextUtility.Upper(method); if (upper.Contains("OPTIONS")) return HttpMethod.OPTIONS; else if (upper.Contains("POST")) return HttpMethod.POST; else if (upper.Contains("GET")) return HttpMethod.GET; else if (upper.Contains("CONNECT")) return HttpMethod.CONNECT; else if (upper.Contains("DELETE")) return HttpMethod.DELETE; else if (upper.Contains("HEAD")) return HttpMethod.HEAD; else if (upper.Contains("PATCH")) return HttpMethod.PATCH; else if (upper.Contains("PUT")) return HttpMethod.PUT; else if (upper.Contains("TRACE")) return HttpMethod.TRACE; } return HttpMethod.NULL; }
static SslStream CreateSslStream(Stream stream, X509Certificate certificate, SslProtocols protocols) { var sslStream = new SslStream(stream, false, (t, c, ch, e) => { if (c == null) return true; var c2 = c as X509Certificate2; if (c2 == null) c2 = new X509Certificate2(c.GetRawCertData()); return true; }); sslStream.AuthenticateAsServer(certificate, true, protocols, false); return sslStream; }
/// <summary>证书验证。忽略所有错误。</summary>
static bool ApproveAll(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors errors) { return true; }
/// <summary>证书验证。</summary>
static X509Certificate ApproveFirst(object sender, string targetHost, X509CertificateCollection localCertificates, X509Certificate remoteCertificate, string[] acceptableIssuers) { if (localCertificates != null) { for (var i = 0; i < localCertificates.Count; i++) { var certificate = localCertificates[i]; if (certificate != null) return certificate; } } return null; }
static long GetContentLength(StringPairs headers) { if (headers != null) { var value = headers.GetValue("Content-Length", true); if (!string.IsNullOrEmpty(value)) { var num = value.Int64(); if (num.ToString() == value) return num; } } return -1L; }
#endregion
}
}
|