middleware.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. package permlib
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "net/http"
  8. "strings"
  9. )
  10. func (e *Engine) authMiddleware(next http.HandlerFunc) http.HandlerFunc {
  11. return func(w http.ResponseWriter, req *http.Request) {
  12. authHeader := req.Header.Get("Authorization")
  13. if authHeader == "" {
  14. e.cfg.Callbacks.OnError(w, req, 401, 401, "未登录")
  15. return
  16. }
  17. tokenStr := strings.TrimPrefix(authHeader, "Bearer ")
  18. if tokenStr == authHeader {
  19. e.cfg.Callbacks.OnError(w, req, 401, 401, "token格式错误")
  20. return
  21. }
  22. info, err := e.verifyAndGetUser(req.Context(), tokenStr)
  23. if err != nil {
  24. e.cfg.Callbacks.OnError(w, req, 401, 401, "token无效或已过期")
  25. return
  26. }
  27. apiCode := e.resolvePermCode(req)
  28. if apiCode != "" && !containsPerm(info.Perms, apiCode) {
  29. e.cfg.Callbacks.OnError(w, req, 403, 403, fmt.Sprintf("无权访问: %s", apiCode))
  30. return
  31. }
  32. user := &UserInfo{
  33. UserId: info.UserId,
  34. Username: info.Username,
  35. ProductCode: info.ProductCode,
  36. MemberType: info.MemberType,
  37. Perms: info.Perms,
  38. }
  39. ctx := withUser(req.Context(), user)
  40. req, rejected := e.filterRequest(req, info.Perms)
  41. if rejected {
  42. e.cfg.Callbacks.OnError(w, req, 403, 403, "包含无权写入的字段")
  43. return
  44. }
  45. dataCode := e.resolveDataCode(req, apiCode)
  46. hasDataPerm := dataCode == "" || containsPerm(info.Perms, dataCode)
  47. var wrappedNext http.HandlerFunc
  48. if e.hasRespFilter(req) {
  49. wrappedNext = func(w http.ResponseWriter, req *http.Request) {
  50. rw := newFilterResponseWriter(w, req, info.Perms, e)
  51. next(rw, req)
  52. rw.flush()
  53. }
  54. } else {
  55. wrappedNext = next
  56. }
  57. e.cfg.Callbacks.OnSuccess(w, req.WithContext(ctx), wrappedNext, user, hasDataPerm)
  58. }
  59. }
  60. func (e *Engine) resolvePermCode(req *http.Request) string {
  61. key := req.Method + " " + req.URL.Path
  62. if decl, ok := e.routePerms[key]; ok {
  63. return decl.PermCode
  64. }
  65. return ""
  66. }
  67. func (e *Engine) resolveDataCode(req *http.Request, apiCode string) string {
  68. key := req.Method + " " + req.URL.Path
  69. if decl, ok := e.routePerms[key]; ok {
  70. return decl.DataCode
  71. }
  72. if apiCode != "" {
  73. return apiToDataCode(apiCode)
  74. }
  75. return ""
  76. }
  77. func (e *Engine) hasRespFilter(req *http.Request) bool {
  78. key := req.Method + " " + req.URL.Path
  79. if fm, ok := e.fieldPerms[key]; ok {
  80. return fm.Response != nil && (len(fm.Response.Fields) > 0 || len(fm.Response.Nested) > 0)
  81. }
  82. return false
  83. }
  84. func (e *Engine) filterRequest(req *http.Request, perms []string) (*http.Request, bool) {
  85. key := req.Method + " " + req.URL.Path
  86. fm, ok := e.fieldPerms[key]
  87. if !ok || fm.Request == nil {
  88. return req, false
  89. }
  90. if len(fm.Request.Fields) == 0 && len(fm.Request.Nested) == 0 {
  91. return req, false
  92. }
  93. return filterRequestByNode(req, fm.Request, perms, e.cfg.FieldWriteMode)
  94. }
  95. func (e *Engine) verifyAndGetUser(ctx context.Context, token string) (*cachedUser, error) {
  96. if u, ok := e.cache.get(token); ok {
  97. return u, nil
  98. }
  99. v, err, _ := e.cache.sf.Do(token, func() (interface{}, error) {
  100. resp, err := e.client.verifyToken(ctx, token)
  101. if err != nil {
  102. return nil, err
  103. }
  104. if !resp.Valid {
  105. return nil, fmt.Errorf("token验证失败")
  106. }
  107. u := &cachedUser{
  108. UserId: resp.UserId,
  109. Username: resp.Username,
  110. ProductCode: resp.ProductCode,
  111. MemberType: resp.MemberType,
  112. Perms: resp.Perms,
  113. }
  114. e.cache.set(token, u)
  115. return u, nil
  116. })
  117. if err != nil {
  118. return nil, err
  119. }
  120. return v.(*cachedUser), nil
  121. }
  122. func containsPerm(perms []string, code string) bool {
  123. for _, p := range perms {
  124. if p == code {
  125. return true
  126. }
  127. }
  128. return false
  129. }
  130. func filterRequestByNode(req *http.Request, node *FieldNode, perms []string, mode FieldWriteMode) (*http.Request, bool) {
  131. if req.Body == nil {
  132. return req, false
  133. }
  134. body, err := readBody(req)
  135. if err != nil || len(body) == 0 {
  136. return req, false
  137. }
  138. var obj map[string]json.RawMessage
  139. if err := json.Unmarshal(body, &obj); err != nil {
  140. restoreBody(req, body)
  141. return req, false
  142. }
  143. permSet := toPermSet(perms)
  144. if mode == FieldWriteReject {
  145. if hasUnauthorizedField(obj, node, permSet) {
  146. restoreBody(req, body)
  147. return req, true
  148. }
  149. }
  150. filtered := filterRequestObject(obj, node, permSet)
  151. result, _ := json.Marshal(filtered)
  152. restoreBody(req, result)
  153. req.ContentLength = int64(len(result))
  154. return req, false
  155. }
  156. func hasUnauthorizedField(obj map[string]json.RawMessage, node *FieldNode, permSet map[string]bool) bool {
  157. for jsonName, permCode := range node.Fields {
  158. if _, has := obj[jsonName]; has && !permSet[permCode] {
  159. return true
  160. }
  161. }
  162. for jsonName, child := range node.Nested {
  163. v, ok := obj[jsonName]
  164. if !ok {
  165. continue
  166. }
  167. v = bytes.TrimSpace(v)
  168. if len(v) == 0 {
  169. continue
  170. }
  171. if v[0] == '{' {
  172. var nested map[string]json.RawMessage
  173. if json.Unmarshal(v, &nested) == nil {
  174. if hasUnauthorizedField(nested, child, permSet) {
  175. return true
  176. }
  177. }
  178. } else if v[0] == '[' {
  179. var arr []json.RawMessage
  180. if json.Unmarshal(v, &arr) == nil {
  181. for _, item := range arr {
  182. var nested map[string]json.RawMessage
  183. if json.Unmarshal(item, &nested) == nil {
  184. if hasUnauthorizedField(nested, child, permSet) {
  185. return true
  186. }
  187. }
  188. }
  189. }
  190. }
  191. }
  192. return false
  193. }
  194. func filterRequestObject(obj map[string]json.RawMessage, node *FieldNode, permSet map[string]bool) map[string]json.RawMessage {
  195. for jsonName, permCode := range node.Fields {
  196. if _, has := obj[jsonName]; has && !permSet[permCode] {
  197. delete(obj, jsonName)
  198. }
  199. }
  200. for jsonName, child := range node.Nested {
  201. v, ok := obj[jsonName]
  202. if !ok {
  203. continue
  204. }
  205. v = bytes.TrimSpace(v)
  206. if len(v) == 0 {
  207. continue
  208. }
  209. if v[0] == '{' {
  210. var nested map[string]json.RawMessage
  211. if json.Unmarshal(v, &nested) == nil {
  212. nested = filterRequestObject(nested, child, permSet)
  213. obj[jsonName], _ = json.Marshal(nested)
  214. }
  215. } else if v[0] == '[' {
  216. var arr []json.RawMessage
  217. if json.Unmarshal(v, &arr) == nil {
  218. for i, item := range arr {
  219. var nested map[string]json.RawMessage
  220. if json.Unmarshal(item, &nested) == nil {
  221. nested = filterRequestObject(nested, child, permSet)
  222. arr[i], _ = json.Marshal(nested)
  223. }
  224. }
  225. obj[jsonName], _ = json.Marshal(arr)
  226. }
  227. }
  228. }
  229. return obj
  230. }