From 3cf829718522405a89c5ae932ff757e8fa502ba8 Mon Sep 17 00:00:00 2001 From: Silke Hofstra Date: Wed, 28 Aug 2024 18:00:01 +0200 Subject: [PATCH 1/3] Fix parsing of X-Forwarded-For header Solve a parsing error when `trustAnyForwardedAddress` is enabled. This is caused by the `X-Forwarded-For` header containing an IP address, and not a host+port. --- api/_routers/03-host_detection.go | 39 +++++++++++++++++-------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/api/_routers/03-host_detection.go b/api/_routers/03-host_detection.go index 830e48f0..5771a5e2 100644 --- a/api/_routers/03-host_detection.go +++ b/api/_routers/03-host_detection.go @@ -12,6 +12,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/sebest/xff" "github.com/sirupsen/logrus" + "github.com/t2bot/matrix-media-repo/api/_responses" "github.com/t2bot/matrix-media-repo/common" "github.com/t2bot/matrix-media-repo/common/config" @@ -32,23 +33,7 @@ func (h *HostRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { r.Host = r.Header.Get("X-Forwarded-Host") } r.Host = strings.Split(r.Host, ":")[0] - - var raddr string - if config.Get().General.TrustAnyForward { - raddr = r.Header.Get("X-Forwarded-For") - } else { - raddr = xff.GetRemoteAddr(r) - } - if raddr == "" { - raddr = r.RemoteAddr - } - host, _, err := net.SplitHostPort(raddr) - if err != nil { - logrus.Error(err) - sentry.CaptureException(err) - host = raddr - } - r.RemoteAddr = host + r.RemoteAddr = GetRemoteAddr(r) ignoreHost := ShouldIgnoreHost(r) isOurs := ignoreHost || util.IsServerOurs(r.Host) @@ -85,6 +70,26 @@ func (h *HostRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } +func GetRemoteAddr(r *http.Request) string { + if config.Get().General.TrustAnyForward { + return r.Header.Get("X-Forwarded-For") + } + + raddr := xff.GetRemoteAddr(r) + if raddr == "" { + raddr = r.RemoteAddr + } + + host, _, err := net.SplitHostPort(raddr) + if err != nil { + logrus.WithField("raddr", raddr).WithError(err).Error("Invalid remote address") + sentry.CaptureException(err) + host = raddr + } + + return host +} + func GetDomainConfig(r *http.Request) *config.DomainRepoConfig { x, ok := r.Context().Value(common.ContextDomainConfig).(*config.DomainRepoConfig) if !ok { From 20cc568ee0ce51ec498073b2ff34fce16983caaa Mon Sep 17 00:00:00 2001 From: Silke Hofstra Date: Wed, 28 Aug 2024 18:26:11 +0200 Subject: [PATCH 2/3] Add resolved remote address/host to log entries The current logs contain only the original host and remote address. This makes it difficult to see who actually performed the request without correlating other logs. Add these values to the log context in addition to the original values. --- api/_routers/03-host_detection.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/api/_routers/03-host_detection.go b/api/_routers/03-host_detection.go index 5771a5e2..8e640ea5 100644 --- a/api/_routers/03-host_detection.go +++ b/api/_routers/03-host_detection.go @@ -29,6 +29,9 @@ func NewHostRouter(next http.Handler) *HostRouter { } func (h *HostRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { + origHost := r.Host + origRemoteAddr := r.RemoteAddr + if r.Header.Get("X-Forwarded-Host") != "" && config.Get().General.UseForwardedHost { r.Host = r.Header.Get("X-Forwarded-Host") } @@ -55,6 +58,13 @@ func (h *HostRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { return // don't call next handler } + logger := GetLogger(r).WithFields(logrus.Fields{ + "host": r.Host, + "remoteAddr": r.RemoteAddr, + "origHost": origHost, + "origRemoteAddr": origRemoteAddr, + }) + cfg := config.GetDomain(r.Host) if ignoreHost { dc := config.DomainConfigFrom(*config.Get()) @@ -63,6 +73,7 @@ func (h *HostRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := r.Context() ctx = context.WithValue(ctx, common.ContextDomainConfig, cfg) + ctx = context.WithValue(ctx, common.ContextLogger, logger) r = r.WithContext(ctx) if h.next != nil { From e1b1c7a74cf3f1b521f116b9281b480da841f461 Mon Sep 17 00:00:00 2001 From: Silke Hofstra Date: Wed, 28 Aug 2024 18:26:28 +0200 Subject: [PATCH 3/3] Log rate limit subject Add the subject that is rate limited to the logged entry. --- pipelines/pipeline_download/pipeline.go | 10 +++++++--- pipelines/pipeline_thumbnail/pipeline.go | 7 +++++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/pipelines/pipeline_download/pipeline.go b/pipelines/pipeline_download/pipeline.go index e56715b8..c43efe76 100644 --- a/pipelines/pipeline_download/pipeline.go +++ b/pipelines/pipeline_download/pipeline.go @@ -10,6 +10,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/t2bot/go-leaky-bucket" "github.com/t2bot/go-singleflight-streams" + "github.com/t2bot/matrix-media-repo/common" "github.com/t2bot/matrix-media-repo/common/rcontext" "github.com/t2bot/matrix-media-repo/database" @@ -76,7 +77,8 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do } // Check rate limits before moving on much further - limitBucket, err := limits.GetBucket(ctx, limits.GetRequestIP(ctx.Request)) + subject := limits.GetRequestIP(ctx.Request) + limitBucket, err := limits.GetBucket(ctx, subject) if err != nil { cancel() return nil, nil, err @@ -90,7 +92,8 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do if limitErr := limitBucket.Add(ctx.Config.Downloads.MaxSizeBytes); limitErr != nil { cancel() if errors.Is(limitErr, leaky.ErrBucketFull) { - ctx.Log.Debugf("Rate limited on MaxSizeBytes=%d/%d", ctx.Config.Downloads.MaxSizeBytes, limitBucket.Remaining()) + ctx.Log.WithField("subject", subject). + Debugf("Rate limited on MaxSizeBytes=%d/%d", ctx.Config.Downloads.MaxSizeBytes, limitBucket.Remaining()) return nil, nil, common.ErrRateLimitExceeded } return nil, nil, limitErr @@ -101,7 +104,8 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do if limitErr := limitBucket.Add(record.SizeBytes); limitErr != nil { cancel() if errors.Is(limitErr, leaky.ErrBucketFull) { - ctx.Log.Debugf("Rate limited on SizeBytes=%d/%d", record.SizeBytes, limitBucket.Remaining()) + ctx.Log.WithField("subject", subject). + Debugf("Rate limited on SizeBytes=%d/%d", record.SizeBytes, limitBucket.Remaining()) return nil, nil, common.ErrRateLimitExceeded } return nil, nil, limitErr diff --git a/pipelines/pipeline_thumbnail/pipeline.go b/pipelines/pipeline_thumbnail/pipeline.go index fb88816b..70780c96 100644 --- a/pipelines/pipeline_thumbnail/pipeline.go +++ b/pipelines/pipeline_thumbnail/pipeline.go @@ -9,6 +9,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/t2bot/go-leaky-bucket" sfstreams "github.com/t2bot/go-singleflight-streams" + "github.com/t2bot/matrix-media-repo/common" "github.com/t2bot/matrix-media-repo/common/rcontext" "github.com/t2bot/matrix-media-repo/database" @@ -87,7 +88,8 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th } // Check rate limits before moving on much further - limitBucket, err := limits.GetBucket(ctx, limits.GetRequestIP(ctx.Request)) + subject := limits.GetRequestIP(ctx.Request) + limitBucket, err := limits.GetBucket(ctx, subject) if err != nil { cancel() return nil, nil, err @@ -96,7 +98,8 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th if limitErr := limitBucket.Add(record.SizeBytes); limitErr != nil { cancel() if errors.Is(limitErr, leaky.ErrBucketFull) { - ctx.Log.Debugf("Rate limited on SizeBytes=%d/%d", record.SizeBytes, limitBucket.Remaining()) + ctx.Log.WithField("subject", subject). + Debugf("Rate limited on SizeBytes=%d/%d", record.SizeBytes, limitBucket.Remaining()) return nil, nil, common.ErrRateLimitExceeded } return nil, nil, limitErr