diff --git a/grab/grab.go b/grab/grab.go index 50ed522..c2d0df5 100644 --- a/grab/grab.go +++ b/grab/grab.go @@ -61,13 +61,15 @@ type Grabber interface { func NewHttpClient() *req.Client { client := req.C() client. - SetCommonHeader("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36 Edg/111.0.1661.51"). + SetCommonHeader("User-Agent", "ozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/112.0.0.0 Safari/537.36 Edg/112.0.1722.34"). SetTimeout(10*time.Second). SetCommonRetryCount(3). SetCookieJar(nil). SetCommonRetryBackoffInterval(5*time.Second, 10*time.Second). SetCommonRetryHook(func(resp *req.Response, err error) { - golog.Warnf("retrying as %s", err) + if err != nil { + golog.Warnf("retrying as %s", err) + } }).SetCommonRetryCondition(func(resp *req.Response, err error) bool { if err != nil { if errors.Is(err, context.Canceled) { diff --git a/grab/oscs.go b/grab/oscs.go index e217fb2..4ca5e04 100644 --- a/grab/oscs.go +++ b/grab/oscs.go @@ -47,11 +47,11 @@ func (t *OSCSCrawler) GetPageCount(ctx context.Context, size int) (int, error) { return true } if body.Code != 200 || !body.Success { - t.log.Warnf("failed to get page count, msg: %s", body.Info) + t.log.Warnf("failed to get page count, msg: %s, retrying", body.Info) return true } if body.Data.Total <= 0 { - t.log.Warnf("invalid total size %d", body.Data.Total) + t.log.Warnf("invalid total size %d, retrying", body.Data.Total) return true } return false diff --git a/grab/ti.go b/grab/ti.go index 930aba1..f63912b 100644 --- a/grab/ti.go +++ b/grab/ti.go @@ -48,11 +48,11 @@ func (t *TiCrawler) GetPageCount(ctx context.Context, size int) (int, error) { return true } if body.Status != 10000 { - t.log.Warnf("failed to get page count, msg: %s", body.Message) + t.log.Warnf("failed to get page count, msg: %s, retrying", body.Message) return true } if body.Data.Total <= 0 { - t.log.Warnf("invalid total size %d", body.Data.Total) + t.log.Warnf("invalid total size %d, retrying", body.Data.Total) return true } return false diff --git a/main.go b/main.go index f600655..f5754b9 100644 --- a/main.go +++ b/main.go @@ -28,7 +28,9 @@ func init() { } var log = golog.Child("[main]") -var Version = "v0.6.0" +var Version = "v0.7.0" + +const MaxPageBase = 3 func main() { golog.Default.SetLevel("info") @@ -128,7 +130,7 @@ func Action(c *cli.Context) error { if err != nil { return err } - grabbers, vulnCountAtLeast, err := initSources(c) + grabbers, err := initSources(c) if err != nil { return err } @@ -160,7 +162,7 @@ func Action(c *cli.Context) error { return fmt.Errorf("interval is too small, at least 1m") } - drv, err := entSql.Open("sqlite3", "file:vuln_v1.sqlite3?cache=shared&_pragma=foreign_keys(1)") + drv, err := entSql.Open("sqlite3", "file:vuln_v2.sqlite3?cache=shared&_pragma=foreign_keys(1)") if err != nil { return errors.Wrap(err, "failed opening connection to sqlite") } @@ -173,36 +175,24 @@ func Action(c *cli.Context) error { return errors.Wrap(err, "failed creating schema resources") } - count, err := dbClient.VulnInformation.Query().Count(ctx) - if err != nil { - return errors.Wrap(err, "failed creating schema resources") - } - log.Infof("local database has %d vulns", count) - if count < vulnCountAtLeast { - log.Infof("local data is outdated, init database") - eg, initCtx := errgroup.WithContext(ctx) - eg.SetLimit(len(grabbers)) - for _, grabber := range grabbers { - grabber := grabber - eg.Go(func() error { - return initData(initCtx, dbClient, grabber) - }) - } - err = eg.Wait() - if err != nil { - return errors.Wrap(err, "init data") - } - log.Infof("grabber finished successfully") + log.Infof("initialize local database..") + // 抓取前3页作为基准漏洞数据 + eg, initCtx := errgroup.WithContext(ctx) + eg.SetLimit(len(grabbers)) + for _, grabber := range grabbers { + grabber := grabber + eg.Go(func() error { + return initData(initCtx, dbClient, grabber) + }) } - - // 初次启动不要推送数据, 以免长时间没运行狂发消息 - vulns, err := collectUpdate(ctx, dbClient, grabbers) + err = eg.Wait() if err != nil { - return errors.Wrap(err, "initial collect") + return errors.Wrap(err, "init data") } + log.Infof("grabber finished successfully") localCount := dbClient.VulnInformation.Query().CountX(ctx) - log.Infof("local database has %d vulns", localCount) + log.Infof("system init finished, local database has %d vulns", localCount) if !noStartMessage { providers := make([]*grab.Provider, 0, 3) for _, p := range grabbers { @@ -220,7 +210,6 @@ func Action(c *cli.Context) error { } } - log.Infof("system init finished, found %d new vulns (skip pushing)", len(vulns)) log.Infof("ticking every %s", interval) defer func() { @@ -246,7 +235,7 @@ func Action(c *cli.Context) error { continue } - vulns, err = collectUpdate(ctx, dbClient, grabbers) + vulns, err := collectUpdate(ctx, dbClient, grabbers) if err != nil { log.Errorf("failed to get updates, %s", err) continue @@ -297,31 +286,27 @@ func Action(c *cli.Context) error { } } -func initSources(c *cli.Context) ([]grab.Grabber, int, error) { +func initSources(c *cli.Context) ([]grab.Grabber, error) { sources := c.String("sources") if os.Getenv("SOURCES") != "" { sources = os.Getenv("SOURCES") } parts := strings.Split(sources, ",") var grabs []grab.Grabber - var countApproximately int for _, part := range parts { part = strings.ToLower(strings.TrimSpace(part)) switch part { case "avd": - countApproximately += 1200 grabs = append(grabs, grab.NewAVDCrawler()) case "ti": - countApproximately += 27000 grabs = append(grabs, grab.NewTiCrawler()) case "oscs": - countApproximately += 90 * 10 grabs = append(grabs, grab.NewOSCSCrawler()) default: - return nil, 0, fmt.Errorf("invalid grab source %s", part) + return nil, fmt.Errorf("invalid grab source %s", part) } } - return grabs, countApproximately, nil + return grabs, nil } func initPusher(c *cli.Context) (push.Pusher, error) { @@ -387,10 +372,16 @@ func initData(ctx context.Context, dbClient *ent.Client, grabber grab.Grabber) e if err != nil { return nil } - log.Infof("%s total page: %d", source.Name, total) + if total == 0 { + return fmt.Errorf("%s got unexpected zero page", source.Name) + } + if total > MaxPageBase { + total = MaxPageBase + } + log.Infof("start grab %s, total page: %d", source.Name, total) eg, ctx := errgroup.WithContext(ctx) - eg.SetLimit(20) + eg.SetLimit(total) for i := 1; i <= total; i++ { i := i