Add time and size limit to remote requests

This commit is contained in:
Junegunn Choi 2022-12-22 20:44:49 +09:00
parent fd1f7665a7
commit 1a9761736e
No known key found for this signature in database
GPG Key ID: 254BC280FEF9C627

View File

@ -8,12 +8,15 @@ import (
"net" "net"
"strconv" "strconv"
"strings" "strings"
"time"
) )
const ( const (
crlf = "\r\n" crlf = "\r\n"
httpOk = "HTTP/1.1 200 OK" + crlf httpOk = "HTTP/1.1 200 OK" + crlf
httpBadRequest = "HTTP/1.1 400 Bad Request" + crlf httpBadRequest = "HTTP/1.1 400 Bad Request" + crlf
httpReadTimeout = 10 * time.Second
maxContentLength = 1024 * 1024
) )
func startHttpServer(port int, channel chan []*action) error { func startHttpServer(port int, channel chan []*action) error {
@ -52,14 +55,13 @@ func startHttpServer(port int, channel chan []*action) error {
// * --listen with net/http: 5.7MB // * --listen with net/http: 5.7MB
// * --listen w/o net/http: 3.3MB // * --listen w/o net/http: 3.3MB
func handleHttpRequest(conn net.Conn, channel chan []*action) string { func handleHttpRequest(conn net.Conn, channel chan []*action) string {
line := 0
headerRead := false
contentLength := 0 contentLength := 0
body := "" body := ""
bad := func(message string) string { bad := func(message string) string {
message += "\n" message += "\n"
return httpBadRequest + fmt.Sprintf("Content-Length: %d%s", len(message), crlf+crlf+message) return httpBadRequest + fmt.Sprintf("Content-Length: %d%s", len(message), crlf+crlf+message)
} }
conn.SetReadDeadline(time.Now().Add(httpReadTimeout))
scanner := bufio.NewScanner(conn) scanner := bufio.NewScanner(conn)
scanner.Split(func(data []byte, atEOF bool) (int, []byte, error) { scanner.Split(func(data []byte, atEOF bool) (int, []byte, error) {
found := bytes.Index(data, []byte(crlf)) found := bytes.Index(data, []byte(crlf))
@ -73,31 +75,41 @@ func handleHttpRequest(conn net.Conn, channel chan []*action) string {
return 0, nil, nil return 0, nil, nil
}) })
section := 0
for scanner.Scan() { for scanner.Scan() {
text := scanner.Text() text := scanner.Text()
if line == 0 && !strings.HasPrefix(text, "POST / HTTP") { switch section {
case 0:
if !strings.HasPrefix(text, "POST / HTTP") {
return bad("invalid request method") return bad("invalid request method")
} }
section++
case 1:
if text == crlf { if text == crlf {
headerRead = true if contentLength == 0 {
return bad("content-length header missing")
}
section++
continue
} }
if !headerRead {
pair := strings.SplitN(text, ":", 2) pair := strings.SplitN(text, ":", 2)
if len(pair) == 2 && strings.ToLower(pair[0]) == "content-length" { if len(pair) == 2 && strings.ToLower(pair[0]) == "content-length" {
length, err := strconv.Atoi(strings.TrimSpace(pair[1])) length, err := strconv.Atoi(strings.TrimSpace(pair[1]))
if err != nil { if err != nil || length <= 0 || length > maxContentLength {
return bad("invalid content length") return bad("invalid content length")
} }
contentLength = length contentLength = length
} }
} else if contentLength <= 0 { case 2:
break
} else {
body += text body += text
} }
line++
} }
if len(body) < contentLength {
return bad("incomplete request")
}
body = body[:contentLength]
errorMessage := "" errorMessage := ""
actions := parseSingleActionList(strings.Trim(string(body), "\r\n"), func(message string) { actions := parseSingleActionList(strings.Trim(string(body), "\r\n"), func(message string) {
errorMessage = message errorMessage = message