rmfakecloud-proxy/main.go

177 lines
3.9 KiB
Go
Raw Normal View History

2021-02-09 17:39:44 +01:00
//go:generate go run generate/versioninfo.go
2018-08-18 00:03:05 +08:00
// secure is a super simple TLS termination proxy
2018-08-17 00:39:01 +08:00
package main
import (
2018-08-18 00:03:05 +08:00
"context"
2018-08-17 00:39:01 +08:00
"flag"
"fmt"
2021-03-21 10:17:58 +01:00
"gopkg.in/yaml.v3"
"io/ioutil"
2018-08-18 00:03:05 +08:00
"log"
2018-08-17 00:39:01 +08:00
"net/http"
2018-08-18 00:03:05 +08:00
"net/http/httputil"
"net/url"
2018-08-17 00:39:01 +08:00
"os"
2018-08-18 00:03:05 +08:00
"os/signal"
"path/filepath"
"strings"
2018-08-17 00:39:01 +08:00
"syscall"
)
2021-03-21 10:17:58 +01:00
type Config struct {
CertFile string `yaml:"certfile"`
KeyFile string `yaml:"keyfile"`
Upstream string `yaml:"upstream"`
Addr string `yaml:"addr"`
}
2018-08-17 00:39:01 +08:00
var (
2021-03-21 10:17:58 +01:00
version bool
configFile string
2018-08-17 00:39:01 +08:00
)
2021-03-21 10:17:58 +01:00
func getConfig() (config *Config, err error) {
cfg := Config{}
flag.StringVar(&configFile, "c", "", "config file")
flag.StringVar(&cfg.Addr, "addr", ":443", "listen address")
flag.StringVar(&cfg.CertFile, "cert", "", "path to cert file")
flag.StringVar(&cfg.KeyFile, "key", "", "path to key file")
2018-08-18 01:31:45 +08:00
flag.BoolVar(&version, "version", false, "print version string and exit")
2018-08-18 00:03:05 +08:00
flag.Usage = func() {
fmt.Fprintf(flag.CommandLine.Output(),
2021-03-21 10:17:58 +01:00
"usage: %s -c [config.yml] [-addr host:port] -cert certfile -key keyfile [-version] upstream\n",
2018-08-18 00:03:05 +08:00
filepath.Base(os.Args[0]))
flag.PrintDefaults()
fmt.Fprintln(flag.CommandLine.Output(), " upstream string\n \tupstream url")
}
2018-08-17 00:39:01 +08:00
flag.Parse()
2018-08-18 01:31:45 +08:00
if version {
fmt.Fprintln(flag.CommandLine.Output(), Version)
os.Exit(0)
}
2021-03-21 10:17:58 +01:00
if configFile != "" {
var data []byte
data, err = ioutil.ReadFile(configFile)
if err != nil {
return
}
err = yaml.Unmarshal(data, &cfg)
if err != nil {
return nil, fmt.Errorf("cant parse config, %v", err)
}
return &cfg, nil
}
2018-08-18 00:03:05 +08:00
if flag.NArg() == 1 {
2021-03-21 10:17:58 +01:00
cfg.Upstream = flag.Arg(0)
2018-08-18 00:03:05 +08:00
} else {
flag.Usage()
os.Exit(2)
}
2021-03-21 10:17:58 +01:00
return &cfg, nil
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
func joinURLPath(a, b *url.URL) (path, rawpath string) {
if a.RawPath == "" && b.RawPath == "" {
return singleJoiningSlash(a.Path, b.Path), ""
}
// Same as singleJoiningSlash, but uses EscapedPath to determine
// whether a slash should be added
apath := a.EscapedPath()
bpath := b.EscapedPath()
aslash := strings.HasSuffix(apath, "/")
bslash := strings.HasPrefix(bpath, "/")
switch {
case aslash && bslash:
return a.Path + b.Path[1:], apath + bpath[1:]
case !aslash && !bslash:
return a.Path + "/" + b.Path, apath + "/" + bpath
}
return a.Path + b.Path, apath + bpath
}
2021-03-21 10:17:58 +01:00
func _main() error {
cfg, err := getConfig()
if err != nil {
return err
}
upstream, err := url.Parse(cfg.Upstream)
2018-08-17 00:39:01 +08:00
if err != nil {
return fmt.Errorf("invalid upstream address: %v", err)
}
upstreamQuery := upstream.RawQuery
director := func(req *http.Request) {
req.URL.Scheme = upstream.Scheme
req.Host = upstream.Host
req.URL.Host = upstream.Host
req.URL.Path, req.URL.RawPath = joinURLPath(upstream, req.URL)
if upstreamQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = upstreamQuery + req.URL.RawQuery
} else {
req.URL.RawQuery = upstreamQuery + "&" + req.URL.RawQuery
}
if _, ok := req.Header["User-Agent"]; !ok {
// explicitly disable User-Agent so it's not set to default value
req.Header.Set("User-Agent", "")
}
}
2018-08-17 00:39:01 +08:00
srv := http.Server{
Handler: &httputil.ReverseProxy{
Director: director,
},
Addr: cfg.Addr,
2018-08-17 00:39:01 +08:00
}
2018-08-18 00:03:05 +08:00
done := make(chan struct{})
2018-08-17 00:39:01 +08:00
go func() {
sig := make(chan os.Signal, 1)
signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
fmt.Println(<-sig)
if err := srv.Shutdown(context.Background()); err != nil {
2018-08-18 00:03:05 +08:00
fmt.Printf("Shutdown: %v", err)
2018-08-17 00:39:01 +08:00
}
2018-08-18 00:03:05 +08:00
close(done)
2018-08-17 00:39:01 +08:00
}()
log.Printf("cert-file=%s key-file=%s listen-addr=%s upstream-url=%s", cfg.CertFile, cfg.KeyFile, srv.Addr, upstream.String())
2021-03-21 10:17:58 +01:00
if err := srv.ListenAndServeTLS(cfg.CertFile, cfg.KeyFile); err != http.ErrServerClosed {
2018-08-17 00:39:01 +08:00
return fmt.Errorf("ListenAndServeTLS: %v", err)
}
2018-08-18 00:03:05 +08:00
<-done
2018-08-17 00:39:01 +08:00
return nil
}
func main() {
err := _main()
if err != nil {
fmt.Println(err)
os.Exit(1)
}
}