package permlib import ( "bytes" "context" "encoding/json" "fmt" "net/http" "strings" ) func (e *Engine) authMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { authHeader := req.Header.Get("Authorization") if authHeader == "" { e.cfg.Callbacks.OnError(w, req, 401, 401, "未登录") return } tokenStr := strings.TrimPrefix(authHeader, "Bearer ") if tokenStr == authHeader { e.cfg.Callbacks.OnError(w, req, 401, 401, "token格式错误") return } info, err := e.verifyAndGetUser(req.Context(), tokenStr) if err != nil { e.cfg.Callbacks.OnError(w, req, 401, 401, "token无效或已过期") return } apiCode := e.resolvePermCode(req) if apiCode != "" && !containsPerm(info.Perms, apiCode) { e.cfg.Callbacks.OnError(w, req, 403, 403, fmt.Sprintf("无权访问: %s", apiCode)) return } user := &UserInfo{ UserId: info.UserId, Username: info.Username, ProductCode: info.ProductCode, MemberType: info.MemberType, Perms: info.Perms, } ctx := withUser(req.Context(), user) req, rejected := e.filterRequest(req, info.Perms) if rejected { e.cfg.Callbacks.OnError(w, req, 403, 403, "包含无权写入的字段") return } dataCode := e.resolveDataCode(req, apiCode) hasDataPerm := dataCode == "" || containsPerm(info.Perms, dataCode) var wrappedNext http.HandlerFunc if e.hasRespFilter(req) { wrappedNext = func(w http.ResponseWriter, req *http.Request) { rw := newFilterResponseWriter(w, req, info.Perms, e) next(rw, req) rw.flush() } } else { wrappedNext = next } e.cfg.Callbacks.OnSuccess(w, req.WithContext(ctx), wrappedNext, user, hasDataPerm) } } func (e *Engine) resolvePermCode(req *http.Request) string { key := req.Method + " " + req.URL.Path if decl, ok := e.routePerms[key]; ok { return decl.PermCode } return "" } func (e *Engine) resolveDataCode(req *http.Request, apiCode string) string { key := req.Method + " " + req.URL.Path if decl, ok := e.routePerms[key]; ok { return decl.DataCode } if apiCode != "" { return apiToDataCode(apiCode) } return "" } func (e *Engine) hasRespFilter(req *http.Request) bool { key := req.Method + " " + req.URL.Path if fm, ok := e.fieldPerms[key]; ok { return fm.Response != nil && (len(fm.Response.Fields) > 0 || len(fm.Response.Nested) > 0) } return false } func (e *Engine) filterRequest(req *http.Request, perms []string) (*http.Request, bool) { key := req.Method + " " + req.URL.Path fm, ok := e.fieldPerms[key] if !ok || fm.Request == nil { return req, false } if len(fm.Request.Fields) == 0 && len(fm.Request.Nested) == 0 { return req, false } return filterRequestByNode(req, fm.Request, perms, e.cfg.FieldWriteMode) } func (e *Engine) verifyAndGetUser(ctx context.Context, token string) (*cachedUser, error) { if u, ok := e.cache.get(token); ok { return u, nil } v, err, _ := e.cache.sf.Do(token, func() (interface{}, error) { resp, err := e.client.verifyToken(ctx, token) if err != nil { return nil, err } if !resp.Valid { return nil, fmt.Errorf("token验证失败") } u := &cachedUser{ UserId: resp.UserId, Username: resp.Username, ProductCode: resp.ProductCode, MemberType: resp.MemberType, Perms: resp.Perms, } e.cache.set(token, u) return u, nil }) if err != nil { return nil, err } return v.(*cachedUser), nil } func containsPerm(perms []string, code string) bool { for _, p := range perms { if p == code { return true } } return false } func filterRequestByNode(req *http.Request, node *FieldNode, perms []string, mode FieldWriteMode) (*http.Request, bool) { if req.Body == nil { return req, false } body, err := readBody(req) if err != nil || len(body) == 0 { return req, false } var obj map[string]json.RawMessage if err := json.Unmarshal(body, &obj); err != nil { restoreBody(req, body) return req, false } permSet := toPermSet(perms) if mode == FieldWriteReject { if hasUnauthorizedField(obj, node, permSet) { restoreBody(req, body) return req, true } } filtered := filterRequestObject(obj, node, permSet) result, _ := json.Marshal(filtered) restoreBody(req, result) req.ContentLength = int64(len(result)) return req, false } func hasUnauthorizedField(obj map[string]json.RawMessage, node *FieldNode, permSet map[string]bool) bool { for jsonName, permCode := range node.Fields { if _, has := obj[jsonName]; has && !permSet[permCode] { return true } } for jsonName, child := range node.Nested { v, ok := obj[jsonName] if !ok { continue } v = bytes.TrimSpace(v) if len(v) == 0 { continue } if v[0] == '{' { var nested map[string]json.RawMessage if json.Unmarshal(v, &nested) == nil { if hasUnauthorizedField(nested, child, permSet) { return true } } } else if v[0] == '[' { var arr []json.RawMessage if json.Unmarshal(v, &arr) == nil { for _, item := range arr { var nested map[string]json.RawMessage if json.Unmarshal(item, &nested) == nil { if hasUnauthorizedField(nested, child, permSet) { return true } } } } } } return false } func filterRequestObject(obj map[string]json.RawMessage, node *FieldNode, permSet map[string]bool) map[string]json.RawMessage { for jsonName, permCode := range node.Fields { if _, has := obj[jsonName]; has && !permSet[permCode] { delete(obj, jsonName) } } for jsonName, child := range node.Nested { v, ok := obj[jsonName] if !ok { continue } v = bytes.TrimSpace(v) if len(v) == 0 { continue } if v[0] == '{' { var nested map[string]json.RawMessage if json.Unmarshal(v, &nested) == nil { nested = filterRequestObject(nested, child, permSet) obj[jsonName], _ = json.Marshal(nested) } } else if v[0] == '[' { var arr []json.RawMessage if json.Unmarshal(v, &arr) == nil { for i, item := range arr { var nested map[string]json.RawMessage if json.Unmarshal(item, &nested) == nil { nested = filterRequestObject(nested, child, permSet) arr[i], _ = json.Marshal(nested) } } obj[jsonName], _ = json.Marshal(arr) } } } return obj }