The default httputil.NewSingleHostReverseProxy implementation does not rewrite the Host header in forwarded requests. So, the upstream server receives requests with the Host header set as the original reMarkable domain. In where a reverse proxy is used in front of rmfakecloud (as suggested [here](https://github.com/ddvk/rmfakecloud/blob/master/docs/https.md)), this can make the HTTP server confused especially if it is configured to serve several websites (in which case the Host header is used to differentiate requests). This PR replaces the call to NewSingleHostReverseProxy with an implementation that rewrites the Host header (by assigning `req.Host`). This is essentially a copy/paste of the [original implementation](https://cs.opensource.google/go/go/+/refs/tags/go1.17.1:src/net/http/httputil/reverseproxy.go;drc=b7a85e0003cedb1b48a1fd3ae5b746ec6330102e;l=143) but with a new line added that does the rewrite. I don’t know if there’s a cleaner way to do this, and this may introduce licensing issues since the original source is BSD-licensed.
176 lines
3.9 KiB
Go
176 lines
3.9 KiB
Go
//go:generate go run generate/versioninfo.go
|
|
|
|
// secure is a super simple TLS termination proxy
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"flag"
|
|
"fmt"
|
|
"gopkg.in/yaml.v3"
|
|
"io/ioutil"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"net/url"
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"strings"
|
|
"syscall"
|
|
)
|
|
|
|
type Config struct {
|
|
CertFile string `yaml:"certfile"`
|
|
KeyFile string `yaml:"keyfile"`
|
|
Upstream string `yaml:"upstream"`
|
|
Addr string `yaml:"addr"`
|
|
}
|
|
|
|
var (
|
|
version bool
|
|
configFile string
|
|
)
|
|
|
|
func getConfig() (config *Config, err error) {
|
|
cfg := Config{}
|
|
flag.StringVar(&configFile, "c", "", "config file")
|
|
flag.StringVar(&cfg.Addr, "addr", ":443", "listen address")
|
|
flag.StringVar(&cfg.CertFile, "cert", "", "path to cert file")
|
|
flag.StringVar(&cfg.KeyFile, "key", "", "path to key file")
|
|
flag.BoolVar(&version, "version", false, "print version string and exit")
|
|
|
|
flag.Usage = func() {
|
|
fmt.Fprintf(flag.CommandLine.Output(),
|
|
"usage: %s -c [config.yml] [-addr host:port] -cert certfile -key keyfile [-version] upstream\n",
|
|
filepath.Base(os.Args[0]))
|
|
flag.PrintDefaults()
|
|
fmt.Fprintln(flag.CommandLine.Output(), " upstream string\n \tupstream url")
|
|
}
|
|
flag.Parse()
|
|
|
|
if version {
|
|
fmt.Fprintln(flag.CommandLine.Output(), Version)
|
|
os.Exit(0)
|
|
}
|
|
|
|
if configFile != "" {
|
|
var data []byte
|
|
data, err = ioutil.ReadFile(configFile)
|
|
|
|
if err != nil {
|
|
return
|
|
}
|
|
err = yaml.Unmarshal(data, &cfg)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("cant parse config, %v", err)
|
|
}
|
|
return &cfg, nil
|
|
}
|
|
|
|
if flag.NArg() == 1 {
|
|
cfg.Upstream = flag.Arg(0)
|
|
} else {
|
|
flag.Usage()
|
|
os.Exit(2)
|
|
}
|
|
|
|
return &cfg, nil
|
|
}
|
|
|
|
func singleJoiningSlash(a, b string) string {
|
|
aslash := strings.HasSuffix(a, "/")
|
|
bslash := strings.HasPrefix(b, "/")
|
|
switch {
|
|
case aslash && bslash:
|
|
return a + b[1:]
|
|
case !aslash && !bslash:
|
|
return a + "/" + b
|
|
}
|
|
return a + b
|
|
}
|
|
|
|
func joinURLPath(a, b *url.URL) (path, rawpath string) {
|
|
if a.RawPath == "" && b.RawPath == "" {
|
|
return singleJoiningSlash(a.Path, b.Path), ""
|
|
}
|
|
// Same as singleJoiningSlash, but uses EscapedPath to determine
|
|
// whether a slash should be added
|
|
apath := a.EscapedPath()
|
|
bpath := b.EscapedPath()
|
|
|
|
aslash := strings.HasSuffix(apath, "/")
|
|
bslash := strings.HasPrefix(bpath, "/")
|
|
|
|
switch {
|
|
case aslash && bslash:
|
|
return a.Path + b.Path[1:], apath + bpath[1:]
|
|
case !aslash && !bslash:
|
|
return a.Path + "/" + b.Path, apath + "/" + bpath
|
|
}
|
|
return a.Path + b.Path, apath + bpath
|
|
}
|
|
|
|
func _main() error {
|
|
cfg, err := getConfig()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
upstream, err := url.Parse(cfg.Upstream)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid upstream address: %v", err)
|
|
}
|
|
|
|
upstreamQuery := upstream.RawQuery
|
|
director := func(req *http.Request) {
|
|
req.URL.Scheme = upstream.Scheme
|
|
req.Host = upstream.Host
|
|
req.URL.Host = upstream.Host
|
|
req.URL.Path, req.URL.RawPath = joinURLPath(upstream, req.URL)
|
|
if upstreamQuery == "" || req.URL.RawQuery == "" {
|
|
req.URL.RawQuery = upstreamQuery + req.URL.RawQuery
|
|
} else {
|
|
req.URL.RawQuery = upstreamQuery + "&" + req.URL.RawQuery
|
|
}
|
|
if _, ok := req.Header["User-Agent"]; !ok {
|
|
// explicitly disable User-Agent so it's not set to default value
|
|
req.Header.Set("User-Agent", "")
|
|
}
|
|
}
|
|
|
|
srv := http.Server{
|
|
Handler: &httputil.ReverseProxy{
|
|
Director: director,
|
|
},
|
|
Addr: cfg.Addr,
|
|
}
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
sig := make(chan os.Signal, 1)
|
|
signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
|
|
fmt.Println(<-sig)
|
|
|
|
if err := srv.Shutdown(context.Background()); err != nil {
|
|
fmt.Printf("Shutdown: %v", err)
|
|
}
|
|
close(done)
|
|
}()
|
|
|
|
log.Printf("cert-file=%s key-file=%s listen-addr=%s upstream-url=%s", cfg.CertFile, cfg.KeyFile, srv.Addr, upstream.String())
|
|
if err := srv.ListenAndServeTLS(cfg.CertFile, cfg.KeyFile); err != http.ErrServerClosed {
|
|
return fmt.Errorf("ListenAndServeTLS: %v", err)
|
|
}
|
|
|
|
<-done
|
|
return nil
|
|
}
|
|
|
|
func main() {
|
|
err := _main()
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
os.Exit(1)
|
|
}
|
|
}
|