diff --git a/api/_routers/03-host_detection.go b/api/_routers/03-host_detection.go index 830e48f0..8e640ea5 100644 --- a/api/_routers/03-host_detection.go +++ b/api/_routers/03-host_detection.go @@ -12,6 +12,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/sebest/xff" "github.com/sirupsen/logrus" + "github.com/t2bot/matrix-media-repo/api/_responses" "github.com/t2bot/matrix-media-repo/common" "github.com/t2bot/matrix-media-repo/common/config" @@ -28,27 +29,14 @@ func NewHostRouter(next http.Handler) *HostRouter { } func (h *HostRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { + origHost := r.Host + origRemoteAddr := r.RemoteAddr + if r.Header.Get("X-Forwarded-Host") != "" && config.Get().General.UseForwardedHost { r.Host = r.Header.Get("X-Forwarded-Host") } r.Host = strings.Split(r.Host, ":")[0] - - var raddr string - if config.Get().General.TrustAnyForward { - raddr = r.Header.Get("X-Forwarded-For") - } else { - raddr = xff.GetRemoteAddr(r) - } - if raddr == "" { - raddr = r.RemoteAddr - } - host, _, err := net.SplitHostPort(raddr) - if err != nil { - logrus.Error(err) - sentry.CaptureException(err) - host = raddr - } - r.RemoteAddr = host + r.RemoteAddr = GetRemoteAddr(r) ignoreHost := ShouldIgnoreHost(r) isOurs := ignoreHost || util.IsServerOurs(r.Host) @@ -70,6 +58,13 @@ func (h *HostRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { return // don't call next handler } + logger := GetLogger(r).WithFields(logrus.Fields{ + "host": r.Host, + "remoteAddr": r.RemoteAddr, + "origHost": origHost, + "origRemoteAddr": origRemoteAddr, + }) + cfg := config.GetDomain(r.Host) if ignoreHost { dc := config.DomainConfigFrom(*config.Get()) @@ -78,6 +73,7 @@ func (h *HostRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := r.Context() ctx = context.WithValue(ctx, common.ContextDomainConfig, cfg) + ctx = context.WithValue(ctx, common.ContextLogger, logger) r = r.WithContext(ctx) if h.next != nil { @@ -85,6 +81,26 @@ func (h *HostRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } +func GetRemoteAddr(r *http.Request) string { + if config.Get().General.TrustAnyForward { + return r.Header.Get("X-Forwarded-For") + } + + raddr := xff.GetRemoteAddr(r) + if raddr == "" { + raddr = r.RemoteAddr + } + + host, _, err := net.SplitHostPort(raddr) + if err != nil { + logrus.WithField("raddr", raddr).WithError(err).Error("Invalid remote address") + sentry.CaptureException(err) + host = raddr + } + + return host +} + func GetDomainConfig(r *http.Request) *config.DomainRepoConfig { x, ok := r.Context().Value(common.ContextDomainConfig).(*config.DomainRepoConfig) if !ok { diff --git a/pipelines/pipeline_download/pipeline.go b/pipelines/pipeline_download/pipeline.go index e56715b8..c43efe76 100644 --- a/pipelines/pipeline_download/pipeline.go +++ b/pipelines/pipeline_download/pipeline.go @@ -10,6 +10,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/t2bot/go-leaky-bucket" "github.com/t2bot/go-singleflight-streams" + "github.com/t2bot/matrix-media-repo/common" "github.com/t2bot/matrix-media-repo/common/rcontext" "github.com/t2bot/matrix-media-repo/database" @@ -76,7 +77,8 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do } // Check rate limits before moving on much further - limitBucket, err := limits.GetBucket(ctx, limits.GetRequestIP(ctx.Request)) + subject := limits.GetRequestIP(ctx.Request) + limitBucket, err := limits.GetBucket(ctx, subject) if err != nil { cancel() return nil, nil, err @@ -90,7 +92,8 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do if limitErr := limitBucket.Add(ctx.Config.Downloads.MaxSizeBytes); limitErr != nil { cancel() if errors.Is(limitErr, leaky.ErrBucketFull) { - ctx.Log.Debugf("Rate limited on MaxSizeBytes=%d/%d", ctx.Config.Downloads.MaxSizeBytes, limitBucket.Remaining()) + ctx.Log.WithField("subject", subject). + Debugf("Rate limited on MaxSizeBytes=%d/%d", ctx.Config.Downloads.MaxSizeBytes, limitBucket.Remaining()) return nil, nil, common.ErrRateLimitExceeded } return nil, nil, limitErr @@ -101,7 +104,8 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do if limitErr := limitBucket.Add(record.SizeBytes); limitErr != nil { cancel() if errors.Is(limitErr, leaky.ErrBucketFull) { - ctx.Log.Debugf("Rate limited on SizeBytes=%d/%d", record.SizeBytes, limitBucket.Remaining()) + ctx.Log.WithField("subject", subject). + Debugf("Rate limited on SizeBytes=%d/%d", record.SizeBytes, limitBucket.Remaining()) return nil, nil, common.ErrRateLimitExceeded } return nil, nil, limitErr diff --git a/pipelines/pipeline_thumbnail/pipeline.go b/pipelines/pipeline_thumbnail/pipeline.go index fb88816b..70780c96 100644 --- a/pipelines/pipeline_thumbnail/pipeline.go +++ b/pipelines/pipeline_thumbnail/pipeline.go @@ -9,6 +9,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/t2bot/go-leaky-bucket" sfstreams "github.com/t2bot/go-singleflight-streams" + "github.com/t2bot/matrix-media-repo/common" "github.com/t2bot/matrix-media-repo/common/rcontext" "github.com/t2bot/matrix-media-repo/database" @@ -87,7 +88,8 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th } // Check rate limits before moving on much further - limitBucket, err := limits.GetBucket(ctx, limits.GetRequestIP(ctx.Request)) + subject := limits.GetRequestIP(ctx.Request) + limitBucket, err := limits.GetBucket(ctx, subject) if err != nil { cancel() return nil, nil, err @@ -96,7 +98,8 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th if limitErr := limitBucket.Add(record.SizeBytes); limitErr != nil { cancel() if errors.Is(limitErr, leaky.ErrBucketFull) { - ctx.Log.Debugf("Rate limited on SizeBytes=%d/%d", record.SizeBytes, limitBucket.Remaining()) + ctx.Log.WithField("subject", subject). + Debugf("Rate limited on SizeBytes=%d/%d", record.SizeBytes, limitBucket.Remaining()) return nil, nil, common.ErrRateLimitExceeded } return nil, nil, limitErr