Skip to content

Commit

Permalink
feat: speed up initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
zema1 committed Apr 14, 2023
1 parent e413ac4 commit 7db5a63
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 45 deletions.
6 changes: 4 additions & 2 deletions grab/grab.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,15 @@ type Grabber interface {
func NewHttpClient() *req.Client {
client := req.C()
client.
SetCommonHeader("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36 Edg/111.0.1661.51").
SetCommonHeader("User-Agent", "ozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36 Edg/112.0.1722.34").
SetTimeout(10*time.Second).
SetCommonRetryCount(3).
SetCookieJar(nil).
SetCommonRetryBackoffInterval(5*time.Second, 10*time.Second).
SetCommonRetryHook(func(resp *req.Response, err error) {
golog.Warnf("retrying as %s", err)
if err != nil {
golog.Warnf("retrying as %s", err)
}
}).SetCommonRetryCondition(func(resp *req.Response, err error) bool {
if err != nil {
if errors.Is(err, context.Canceled) {
Expand Down
4 changes: 2 additions & 2 deletions grab/oscs.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ func (t *OSCSCrawler) GetPageCount(ctx context.Context, size int) (int, error) {
return true
}
if body.Code != 200 || !body.Success {
t.log.Warnf("failed to get page count, msg: %s", body.Info)
t.log.Warnf("failed to get page count, msg: %s, retrying", body.Info)
return true
}
if body.Data.Total <= 0 {
t.log.Warnf("invalid total size %d", body.Data.Total)
t.log.Warnf("invalid total size %d, retrying", body.Data.Total)
return true
}
return false
Expand Down
4 changes: 2 additions & 2 deletions grab/ti.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ func (t *TiCrawler) GetPageCount(ctx context.Context, size int) (int, error) {
return true
}
if body.Status != 10000 {
t.log.Warnf("failed to get page count, msg: %s", body.Message)
t.log.Warnf("failed to get page count, msg: %s, retrying", body.Message)
return true
}
if body.Data.Total <= 0 {
t.log.Warnf("invalid total size %d", body.Data.Total)
t.log.Warnf("invalid total size %d, retrying", body.Data.Total)
return true
}
return false
Expand Down
69 changes: 30 additions & 39 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ func init() {
}

var log = golog.Child("[main]")
var Version = "v0.6.0"
var Version = "v0.7.0"

const MaxPageBase = 3

func main() {
golog.Default.SetLevel("info")
Expand Down Expand Up @@ -128,7 +130,7 @@ func Action(c *cli.Context) error {
if err != nil {
return err
}
grabbers, vulnCountAtLeast, err := initSources(c)
grabbers, err := initSources(c)
if err != nil {
return err
}
Expand Down Expand Up @@ -160,7 +162,7 @@ func Action(c *cli.Context) error {
return fmt.Errorf("interval is too small, at least 1m")
}

drv, err := entSql.Open("sqlite3", "file:vuln_v1.sqlite3?cache=shared&_pragma=foreign_keys(1)")
drv, err := entSql.Open("sqlite3", "file:vuln_v2.sqlite3?cache=shared&_pragma=foreign_keys(1)")
if err != nil {
return errors.Wrap(err, "failed opening connection to sqlite")
}
Expand All @@ -173,36 +175,24 @@ func Action(c *cli.Context) error {
return errors.Wrap(err, "failed creating schema resources")
}

count, err := dbClient.VulnInformation.Query().Count(ctx)
if err != nil {
return errors.Wrap(err, "failed creating schema resources")
}
log.Infof("local database has %d vulns", count)
if count < vulnCountAtLeast {
log.Infof("local data is outdated, init database")
eg, initCtx := errgroup.WithContext(ctx)
eg.SetLimit(len(grabbers))
for _, grabber := range grabbers {
grabber := grabber
eg.Go(func() error {
return initData(initCtx, dbClient, grabber)
})
}
err = eg.Wait()
if err != nil {
return errors.Wrap(err, "init data")
}
log.Infof("grabber finished successfully")
log.Infof("initialize local database..")
// 抓取前3页作为基准漏洞数据
eg, initCtx := errgroup.WithContext(ctx)
eg.SetLimit(len(grabbers))
for _, grabber := range grabbers {
grabber := grabber
eg.Go(func() error {
return initData(initCtx, dbClient, grabber)
})
}

// 初次启动不要推送数据, 以免长时间没运行狂发消息
vulns, err := collectUpdate(ctx, dbClient, grabbers)
err = eg.Wait()
if err != nil {
return errors.Wrap(err, "initial collect")
return errors.Wrap(err, "init data")
}
log.Infof("grabber finished successfully")

localCount := dbClient.VulnInformation.Query().CountX(ctx)
log.Infof("local database has %d vulns", localCount)
log.Infof("system init finished, local database has %d vulns", localCount)
if !noStartMessage {
providers := make([]*grab.Provider, 0, 3)
for _, p := range grabbers {
Expand All @@ -220,7 +210,6 @@ func Action(c *cli.Context) error {
}
}

log.Infof("system init finished, found %d new vulns (skip pushing)", len(vulns))
log.Infof("ticking every %s", interval)

defer func() {
Expand All @@ -246,7 +235,7 @@ func Action(c *cli.Context) error {
continue
}

vulns, err = collectUpdate(ctx, dbClient, grabbers)
vulns, err := collectUpdate(ctx, dbClient, grabbers)
if err != nil {
log.Errorf("failed to get updates, %s", err)
continue
Expand Down Expand Up @@ -297,31 +286,27 @@ func Action(c *cli.Context) error {
}
}

func initSources(c *cli.Context) ([]grab.Grabber, int, error) {
func initSources(c *cli.Context) ([]grab.Grabber, error) {
sources := c.String("sources")
if os.Getenv("SOURCES") != "" {
sources = os.Getenv("SOURCES")
}
parts := strings.Split(sources, ",")
var grabs []grab.Grabber
var countApproximately int
for _, part := range parts {
part = strings.ToLower(strings.TrimSpace(part))
switch part {
case "avd":
countApproximately += 1200
grabs = append(grabs, grab.NewAVDCrawler())
case "ti":
countApproximately += 27000
grabs = append(grabs, grab.NewTiCrawler())
case "oscs":
countApproximately += 90 * 10
grabs = append(grabs, grab.NewOSCSCrawler())
default:
return nil, 0, fmt.Errorf("invalid grab source %s", part)
return nil, fmt.Errorf("invalid grab source %s", part)
}
}
return grabs, countApproximately, nil
return grabs, nil
}

func initPusher(c *cli.Context) (push.Pusher, error) {
Expand Down Expand Up @@ -387,10 +372,16 @@ func initData(ctx context.Context, dbClient *ent.Client, grabber grab.Grabber) e
if err != nil {
return nil
}
log.Infof("%s total page: %d", source.Name, total)
if total == 0 {
return fmt.Errorf("%s got unexpected zero page", source.Name)
}
if total > MaxPageBase {
total = MaxPageBase
}
log.Infof("start grab %s, total page: %d", source.Name, total)

eg, ctx := errgroup.WithContext(ctx)
eg.SetLimit(20)
eg.SetLimit(total)

for i := 1; i <= total; i++ {
i := i
Expand Down

0 comments on commit 7db5a63

Please sign in to comment.