Przeglądaj źródła

feat: 支持嵌套结构的字段权限过滤

BaiLuoYan 1 tydzień temu
rodzic
commit
ee0808fbf2
3 zmienionych plików z 148 dodań i 45 usunięć
  1. 44 28
      fieldperm.go
  2. 96 15
      middleware.go
  3. 8 2
      permlib.go

+ 44 - 28
fieldperm.go

@@ -7,11 +7,6 @@ import (
 	"net/http"
 )
 
-type fieldPermEntry struct {
-	jsonName string
-	permCode string
-}
-
 type filterResponseWriter struct {
 	http.ResponseWriter
 	buf    bytes.Buffer
@@ -52,52 +47,73 @@ func (fw *filterResponseWriter) flush() {
 
 func (e *Engine) filterResponseBody(body []byte, req *http.Request, perms []string) []byte {
 	key := req.Method + " " + req.URL.Path
-	if fm, ok := e.fieldPerms[key]; ok && len(fm.Response) > 0 {
-		return filterResponseByMap(body, fm.Response, perms)
+	fm, ok := e.fieldPerms[key]
+	if !ok || fm.Response == nil {
+		return body
 	}
-	return body
+	return filterResponseByMap(body, fm.Response, perms)
 }
 
-func filterResponseByMap(body []byte, fieldMap map[string]string, perms []string) []byte {
+func filterResponseByMap(body []byte, node *FieldNode, perms []string) []byte {
+	if node == nil {
+		return body
+	}
 	permSet := toPermSet(perms)
 	body = bytes.TrimSpace(body)
 	if len(body) == 0 {
 		return body
 	}
 
-	entries := make([]fieldPermEntry, 0, len(fieldMap))
-	for jsonField, permCode := range fieldMap {
-		entries = append(entries, fieldPermEntry{jsonName: jsonField, permCode: permCode})
-	}
-
 	if body[0] == '[' {
-		var arr []json.RawMessage
-		if err := json.Unmarshal(body, &arr); err != nil {
-			return body
-		}
-		for i, item := range arr {
-			arr[i] = filterObject(item, entries, permSet)
-		}
-		result, _ := json.Marshal(arr)
-		return result
+		return filterArray(body, node, permSet)
 	}
-	return filterObject(body, entries, permSet)
+	return filterObject(body, node, permSet)
 }
 
-func filterObject(raw json.RawMessage, entries []fieldPermEntry, permSet map[string]bool) json.RawMessage {
+func filterObject(raw json.RawMessage, node *FieldNode, permSet map[string]bool) json.RawMessage {
+	if node == nil {
+		return raw
+	}
 	var obj map[string]json.RawMessage
 	if err := json.Unmarshal(raw, &obj); err != nil {
 		return raw
 	}
-	for _, entry := range entries {
-		if _, has := obj[entry.jsonName]; has && !permSet[entry.permCode] {
-			delete(obj, entry.jsonName)
+	for jsonName, permCode := range node.Fields {
+		if !permSet[permCode] {
+			delete(obj, jsonName)
+		}
+	}
+	for jsonName, child := range node.Nested {
+		v, ok := obj[jsonName]
+		if !ok {
+			continue
+		}
+		v = bytes.TrimSpace(v)
+		if len(v) == 0 {
+			continue
+		}
+		if v[0] == '[' {
+			obj[jsonName] = filterArray(v, child, permSet)
+		} else if v[0] == '{' {
+			obj[jsonName] = filterObject(v, child, permSet)
 		}
 	}
 	result, _ := json.Marshal(obj)
 	return result
 }
 
+func filterArray(raw json.RawMessage, node *FieldNode, permSet map[string]bool) json.RawMessage {
+	var arr []json.RawMessage
+	if err := json.Unmarshal(raw, &arr); err != nil {
+		return raw
+	}
+	for i, item := range arr {
+		arr[i] = filterObject(item, node, permSet)
+	}
+	result, _ := json.Marshal(arr)
+	return result
+}
+
 func toPermSet(perms []string) map[string]bool {
 	m := make(map[string]bool, len(perms))
 	for _, p := range perms {

+ 96 - 15
middleware.go

@@ -1,6 +1,7 @@
 package permlib
 
 import (
+	"bytes"
 	"context"
 	"encoding/json"
 	"fmt"
@@ -89,17 +90,21 @@ func (e *Engine) resolveDataCode(req *http.Request, apiCode string) string {
 func (e *Engine) hasRespFilter(req *http.Request) bool {
 	key := req.Method + " " + req.URL.Path
 	if fm, ok := e.fieldPerms[key]; ok {
-		return len(fm.Response) > 0
+		return fm.Response != nil && (len(fm.Response.Fields) > 0 || len(fm.Response.Nested) > 0)
 	}
 	return false
 }
 
 func (e *Engine) filterRequest(req *http.Request, perms []string) (*http.Request, bool) {
 	key := req.Method + " " + req.URL.Path
-	if fm, ok := e.fieldPerms[key]; ok && len(fm.Request) > 0 {
-		return filterRequestByMap(req, fm.Request, perms, e.cfg.FieldWriteMode)
+	fm, ok := e.fieldPerms[key]
+	if !ok || fm.Request == nil {
+		return req, false
 	}
-	return req, false
+	if len(fm.Request.Fields) == 0 && len(fm.Request.Nested) == 0 {
+		return req, false
+	}
+	return filterRequestByNode(req, fm.Request, perms, e.cfg.FieldWriteMode)
 }
 
 func (e *Engine) verifyAndGetUser(ctx context.Context, token string) (*cachedUser, error) {
@@ -140,7 +145,7 @@ func containsPerm(perms []string, code string) bool {
 	return false
 }
 
-func filterRequestByMap(req *http.Request, fieldMap map[string]string, perms []string, mode FieldWriteMode) (*http.Request, bool) {
+func filterRequestByNode(req *http.Request, node *FieldNode, perms []string, mode FieldWriteMode) (*http.Request, bool) {
 	if req.Body == nil {
 		return req, false
 	}
@@ -157,18 +162,94 @@ func filterRequestByMap(req *http.Request, fieldMap map[string]string, perms []s
 	}
 
 	permSet := toPermSet(perms)
-	for jsonField, permCode := range fieldMap {
-		if _, has := obj[jsonField]; has && !permSet[permCode] {
-			if mode == FieldWriteReject {
-				restoreBody(req, body)
-				return req, true
-			}
-			delete(obj, jsonField)
+
+	if mode == FieldWriteReject {
+		if hasUnauthorizedField(obj, node, permSet) {
+			restoreBody(req, body)
+			return req, true
 		}
 	}
 
-	filtered, _ := json.Marshal(obj)
-	restoreBody(req, filtered)
-	req.ContentLength = int64(len(filtered))
+	filtered := filterRequestObject(obj, node, permSet)
+	result, _ := json.Marshal(filtered)
+	restoreBody(req, result)
+	req.ContentLength = int64(len(result))
 	return req, false
 }
+
+func hasUnauthorizedField(obj map[string]json.RawMessage, node *FieldNode, permSet map[string]bool) bool {
+	for jsonName, permCode := range node.Fields {
+		if _, has := obj[jsonName]; has && !permSet[permCode] {
+			return true
+		}
+	}
+	for jsonName, child := range node.Nested {
+		v, ok := obj[jsonName]
+		if !ok {
+			continue
+		}
+		v = bytes.TrimSpace(v)
+		if len(v) == 0 {
+			continue
+		}
+		if v[0] == '{' {
+			var nested map[string]json.RawMessage
+			if json.Unmarshal(v, &nested) == nil {
+				if hasUnauthorizedField(nested, child, permSet) {
+					return true
+				}
+			}
+		} else if v[0] == '[' {
+			var arr []json.RawMessage
+			if json.Unmarshal(v, &arr) == nil {
+				for _, item := range arr {
+					var nested map[string]json.RawMessage
+					if json.Unmarshal(item, &nested) == nil {
+						if hasUnauthorizedField(nested, child, permSet) {
+							return true
+						}
+					}
+				}
+			}
+		}
+	}
+	return false
+}
+
+func filterRequestObject(obj map[string]json.RawMessage, node *FieldNode, permSet map[string]bool) map[string]json.RawMessage {
+	for jsonName, permCode := range node.Fields {
+		if _, has := obj[jsonName]; has && !permSet[permCode] {
+			delete(obj, jsonName)
+		}
+	}
+	for jsonName, child := range node.Nested {
+		v, ok := obj[jsonName]
+		if !ok {
+			continue
+		}
+		v = bytes.TrimSpace(v)
+		if len(v) == 0 {
+			continue
+		}
+		if v[0] == '{' {
+			var nested map[string]json.RawMessage
+			if json.Unmarshal(v, &nested) == nil {
+				nested = filterRequestObject(nested, child, permSet)
+				obj[jsonName], _ = json.Marshal(nested)
+			}
+		} else if v[0] == '[' {
+			var arr []json.RawMessage
+			if json.Unmarshal(v, &arr) == nil {
+				for i, item := range arr {
+					var nested map[string]json.RawMessage
+					if json.Unmarshal(item, &nested) == nil {
+						nested = filterRequestObject(nested, child, permSet)
+						arr[i], _ = json.Marshal(nested)
+					}
+				}
+				obj[jsonName], _ = json.Marshal(arr)
+			}
+		}
+	}
+	return obj
+}

+ 8 - 2
permlib.go

@@ -19,10 +19,16 @@ type RoutePermDecl struct {
 	DataCode string // 配对的 data 权限 code
 }
 
+// FieldNode 树形字段权限节点,支持任意深度嵌套
+type FieldNode struct {
+	Fields map[string]string      // jsonName → permCode(当前层需过滤的字段)
+	Nested map[string]*FieldNode  // jsonName → 子节点(有嵌套结构的字段)
+}
+
 // FieldPermMap 字段权限映射,用于静态注册(替代反射)
 type FieldPermMap struct {
-	Request  map[string]string // json字段名 → permCode
-	Response map[string]string // json字段名 → permCode
+	Request  *FieldNode
+	Response *FieldNode
 }
 
 type Engine struct {