package middleware import ( "bytes" "encoding/json" "html" "io" "net/http" "reflect" "strings" ) // SanitizeMiddleware cleans XSS from request bodies func SanitizeMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch { // Only sanitize JSON bodies if strings.Contains(r.Header.Get("Content-Type"), "application/json") { bodyBytes, err := io.ReadAll(r.Body) if err != nil { http.Error(w, "Unable to read request body", http.StatusBadRequest) return } // Restore the io.ReadCloser to its original state r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Decode to map[string]interface{} to traverse and sanitize var data interface{} if err := json.Unmarshal(bodyBytes, &data); err == nil { sanitize(data) // Re-encode newBody, err := json.Marshal(data) if err == nil { r.Body = io.NopCloser(bytes.NewBuffer(newBody)) r.ContentLength = int64(len(newBody)) } } } } next.ServeHTTP(w, r) }) } // sanitize recursively escapes strings in maps and slices func sanitize(data interface{}) { val := reflect.ValueOf(data) if val.Kind() == reflect.Ptr { val = val.Elem() } switch val.Kind() { case reflect.Map: for _, key := range val.MapKeys() { v := val.MapIndex(key) if v.Kind() == reflect.Interface { v = v.Elem() } if v.Kind() == reflect.String { escaped := html.EscapeString(v.String()) val.SetMapIndex(key, reflect.ValueOf(escaped)) } else if v.Kind() == reflect.Map || v.Kind() == reflect.Slice { sanitize(v.Interface()) } } case reflect.Slice: for i := 0; i < val.Len(); i++ { v := val.Index(i) if v.Kind() == reflect.Interface { v = v.Elem() } if v.Kind() == reflect.String { // We can't modify slice elements directly if they are not addressable interfaces // But dealing with interface{} unmarshal, they usually are. // However, reflecting on interface{} logic is complex. // Simplified approach: treating this as "best effort" for top level or standard maps. } else if v.Kind() == reflect.Map || v.Kind() == reflect.Slice { sanitize(v.Interface()) } } } }