diff --git a/providers/aws/resources/aws_ec2.go b/providers/aws/resources/aws_ec2.go index 3654f5ae5..662697d79 100644 --- a/providers/aws/resources/aws_ec2.go +++ b/providers/aws/resources/aws_ec2.go @@ -764,35 +764,17 @@ func (a *mqlAwsEc2) instances() ([]interface{}, error) { return res, nil } -func (a *mqlAwsEc2) getImdsv2Instances(ctx context.Context, svc *ec2.Client, filterName string) ([]ec2types.Reservation, error) { +func (a *mqlAwsEc2) getEc2Instances(ctx context.Context, svc *ec2.Client, tags map[string]string) ([]ec2types.Reservation, error) { res := []ec2types.Reservation{} nextToken := aws.String("no_token_to_start_with") params := &ec2.DescribeInstancesInput{ - Filters: []ec2types.Filter{ - {Name: &filterName, Values: []string{"required"}}, - }, + Filters: []ec2types.Filter{}, } - for nextToken != nil { - instances, err := svc.DescribeInstances(ctx, params) - if err != nil { - return nil, err - } - nextToken = instances.NextToken - if instances.NextToken != nil { - params.NextToken = nextToken - } - res = append(res, instances.Reservations...) - } - return res, nil -} - -func (a *mqlAwsEc2) getImdsv1Instances(ctx context.Context, svc *ec2.Client, filterName string) ([]ec2types.Reservation, error) { - res := []ec2types.Reservation{} - nextToken := aws.String("no_token_to_start_with") - params := &ec2.DescribeInstancesInput{ - Filters: []ec2types.Filter{ - {Name: &filterName, Values: []string{"optional"}}, - }, + for k, v := range tags { + params.Filters = append(params.Filters, ec2types.Filter{ + Name: aws.String(fmt.Sprintf("tag:%s", k)), + Values: []string{v}, + }) } for nextToken != nil { instances, err := svc.DescribeInstances(ctx, params) @@ -826,10 +808,7 @@ func (a *mqlAwsEc2) getInstances(conn *connection.AwsConnection) []*jobpool.Job ctx := context.Background() var res []interface{} - // the value for http tokens is not available on api output i've been able to find, so here - // we make two calls to get the instances, one with the imdsv1 filter and another with the imdsv2 filter - filterName := "metadata-options.http-tokens" - imdsv2Instances, err := a.getImdsv2Instances(ctx, svc, filterName) + instances, err := a.getEc2Instances(ctx, svc, conn.Filters.Ec2DiscoveryFilters.Tags) if err != nil { if Is400AccessDeniedError(err) { log.Warn().Str("region", regionVal).Msg("error accessing region for AWS API") @@ -837,25 +816,11 @@ func (a *mqlAwsEc2) getInstances(conn *connection.AwsConnection) []*jobpool.Job } return nil, err } - res, err = a.gatherInstanceInfo(imdsv2Instances, 2, regionVal) + res, err = a.gatherInstanceInfo(instances, regionVal) if err != nil { return nil, err } - imdsv1Instances, err := a.getImdsv1Instances(ctx, svc, filterName) - if err != nil { - if Is400AccessDeniedError(err) { - log.Warn().Str("region", regionVal).Msg("error accessing region for AWS API") - return res, nil - } - return nil, err - } - imdsv1Res, err := a.gatherInstanceInfo(imdsv1Instances, 1, regionVal) - if err != nil { - return nil, err - } - res = append(res, imdsv1Res...) - return jobpool.JobResult(res), nil } tasks = append(tasks, jobpool.NewJob(f)) @@ -863,13 +828,9 @@ func (a *mqlAwsEc2) getInstances(conn *connection.AwsConnection) []*jobpool.Job return tasks } -func (a *mqlAwsEc2) gatherInstanceInfo(instances []ec2types.Reservation, imdsvVersion int, regionVal string) ([]interface{}, error) { +func (a *mqlAwsEc2) gatherInstanceInfo(instances []ec2types.Reservation, regionVal string) ([]interface{}, error) { conn := a.MqlRuntime.Connection.(*connection.AwsConnection) res := []interface{}{} - httpTokens := "required" - if imdsvVersion == 1 { - httpTokens = "optional" - } for _, reservation := range instances { for _, instance := range reservation.Instances { mqlDevices := []interface{}{} @@ -911,7 +872,7 @@ func (a *mqlAwsEc2) gatherInstanceInfo(instances []ec2types.Reservation, imdsvVe "ebsOptimized": llx.BoolDataPtr(instance.EbsOptimized), "enaSupported": llx.BoolDataPtr(instance.EnaSupport), "httpEndpoint": llx.StringData(string(instance.MetadataOptions.HttpEndpoint)), - "httpTokens": llx.StringData(httpTokens), + "httpTokens": llx.StringData(string(instance.MetadataOptions.HttpTokens)), "hypervisor": llx.StringData(string(instance.Hypervisor)), "instanceId": llx.StringDataPtr(instance.InstanceId), "instanceLifecycle": llx.StringData(string(instance.InstanceLifecycle)),