Skip to content

Commit 8e92368

Browse files
authored
1 parent 2aeb2ef commit 8e92368

4 files changed

+61
-3
lines changed

flate.go

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package saml
2+
3+
import (
4+
"compress/flate"
5+
"fmt"
6+
"io"
7+
)
8+
9+
const flateUncompressLimit = 10 * 1024 * 1024 // 10MB
10+
11+
func newSaferFlateReader(r io.Reader) io.ReadCloser {
12+
return &saferFlateReader{r: flate.NewReader(r)}
13+
}
14+
15+
type saferFlateReader struct {
16+
r io.ReadCloser
17+
count int
18+
}
19+
20+
func (r *saferFlateReader) Read(p []byte) (n int, err error) {
21+
if r.count+len(p) > flateUncompressLimit {
22+
return 0, fmt.Errorf("flate: uncompress limit exceeded (%d bytes)", flateUncompressLimit)
23+
}
24+
n, err = r.r.Read(p)
25+
r.count += n
26+
return n, err
27+
}
28+
29+
func (r *saferFlateReader) Close() error {
30+
return r.r.Close()
31+
}

identity_provider.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package saml
22

33
import (
44
"bytes"
5-
"compress/flate"
65
"crypto"
76
"crypto/tls"
87
"crypto/x509"
@@ -363,7 +362,7 @@ func NewIdpAuthnRequest(idp *IdentityProvider, r *http.Request) (*IdpAuthnReques
363362
if err != nil {
364363
return nil, fmt.Errorf("cannot decode request: %s", err)
365364
}
366-
req.RequestBuffer, err = ioutil.ReadAll(flate.NewReader(bytes.NewReader(compressedRequest)))
365+
req.RequestBuffer, err = ioutil.ReadAll(newSaferFlateReader(bytes.NewReader(compressedRequest)))
367366
if err != nil {
368367
return nil, fmt.Errorf("cannot decompress request: %s", err)
369368
}

identity_provider_test.go

+28
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package saml
22

33
import (
4+
"bytes"
5+
"compress/flate"
46
"crypto"
57
"crypto/rsa"
68
"crypto/x509"
@@ -1013,3 +1015,29 @@ func TestIDPNoDestination(t *testing.T) {
10131015
err = req.MakeResponse()
10141016
assert.Check(t, err)
10151017
}
1018+
1019+
func TestIDPRejectDecompressionBomb(t *testing.T) {
1020+
test := NewIdentifyProviderTest(t)
1021+
test.IDP.SessionProvider = &mockSessionProvider{
1022+
GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session {
1023+
fmt.Fprintf(w, "RelayState: %s\nSAMLRequest: %s",
1024+
req.RelayState, req.RequestBuffer)
1025+
return nil
1026+
},
1027+
}
1028+
1029+
//w := httptest.NewRecorder()
1030+
1031+
data := bytes.Repeat([]byte("a"), 768*1024*1024)
1032+
var compressed bytes.Buffer
1033+
w, _ := flate.NewWriter(&compressed, flate.BestCompression)
1034+
w.Write(data)
1035+
w.Close()
1036+
encoded := base64.StdEncoding.EncodeToString(compressed.Bytes())
1037+
1038+
r, _ := http.NewRequest("GET", "/dontcare?"+url.Values{
1039+
"SAMLRequest": {encoded},
1040+
}.Encode(), nil)
1041+
_, err := NewIdpAuthnRequest(&test.IDP, r)
1042+
assert.Error(t, err, "cannot decompress request: flate: uncompress limit exceeded (10485760 bytes)")
1043+
}

service_provider.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1524,7 +1524,7 @@ func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData str
15241524
}
15251525
retErr.Response = string(rawResponseBuf)
15261526

1527-
gr, err := ioutil.ReadAll(flate.NewReader(bytes.NewBuffer(rawResponseBuf)))
1527+
gr, err := ioutil.ReadAll(newSaferFlateReader(bytes.NewBuffer(rawResponseBuf)))
15281528
if err != nil {
15291529
retErr.PrivateErr = err
15301530
return retErr

0 commit comments

Comments
 (0)