Skip to content

Commit

Permalink
feat✨: 更改逻辑
Browse files Browse the repository at this point in the history
  • Loading branch information
wcz0 committed Apr 24, 2024
1 parent ce58052 commit 2ab6d58
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 156 deletions.
104 changes: 84 additions & 20 deletions adapters/database_adapter.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,44 @@
package adapters

import (

"time"

"github.com/casbin/casbin/v2/model"
"github.com/casbin/casbin/v2/persist"
"github.com/goravel/framework/facades"
"github.com/wcz0/goravel-authz/models"
)

type Rule interface {
SetPtype(value string)
GetPtype() string
SetV0(value string)
GetV0() string
SetV1(value string)
GetV1() string
SetV2(value string)
GetV2() string
SetV3(value string)
GetV3() string
SetV4(value string)
GetV4() string
SetV5(value string)
GetV5() string
// model 类型, model 值
Model() (string, string)
// 是否从缓存中获取, 缓存store, 缓存key
Cache() (bool, string, string)
// 刷新缓存方式
RefreshCache()
}

type Adapter struct {
eloquent *models.Rule
eloquent Rule
}

func NewAdapter() *Adapter {
func NewAdapter(r Rule) *Adapter {
return &Adapter{
eloquent: &models.Rule{}, // Replace models.Rule with a valid expression that represents an instance of the models.Rule type
eloquent: r,
}
}

Expand Down Expand Up @@ -41,40 +66,43 @@ func (a *Adapter) SavePolicy(model model.Model) error {

// AddPolicy adds a policy rule to the storage.
func (a *Adapter) savePolicyLine(ptype string, rule []string) error {
a.eloquent.Ptype = ptype
a.eloquent.SetPtype(ptype)
if len(rule) > 0 {
a.eloquent.V0 = rule[0]
a.eloquent.SetV0(rule[0])
}
if len(rule) > 1 {
a.eloquent.V1 = rule[1]
a.eloquent.SetV1(rule[1])
}
if len(rule) > 2 {
a.eloquent.V2 = rule[2]
a.eloquent.SetV2(rule[2])
}
if len(rule) > 3 {
a.eloquent.V3 = rule[3]
a.eloquent.SetV3(rule[3])
}
if len(rule) > 4 {
a.eloquent.V4 = rule[4]
a.eloquent.SetV4(rule[4])
}
if len(rule) > 5 {
a.eloquent.V5 = rule[5]
a.eloquent.SetV5(rule[5])
}
// Save the rule to the database
err := facades.Orm().Query().Create(&a.eloquent)
err := facades.Orm().Query().Create(a.eloquent)
if err != nil {
return err
}
return nil
}



/**
* Loads all policy rules from the storage.
*/
func (a *Adapter) LoadPolicy(model model.Model) error {
row, _ := a.eloquent.GetAllFromCache()
// var row []Rule
// 是否从缓存中获取
row, err := a.getAllFromCache()
if err != nil {
return err
}
for _, rule := range row {
err := a.loadPolicyLine(rule, model)
if err != nil {
Expand All @@ -84,8 +112,44 @@ func (a *Adapter) LoadPolicy(model model.Model) error {
return nil
}

func (a *Adapter) loadPolicyLine(rule models.Rule, model model.Model) error {
var p = []string{rule.Ptype, rule.V0, rule.V1, rule.V2, rule.V3, rule.V4, rule.V5}
func (a *Adapter) getAllFromCache() ([]Rule, error) {
// 是否从缓存中获取
if ok, store, key := a.eloquent.Cache(); ok {
cache := facades.Cache().Store(store)
ttl := 5 * 60 * time.Second
result, err := cache.Remember(key, ttl, func() (any, error) {
return a.getPolicy(), nil
})
if err != nil {
return nil, err
}
return result.([]Rule), nil
} else {
// 从数据库中获取
return a.getPolicy(), nil
}
}

// func (a *Adapter) (value string) {
// a.eloquent.SetPtype(value)
// }

func (a *Adapter) getPolicy() []Rule {
var rules = []Rule{}
facades.Orm().Query().Select("ptype", "v0", "v1", "v2", "v3", "v4", "v5").Get(&rules)
return rules
}

func (a *Adapter) loadPolicyLine(rule Rule, model model.Model) error {
var p = []string{
rule.GetPtype(),
rule.GetV0(),
rule.GetV1(),
rule.GetV2(),
rule.GetV3(),
rule.GetV4(),
rule.GetV5(),
}
i := len(p) - 1
for p[i] == "" {
i--
Expand Down Expand Up @@ -118,7 +182,7 @@ func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error {
for i, v := range rule {
query = query.Where("v"+string(rune(i)), v)
}
_, err := query.Delete(&a.eloquent)
_, err := query.Delete(a.eloquent)
if err != nil {
return err
}
Expand All @@ -141,14 +205,14 @@ func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int,
}

// 保存删除的规则, 不知有何用意
// err := query.Get(&a.eloquent)
// err := query.Get(a.eloquent)
// if err != nil {
// return err
// }
// for _, rule := range rules {
// removeRules = append(removeRules, map[string]any{"p_type": rule.PType, "v0": rule.V0, "v1": rule.V1, "v2": rule.V2, "v3": rule.V3, "v4": rule.V4, "v5": rule.V5})
// }
_, err := query.Delete(&a.eloquent)
_, err := query.Delete(a.eloquent)
if err != nil {
return err
}
Expand Down
59 changes: 6 additions & 53 deletions config/casbin.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package config

import (
"github.com/goravel/framework/facades"
"github.com/wcz0/goravel-authz/models"
)

func init() {
Expand All @@ -10,59 +11,11 @@ func init() {
// Casbin default
"default": "basic",

"basic": map[string]any{
"model": map[string]any{
"config_type": "file",
"config_file_path": "casbin-rbac-model.conf",
"config_text": "",
},
// TODO: Casbin adapter . it is default adapter
"adapter": "",

// goravel database type
"database": map[string]any{
"connection": "mysql",
"rules_table": "casbin_rules",
},

// TODO: Casbin Logger
// "log": map[string]any{
// "enabled": false,
// "logger": "log",
// },

// store cache for goravel cache
"cache": map[string]any{
"enabled": true,
// goravel cache store
"store": "memory",
"key": "casbin",
"ttl": 60 * 60,
},
// 多个模型实现多个适配器
"models": map[string]any{
"basic": models.NewRule(),
// second adapter
"second": "",
},


// TODO: Casbin multi adapter
// 第二个 Casbin 配置, 注意!, 需要自己创建对应的数据库表
// "second": map[string]any{
// "model": map[string]any{
// "config_type": "file",
// "config_file_path": path.Config() + "casbin-rbac-model.conf",
// },


// "database": map[string]any{
// "connection": "mysql",
// "rules_table": "casbin_rules_second",
// },


// "cache": map[string]any{
// "enabled": false,
// "store": "default",
// "key": "casbin",
// "ttl": 24 * 60,
// },
// },
})
}
32 changes: 15 additions & 17 deletions enforcer.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,32 @@ type EnforcerManager struct {
}

func NewEnforcer(g string) *casbin.Enforcer {
config := facades.Config()
if g == "" {
g = facades.Config().GetString("casbin.default")
g = config.GetString("casbin.default")
}

ruleImpl := config.Get("casbin.models." + g)
var rule adapters.Rule = ruleImpl.(adapters.Rule)
configType, contextArg := rule.Model()
var m model.Model
configType := config(g, "model.config_type").(string)
if configType == "file" {
filename := config(g, "model.config_file_path").(string)
path := path.Config(filename)
model, err := model.NewModelFromFile(path)
path := path.Config(contextArg)
modelInstance, err := model.NewModelFromFile(path)
if err != nil {
panic("加载 model 文件失败")
panic("加载 model 文件失败: " + err.Error())
}
m = model
m = modelInstance
} else if configType == "text" {
model, err := model.NewModelFromString(config(g, "model.config_text").(string))
modelInstance, err := model.NewModelFromString(contextArg)
if err != nil {
panic("加载 model 文本失败")
panic("加载 model 文本失败: " + err.Error())
}
m = model
m = modelInstance
}
e, err := casbin.NewEnforcer(m, adapters.NewAdapter())
a := adapters.NewAdapter(rule)
e, err := casbin.NewEnforcer(m, a)
if err != nil {
panic("创建 enforcer 失败")
panic("创建 enforcer 失败: " + err.Error())
}
return e
}

func config(guard string, key string) any {
return facades.Config().Get("casbin."+guard+"."+key, "")
}
Loading

0 comments on commit 2ab6d58

Please sign in to comment.