Rewrite the Host header in forwarded requests
The default httputil.NewSingleHostReverseProxy implementation does not rewrite the Host header in forwarded requests. So, the upstream server receives requests with the Host header set as the original reMarkable domain. In where a reverse proxy is used in front of rmfakecloud (as suggested [here](https://github.com/ddvk/rmfakecloud/blob/master/docs/https.md)), this can make the HTTP server confused especially if it is configured to serve several websites (in which case the Host header is used to differentiate requests). This PR replaces the call to NewSingleHostReverseProxy with an implementation that rewrites the Host header (by assigning `req.Host`). This is essentially a copy/paste of the [original implementation](https://cs.opensource.google/go/go/+/refs/tags/go1.17.1:src/net/http/httputil/reverseproxy.go;drc=b7a85e0003cedb1b48a1fd3ae5b746ec6330102e;l=143) but with a new line added that does the rewrite. I don’t know if there’s a cleaner way to do this, and this may introduce licensing issues since the original source is BSD-licensed.
This commit is contained in:
parent
65969a4697
commit
9ef79c874c
1 changed files with 57 additions and 5 deletions
62
main.go
62
main.go
|
@ -16,6 +16,7 @@ import (
|
|||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
|
@ -77,21 +78,72 @@ func getConfig() (config *Config, err error) {
|
|||
return &cfg, nil
|
||||
}
|
||||
|
||||
func singleJoiningSlash(a, b string) string {
|
||||
aslash := strings.HasSuffix(a, "/")
|
||||
bslash := strings.HasPrefix(b, "/")
|
||||
switch {
|
||||
case aslash && bslash:
|
||||
return a + b[1:]
|
||||
case !aslash && !bslash:
|
||||
return a + "/" + b
|
||||
}
|
||||
return a + b
|
||||
}
|
||||
|
||||
func joinURLPath(a, b *url.URL) (path, rawpath string) {
|
||||
if a.RawPath == "" && b.RawPath == "" {
|
||||
return singleJoiningSlash(a.Path, b.Path), ""
|
||||
}
|
||||
// Same as singleJoiningSlash, but uses EscapedPath to determine
|
||||
// whether a slash should be added
|
||||
apath := a.EscapedPath()
|
||||
bpath := b.EscapedPath()
|
||||
|
||||
aslash := strings.HasSuffix(apath, "/")
|
||||
bslash := strings.HasPrefix(bpath, "/")
|
||||
|
||||
switch {
|
||||
case aslash && bslash:
|
||||
return a.Path + b.Path[1:], apath + bpath[1:]
|
||||
case !aslash && !bslash:
|
||||
return a.Path + "/" + b.Path, apath + "/" + bpath
|
||||
}
|
||||
return a.Path + b.Path, apath + bpath
|
||||
}
|
||||
|
||||
func _main() error {
|
||||
cfg, err := getConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u, err := url.Parse(cfg.Upstream)
|
||||
upstream, err := url.Parse(cfg.Upstream)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid upstream address: %v", err)
|
||||
}
|
||||
|
||||
rp := httputil.NewSingleHostReverseProxy(u)
|
||||
upstreamQuery := upstream.RawQuery
|
||||
director := func(req *http.Request) {
|
||||
req.URL.Scheme = upstream.Scheme
|
||||
req.Host = upstream.Host
|
||||
req.URL.Host = upstream.Host
|
||||
req.URL.Path, req.URL.RawPath = joinURLPath(upstream, req.URL)
|
||||
if upstreamQuery == "" || req.URL.RawQuery == "" {
|
||||
req.URL.RawQuery = upstreamQuery + req.URL.RawQuery
|
||||
} else {
|
||||
req.URL.RawQuery = upstreamQuery + "&" + req.URL.RawQuery
|
||||
}
|
||||
if _, ok := req.Header["User-Agent"]; !ok {
|
||||
// explicitly disable User-Agent so it's not set to default value
|
||||
req.Header.Set("User-Agent", "")
|
||||
}
|
||||
}
|
||||
|
||||
srv := http.Server{
|
||||
Handler: rp,
|
||||
Addr: cfg.Addr,
|
||||
Handler: &httputil.ReverseProxy{
|
||||
Director: director,
|
||||
},
|
||||
Addr: cfg.Addr,
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
|
@ -106,7 +158,7 @@ func _main() error {
|
|||
close(done)
|
||||
}()
|
||||
|
||||
log.Printf("cert-file=%s key-file=%s listen-addr=%s upstream-url=%s", cfg.CertFile, cfg.KeyFile, srv.Addr, u.String())
|
||||
log.Printf("cert-file=%s key-file=%s listen-addr=%s upstream-url=%s", cfg.CertFile, cfg.KeyFile, srv.Addr, upstream.String())
|
||||
if err := srv.ListenAndServeTLS(cfg.CertFile, cfg.KeyFile); err != http.ErrServerClosed {
|
||||
return fmt.Errorf("ListenAndServeTLS: %v", err)
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue