Skip to content

Commit

Permalink
Merge pull request #98 from cerberauth/no-valid-writer-value
Browse files Browse the repository at this point in the history
fix: manage when no valid token is provided
  • Loading branch information
emmanuelgautier authored May 7, 2024
2 parents d698481 + dc18e22 commit b849141
Show file tree
Hide file tree
Showing 21 changed files with 232 additions and 60 deletions.
4 changes: 4 additions & 0 deletions internal/auth/bearer.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ func (ss *BearerSecurityScheme) GetCookies() []*http.Cookie {
return []*http.Cookie{}
}

func (ss *BearerSecurityScheme) HasValidValue() bool {
return ss.ValidValue != nil
}

func (ss *BearerSecurityScheme) GetValidValue() interface{} {
return *ss.ValidValue
}
Expand Down
19 changes: 19 additions & 0 deletions internal/auth/bearer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,25 @@ func TestBearerSecurityScheme_GetCookies(t *testing.T) {
assert.Empty(t, cookies)
}

func TestBearerSecurityScheme_HasValidValue(t *testing.T) {
name := "token"
value := "abc123"
ss := auth.NewAuthorizationBearerSecurityScheme(name, &value)

result := ss.HasValidValue()

assert.True(t, result)
}

func TestBearerSecurityScheme_HasValidValueFalse(t *testing.T) {
name := "token"
ss := auth.NewAuthorizationBearerSecurityScheme(name, nil)

result := ss.HasValidValue()

assert.False(t, result)
}

func TestBearerSecurityScheme_GetValidValue(t *testing.T) {
name := "token"
value := "abc123"
Expand Down
4 changes: 4 additions & 0 deletions internal/auth/jwt_bearer.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ func (ss *JWTBearerSecurityScheme) GetCookies() []*http.Cookie {
return []*http.Cookie{}
}

func (ss *JWTBearerSecurityScheme) HasValidValue() bool {
return ss.ValidValue != nil
}

func (ss *JWTBearerSecurityScheme) GetValidValue() interface{} {
return *ss.ValidValue
}
Expand Down
36 changes: 27 additions & 9 deletions internal/auth/jwt_bearer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@ import (
"testing"

"github.com/cerberauth/vulnapi/internal/auth"
"github.com/cerberauth/vulnapi/jwt"
"github.com/stretchr/testify/assert"
)

const fakeJWT = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U"

func TestNewAuthorizationJWTBearerSecurityScheme(t *testing.T) {
name := "token"
value := fakeJWT
value := jwt.FakeJWT
ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value)

assert.NoError(t, err)
Expand All @@ -34,7 +33,7 @@ func TestNewAuthorizationJWTBearerSecuritySchemeWithInvalidJWT(t *testing.T) {

func TestJWTBearerSecurityScheme_GetHeaders(t *testing.T) {
name := "token"
value := fakeJWT
value := jwt.FakeJWT
attackValue := "xyz789"
ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value)
ss.SetAttackValue(attackValue)
Expand All @@ -49,17 +48,36 @@ func TestJWTBearerSecurityScheme_GetHeaders(t *testing.T) {

func TestJWTBearerSecurityScheme_GetCookies(t *testing.T) {
name := "token"
value := fakeJWT
value := jwt.FakeJWT
ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value)
cookies := ss.GetCookies()

assert.NoError(t, err)
assert.Empty(t, cookies)
}

func TestJWTBearerSecurityScheme_HasValidValue(t *testing.T) {
name := "token"
value := jwt.FakeJWT
ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value)
hasValidValue := ss.HasValidValue()

assert.NoError(t, err)
assert.True(t, hasValidValue)
}

func TestJWTBearerSecurityScheme_HasValidValue_WhenNoValue(t *testing.T) {
name := "token"
ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, nil)
hasValidValue := ss.HasValidValue()

assert.NoError(t, err)
assert.False(t, hasValidValue)
}

func TestJWTBearerSecurityScheme_GetValidValue(t *testing.T) {
name := "token"
value := fakeJWT
value := jwt.FakeJWT
ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value)
validValue := ss.GetValidValue()

Expand All @@ -69,7 +87,7 @@ func TestJWTBearerSecurityScheme_GetValidValue(t *testing.T) {

func TestJWTBearerSecurityScheme_GetValidValueWriter(t *testing.T) {
name := "token"
value := fakeJWT
value := jwt.FakeJWT
ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value)
writer := ss.GetValidValueWriter()

Expand All @@ -79,7 +97,7 @@ func TestJWTBearerSecurityScheme_GetValidValueWriter(t *testing.T) {

func TestJWTBearerSecurityScheme_SetAttackValue(t *testing.T) {
name := "token"
value := fakeJWT
value := jwt.FakeJWT
ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value)
attackValue := "xyz789"
ss.SetAttackValue(attackValue)
Expand All @@ -90,7 +108,7 @@ func TestJWTBearerSecurityScheme_SetAttackValue(t *testing.T) {

func TestJWTBearerSecurityScheme_GetAttackValue(t *testing.T) {
name := "token"
value := fakeJWT
value := jwt.FakeJWT
ss, err := auth.NewAuthorizationJWTBearerSecurityScheme(name, &value)
attackValue := "xyz789"
ss.SetAttackValue(attackValue)
Expand Down
5 changes: 5 additions & 0 deletions internal/auth/security_scheme.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type SecurityScheme interface {
GetHeaders() http.Header
GetCookies() []*http.Cookie
GetValidValue() interface{}
HasValidValue() bool
GetValidValueWriter() interface{}
SetAttackValue(v interface{})
GetAttackValue() interface{}
Expand All @@ -38,6 +39,10 @@ func (ss *NoAuthSecurityScheme) GetCookies() []*http.Cookie {
return []*http.Cookie{}
}

func (ss *NoAuthSecurityScheme) HasValidValue() bool {
return false
}

func (ss *NoAuthSecurityScheme) GetValidValue() interface{} {
return ""
}
Expand Down
3 changes: 3 additions & 0 deletions jwt/const.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package jwt

const FakeJWT = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U"
6 changes: 4 additions & 2 deletions scan/best_practices/http_cookies.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@ const (
)

func HTTPCookiesScanHandler(operation *request.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) {
r := report.NewScanReport(HTTPCookiesScanID, HTTPCookiesScanName)
if securityScheme.HasValidValue() {
securityScheme.SetAttackValue(securityScheme.GetValidValue())
}

securityScheme.SetAttackValue(securityScheme.GetValidValue())
attempt, err := scan.ScanURL(operation, &securityScheme)
r := report.NewScanReport(HTTPCookiesScanID, HTTPCookiesScanName)
r.AddScanAttempt(attempt).End()
if err != nil {
return r, err
Expand Down
6 changes: 4 additions & 2 deletions scan/best_practices/http_headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,12 @@ func CheckCORSAllowOrigin(operation *request.Operation, headers http.Header, r *
}

func HTTPHeadersBestPracticesScanHandler(operation *request.Operation, ss auth.SecurityScheme) (*report.ScanReport, error) {
r := report.NewScanReport(HTTPHeadersScanID, HTTPHeadersScanName)
if ss.HasValidValue() {
ss.SetAttackValue(ss.GetValidValue())
}

ss.SetAttackValue(ss.GetValidValue())
vsa, err := scan.ScanURL(operation, &ss)
r := report.NewScanReport(HTTPHeadersScanID, HTTPHeadersScanName)
r.AddScanAttempt(vsa).End()
if err != nil {
return r, err
Expand Down
7 changes: 5 additions & 2 deletions scan/best_practices/http_trace_method.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ const (
)

func HTTPTraceMethodScanHandler(operation *request.Operation, ss auth.SecurityScheme) (*report.ScanReport, error) {
r := report.NewScanReport(HTTPTraceScanID, HTTPTraceScanName)
if ss.HasValidValue() {
ss.SetAttackValue(ss.GetValidValue())
}

newOperation := operation.Clone()
newOperation.Method = "TRACE"

ss.SetAttackValue(ss.GetValidValue())
vsa, err := scan.ScanURL(newOperation, &ss)
r := report.NewScanReport(HTTPTraceScanID, HTTPTraceScanName)
r.AddScanAttempt(vsa).End()
if err != nil {
return r, err
Expand Down
6 changes: 4 additions & 2 deletions scan/discover/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ func newGetGraphqlIntrospectionRequest(endpoint *url.URL) (*http.Request, error)
}

func GraphqlIntrospectionScanHandler(operation *request.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) {
r := report.NewScanReport(GraphqlIntrospectionScanID, GraphqlIntrospectionScanName)
securityScheme.SetAttackValue(securityScheme.GetValidValue())
if securityScheme.HasValidValue() {
securityScheme.SetAttackValue(securityScheme.GetValidValue())
}

base := ExtractBaseURL(operation.Request.URL)

r := report.NewScanReport(GraphqlIntrospectionScanID, GraphqlIntrospectionScanName)
for _, path := range potentialGraphQLEndpoints {
newRequest, err := newPostGraphqlIntrospectionRequest(base.ResolveReference(&url.URL{Path: path}))
if err != nil {
Expand Down
6 changes: 4 additions & 2 deletions scan/discover/server_signature.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ func checkSignatureHeader(operation *request.Operation, headers map[string][]str
}

func ServerSignatureScanHandler(operation *request.Operation, securityScheme auth.SecurityScheme) (*report.ScanReport, error) {
r := report.NewScanReport(DiscoverServerSignatureScanID, DiscoverServerSignatureScanName)
if securityScheme.HasValidValue() {
securityScheme.SetAttackValue(securityScheme.GetValidValue())
}

securityScheme.SetAttackValue(securityScheme.GetValidValue())
vsa, err := scan.ScanURL(operation, &securityScheme)
r := report.NewScanReport(DiscoverServerSignatureScanID, DiscoverServerSignatureScanName)
r.AddScanAttempt(vsa).End()
if err != nil {
return r, err
Expand Down
15 changes: 10 additions & 5 deletions scan/jwt/alg_none.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,21 @@ const (
)

func AlgNoneJwtScanHandler(operation *request.Operation, ss auth.SecurityScheme) (*report.ScanReport, error) {
r := report.NewScanReport(AlgNoneJwtScanID, AlgNoneJwtScanName)
if !ShouldBeScanned(ss) {
return r, nil
return nil, nil
}

valueWriter := ss.GetValidValueWriter().(*jwt.JWTWriter)
if valueWriter.Token.Method.Alg() == jwtlib.SigningMethodNone.Alg() {
return r, nil
var valueWriter *jwt.JWTWriter
if ss.HasValidValue() {
valueWriter = ss.GetValidValueWriter().(*jwt.JWTWriter)
if valueWriter.Token.Method.Alg() == jwtlib.SigningMethodNone.Alg() {
return nil, nil
}
} else {
valueWriter, _ = jwt.NewJWTWriter(jwt.FakeJWT)
}

r := report.NewScanReport(AlgNoneJwtScanID, AlgNoneJwtScanName)
newToken, err := valueWriter.WithAlgNone()
if err != nil {
return r, err
Expand Down
23 changes: 20 additions & 3 deletions scan/jwt/alg_none_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jwt_test

import (
"net/http"
"testing"

"github.com/cerberauth/vulnapi/internal/auth"
Expand All @@ -10,19 +11,35 @@ import (
"github.com/stretchr/testify/assert"
)

func TestAlgNoneJwtScanHandlerWithoutJwt(t *testing.T) {
func TestAlgNoneJwtScanHandlerWithoutSecurityScheme(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()

securityScheme := auth.NewNoAuthSecurityScheme()
operation := request.NewOperation("http://localhost:8080/", "GET", nil, nil, nil)

httpmock.RegisterResponder(operation.Method, operation.Request.URL.String(), httpmock.NewBytesResponder(405, nil))
httpmock.RegisterResponder(operation.Method, operation.Request.URL.String(), httpmock.NewBytesResponder(http.StatusUnauthorized, nil))

report, err := jwt.AlgNoneJwtScanHandler(operation, securityScheme)

assert.NoError(t, err)
assert.Equal(t, 0, httpmock.GetTotalCallCount())
assert.Nil(t, report)
}

func TestAlgNoneJwtScanHandlerWithoutJWT(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()

securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", nil)
operation := request.NewOperation("http://localhost:8080/", "GET", nil, nil, nil)

httpmock.RegisterResponder(operation.Method, operation.Request.URL.String(), httpmock.NewBytesResponder(http.StatusUnauthorized, nil))

report, err := jwt.AlgNoneJwtScanHandler(operation, securityScheme)

assert.NoError(t, err)
assert.Equal(t, 1, httpmock.GetTotalCallCount())
assert.False(t, report.HasVulnerabilityReport())
}

Expand All @@ -34,7 +51,7 @@ func TestAlgNoneJwtScanHandler(t *testing.T) {
securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", &token)
operation := request.NewOperation("http://localhost:8080/", "GET", nil, nil, nil)

httpmock.RegisterResponder(operation.Method, operation.Request.URL.String(), httpmock.NewBytesResponder(401, nil))
httpmock.RegisterResponder(operation.Method, operation.Request.URL.String(), httpmock.NewBytesResponder(http.StatusUnauthorized, nil))

report, err := jwt.AlgNoneJwtScanHandler(operation, securityScheme)

Expand Down
13 changes: 9 additions & 4 deletions scan/jwt/blank_secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,18 @@ const (
)

func BlankSecretScanHandler(operation *request.Operation, ss auth.SecurityScheme) (*report.ScanReport, error) {
r := report.NewScanReport(BlankSecretVulnerabilityScanID, BlankSecretVulnerabilityScanName)
if !ShouldBeScanned(ss) {
r.End()
return r, nil
return nil, nil
}

var valueWriter *jwt.JWTWriter
if ss.HasValidValue() {
valueWriter = ss.GetValidValueWriter().(*jwt.JWTWriter)
} else {
valueWriter, _ = jwt.NewJWTWriter(jwt.FakeJWT)
}

valueWriter := ss.GetValidValueWriter().(*jwt.JWTWriter)
r := report.NewScanReport(BlankSecretVulnerabilityScanID, BlankSecretVulnerabilityScanName)
newToken, err := valueWriter.SignWithKey([]byte(""))
if err != nil {
return r, err
Expand Down
21 changes: 19 additions & 2 deletions scan/jwt/blank_secret_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jwt_test

import (
"net/http"
"testing"

"github.com/cerberauth/vulnapi/internal/auth"
Expand All @@ -10,14 +11,30 @@ import (
"github.com/stretchr/testify/assert"
)

func TestBlankSecretScanHandlerWithoutJwt(t *testing.T) {
func TestBlankSecretScanHandlerWithoutSecurityScheme(t *testing.T) {
securityScheme := auth.NewNoAuthSecurityScheme()
operation := request.NewOperation("http://localhost:8080/", "GET", nil, nil, nil)

report, err := jwt.BlankSecretScanHandler(operation, securityScheme)

assert.NoError(t, err)
assert.Equal(t, 0, httpmock.GetTotalCallCount())
assert.Nil(t, report)
}

func TestBlankSecretScanHandlerWithoutJWT(t *testing.T) {
httpmock.Activate()
defer httpmock.DeactivateAndReset()

securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", nil)
operation := request.NewOperation("http://localhost:8080/", "GET", nil, nil, nil)

httpmock.RegisterResponder(operation.Method, operation.Request.URL.String(), httpmock.NewBytesResponder(http.StatusUnauthorized, nil))

report, err := jwt.BlankSecretScanHandler(operation, securityScheme)

assert.NoError(t, err)
assert.Equal(t, 1, httpmock.GetTotalCallCount())
assert.False(t, report.HasVulnerabilityReport())
}

Expand All @@ -29,7 +46,7 @@ func TestBlankSecretScanHandler(t *testing.T) {
securityScheme, _ := auth.NewAuthorizationJWTBearerSecurityScheme("token", &token)
operation := request.NewOperation("http://localhost:8080/", "GET", nil, nil, nil)

httpmock.RegisterResponder(operation.Method, operation.Request.URL.String(), httpmock.NewBytesResponder(401, nil))
httpmock.RegisterResponder(operation.Method, operation.Request.URL.String(), httpmock.NewBytesResponder(http.StatusUnauthorized, nil))

report, err := jwt.BlankSecretScanHandler(operation, securityScheme)

Expand Down
Loading

0 comments on commit b849141

Please sign in to comment.