diff --git a/README.md b/README.md index e69de29..244e220 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,148 @@ +# Finger + +Webfinger server written in Go. + +## Features +- 🍰 Easy YAML configuration +- 🪶 Single 8MB binary / 0% idle CPU / 4MB idle RAM +- ⚡️ Sub millisecond responses at 10,000 request per second +- 🐳 10MB Docker image + +## Install + +Via `go install`: + +```bash +go install git.maronato.dev/maronato/finger@latest +``` + +Via Docker: + +```bash +docker run --name finger / + -p 8080:8080 / + git.maronato.dev/maronato/finger +``` + +## Usage + +If you installed it using `go install`, run +```bash +finger serve +``` +To start the server on port `8080`. Your resources will be queryable via `locahost:8080/.well-known/webfinger?resource=` + +If you're using Docker, the use the same command in the install section. + +By default, no resources will be exposed. You can create resources via a `fingers.yml` file. It should contain a collection of resources as keys and their attributes as their objects. + +Some default URN aliases are provided via the built-in mapping ([`urns.yml`](./urns.yml)). You can replace that with your own or use URNs directly in the `fingers.yml` file. + +Here's an example: +```yaml +# fingers.yml + +# Resources go in the root of the file. Email address will have the acct: +# prefix added automatically. +alice@example.com: + # "avatar" is an alias of "http://webfinger.net/rel/avatar" + # (see urns.yml for more) + avatar: "https://example.com/alice-pic" + + # If the value is a URI, it'll be exposed as a webfinger link + openid: "https://sso.example.com/" + + # If the value of the attribute is not a URI, it will be exposed as a + # webfinger property + name: "Alice Doe" + + # You can also specify URN's directly instead of the aliases + http://webfinger.net/rel/profile-page: "https://example.com/user/alice" + +bob@example.com: + name: Bob Foo + openid: "https://sso.example.com/" + +# Resources can also be URIs +https://example.com/user/charlie: + name: Charlie Baz + profile: https://example.com/user/charlie +``` + +### Example queries +
+Query Alice
GET http://localhost:8080/.well-known/webfinger?resource=acct:alice@example.com
+ +```json +{ + "subject": "acct:alice@example.com", + "links": [ + { + "rel": "avatar", + "href": "https://example.com/alice-pic" + }, + { + "rel": "openid", + "href": "https://sso.example.com/" + }, + { + "rel": "name", + "href": "Alice Doe" + }, + { + "rel": "http://webfinger.net/rel/profile-page", + "href": "https://example.com/user/alice" + } + ] +} +``` +
+ + +
+Query Bob
GET http://localhost:8080/.well-known/webfinger?resource=acct:bob@example.com
+ +```json +{ + "subject": "acct:bob@example.com", + "links": [ + { + "rel": "name", + "href": "Bob Foo" + }, + { + "rel": "openid", + "href": "https://sso.example.com/" + } + ] +} +``` +
+ + +
+Query Charlie
GET http://localhost:8080/.well-known/webfinger?resource=https://example.com/user/charlie
+ +```JSON +{ + "subject": "https://example.com/user/charlie", + "links": [ + { + "rel": "name", + "href": "Charlie Baz" + }, + { + "rel": "profile", + "href": "https://example.com/user/charlie" + } + ] +} +``` +
+ +## Configs +Here are the config options available. You can change them via command line flags or environment variables: + +| CLI flag | Env variable | Default | Description | +| -------- | ------------ | ------- | ----------- | +| fdsfds | gsfgfs | fgfsdgf | gdfsgdf | diff --git a/cmd/cmd.go b/cmd/cmd.go new file mode 100644 index 0000000..a309341 --- /dev/null +++ b/cmd/cmd.go @@ -0,0 +1,98 @@ +package cmd + +import ( + "context" + "errors" + "fmt" + "os" + "os/signal" + "syscall" + + "git.maronato.dev/maronato/finger/internal/config" + "github.com/peterbourgon/ff/v4" + "github.com/peterbourgon/ff/v4/ffhelp" +) + +func Run(version string) error { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Allow graceful shutdown + trapSignalsCrossPlatform(cancel) + + cfg := &config.Config{} + + // Create a new root command + subcommands := []*ff.Command{ + newServerCmd(cfg), + newHealthcheckCmd(cfg), + } + cmd := newRootCmd(version, cfg, subcommands) + + // Parse and run + if err := cmd.ParseAndRun(ctx, os.Args[1:], ff.WithEnvVarPrefix("WF")); err != nil { + if errors.Is(err, ff.ErrHelp) || errors.Is(err, ff.ErrNoExec) { + fmt.Fprintf(os.Stderr, "\n%s\n", ffhelp.Command(cmd)) + + return nil + } + + return fmt.Errorf("error running command: %w", err) + } + + return nil +} + +// https://github.com/caddyserver/caddy/blob/fbb0ecfa322aa7710a3448453fd3ae40f037b8d1/sigtrap.go#L37 +// trapSignalsCrossPlatform captures SIGINT or interrupt (depending +// on the OS), which initiates a graceful shutdown. A second SIGINT +// or interrupt will forcefully exit the process immediately. +func trapSignalsCrossPlatform(cancel context.CancelFunc) { + go func() { + shutdown := make(chan os.Signal, 1) + signal.Notify(shutdown, os.Interrupt, syscall.SIGINT) + + for i := 0; true; i++ { + <-shutdown + + if i > 0 { + fmt.Printf("\nForce quit\n") //nolint:forbidigo // We want to print to stdout + os.Exit(1) + } + + fmt.Printf("\nGracefully shutting down. Press Ctrl+C again to force quit\n") //nolint:forbidigo // We want to print to stdout + cancel() + } + }() +} + +// NewRootCmd parses the command line flags and returns a config.Config struct. +func newRootCmd(version string, cfg *config.Config, subcommands []*ff.Command) *ff.Command { + fs := ff.NewFlagSet(appName) + + for _, cmd := range subcommands { + cmd.Flags = ff.NewFlagSet(cmd.Name).SetParent(fs) + } + + cmd := &ff.Command{ + Name: appName, + Usage: fmt.Sprintf("%s [flags]", appName), + ShortHelp: fmt.Sprintf("(%s) A webfinger server", version), + Flags: fs, + Subcommands: subcommands, + } + + // Use 0.0.0.0 as the default host if on docker + defaultHost := "localhost" + if os.Getenv("ENV_DOCKER") == "true" { + defaultHost = "0.0.0.0" + } + + fs.BoolVar(&cfg.Debug, 'd', "debug", "Enable debug logging") + fs.StringVar(&cfg.Host, 'h', "host", defaultHost, "Host to listen on") + fs.StringVar(&cfg.Port, 'p', "port", "8080", "Port to listen on") + fs.StringVar(&cfg.URNPath, 'u', "urn-file", "urns.yml", "Path to the URNs file") + fs.StringVar(&cfg.FingerPath, 'f', "finger-file", "fingers.yml", "Path to the fingers file") + + return cmd +} diff --git a/cmd/healthcheck.go b/cmd/healthcheck.go new file mode 100644 index 0000000..9ec16b5 --- /dev/null +++ b/cmd/healthcheck.go @@ -0,0 +1,53 @@ +package cmd + +import ( + "context" + "fmt" + "net/http" + "net/url" + "time" + + "git.maronato.dev/maronato/finger/internal/config" + "github.com/peterbourgon/ff/v4" +) + +func newHealthcheckCmd(cfg *config.Config) *ff.Command { + return &ff.Command{ + Name: "healthcheck", + Usage: "healthcheck [flags]", + ShortHelp: "Check if the server is running", + Exec: func(ctx context.Context, args []string) error { + // Create a new client + client := &http.Client{ + Timeout: 5 * time.Second, //nolint:gomnd // We want to use a constant + } + + // Create a new request + reqURL := url.URL{ + Scheme: "http", + Host: cfg.GetAddr(), + Path: "/healthz", + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), http.NoBody) + if err != nil { + return fmt.Errorf("error creating request: %w", err) + } + + // Send the request + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("error sending request: %w", err) + } + + defer resp.Body.Close() + + // Check the response + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("server returned status %d", resp.StatusCode) //nolint:goerr113 // We want to return an error + } + + return nil + }, + } +} diff --git a/cmd/serve.go b/cmd/serve.go new file mode 100644 index 0000000..544376a --- /dev/null +++ b/cmd/serve.go @@ -0,0 +1,49 @@ +package cmd + +import ( + "context" + "fmt" + "os" + + "git.maronato.dev/maronato/finger/internal/config" + "git.maronato.dev/maronato/finger/internal/log" + "git.maronato.dev/maronato/finger/internal/server" + "git.maronato.dev/maronato/finger/internal/webfinger" + "github.com/peterbourgon/ff/v4" +) + +const appName = "finger" + +func newServerCmd(cfg *config.Config) *ff.Command { + return &ff.Command{ + Name: "serve", + Usage: "serve [flags]", + ShortHelp: "Start the webfinger server", + Exec: func(ctx context.Context, args []string) error { + // Create a logger and add it to the context + l := log.NewLogger(os.Stderr, cfg) + ctx = log.WithLogger(ctx, l) + + // Read the webfinger files + r := webfinger.NewFingerReader() + err := r.ReadFiles(cfg) + if err != nil { + return fmt.Errorf("error reading finger files: %w", err) + } + + webfingers, err := r.ReadFingerFile(ctx) + if err != nil { + return fmt.Errorf("error parsing finger files: %w", err) + } + + l.Info(fmt.Sprintf("Loaded %d webfingers", len(webfingers))) + + // Start the server + if err := server.StartServer(ctx, cfg, webfingers); err != nil { + return fmt.Errorf("error running server: %w", err) + } + + return nil + }, + } +} diff --git a/go.mod b/go.mod index 64a2775..21733cd 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.21.0 require ( github.com/peterbourgon/ff/v4 v4.0.0-alpha.3 - golang.org/x/exp v0.0.0-20230905200255-921286631fa9 golang.org/x/sync v0.3.0 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index 02ae10a..884cb2e 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,6 @@ github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= github.com/peterbourgon/ff/v4 v4.0.0-alpha.3 h1:fpyiFVEJvxIFljxM4l5ANSk/UGlM1gyU+hPAr9jhB7M= github.com/peterbourgon/ff/v4 v4.0.0-alpha.3/go.mod h1:H/13DK46DKXy7EaIxPhk2Y0EC8aubKm35nBjBe8AAGc= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..7a6755d --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,67 @@ +package config + +import ( + "errors" + "fmt" + "net" + "net/url" +) + +const ( + // DefaultHost is the default host to listen on. + DefaultHost = "localhost" + // DefaultPort is the default port to listen on. + DefaultPort = "8080" + // DefaultURNPath is the default file path to the URN alias file. + DefaultURNPath = "urns.yml" + // DefaultFingerPath is the default file path to the webfinger definition file. + DefaultFingerPath = "finger.yml" +) + +// ErrInvalidConfig is returned when the config is invalid. +var ErrInvalidConfig = errors.New("invalid config") + +type Config struct { + Debug bool + Host string + Port string + URNPath string + FingerPath string +} + +func NewConfig() *Config { + return &Config{ + Host: DefaultHost, + Port: DefaultPort, + URNPath: DefaultURNPath, + FingerPath: DefaultFingerPath, + } +} + +func (c *Config) GetAddr() string { + return net.JoinHostPort(c.Host, c.Port) +} + +func (c *Config) Validate() error { + if c.Host == "" { + return fmt.Errorf("%w: host is empty", ErrInvalidConfig) + } + + if c.Port == "" { + return fmt.Errorf("%w: port is empty", ErrInvalidConfig) + } + + if _, err := url.Parse(c.GetAddr()); err != nil { + return fmt.Errorf("%w: %w", ErrInvalidConfig, err) + } + + if c.URNPath == "" { + return fmt.Errorf("%w: urn path is empty", ErrInvalidConfig) + } + + if c.FingerPath == "" { + return fmt.Errorf("%w: finger path is empty", ErrInvalidConfig) + } + + return nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..85e35fe --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,124 @@ +package config_test + +import ( + "testing" + + "git.maronato.dev/maronato/finger/internal/config" +) + +func TestConfig_GetAddr(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *config.Config + want string + }{ + { + name: "default", + cfg: config.NewConfig(), + want: "localhost:8080", + }, + { + name: "custom", + cfg: &config.Config{ + Host: "example.com", + Port: "1234", + }, + want: "example.com:1234", + }, + } + + for _, tt := range tests { + tc := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := tc.cfg.GetAddr() + if got != tc.want { + t.Errorf("Config.GetAddr() = %v, want %v", got, tc.want) + } + }) + } +} + +func TestConfig_Validate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *config.Config + wantErr bool + }{ + { + name: "default", + cfg: config.NewConfig(), + wantErr: false, + }, + { + name: "empty host", + cfg: &config.Config{ + Host: "", + Port: "1234", + }, + wantErr: true, + }, + { + name: "empty port", + cfg: &config.Config{ + Host: "example.com", + Port: "", + }, + wantErr: true, + }, + { + name: "invalid addr", + cfg: &config.Config{ + Host: "example.com", + Port: "invalid", + }, + wantErr: true, + }, + { + name: "empty urn path", + cfg: &config.Config{ + Host: "example.com", + Port: "1234", + URNPath: "", + }, + wantErr: true, + }, + { + name: "empty finger path", + cfg: &config.Config{ + Host: "example.com", + Port: "1234", + URNPath: "urns.yml", + FingerPath: "", + }, + wantErr: true, + }, + { + name: "valid", + cfg: &config.Config{ + Host: "example.com", + Port: "1234", + URNPath: "urns.yml", + FingerPath: "finger.yml", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + tc := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := tc.cfg.Validate() + if (err != nil) != tc.wantErr { + t.Errorf("Config.Validate() error = %v, wantErr %v", err, tc.wantErr) + } + }) + } +} diff --git a/internal/log/log.go b/internal/log/log.go new file mode 100644 index 0000000..bd31ea1 --- /dev/null +++ b/internal/log/log.go @@ -0,0 +1,42 @@ +package log + +import ( + "context" + "io" + "log/slog" + + "git.maronato.dev/maronato/finger/internal/config" +) + +type loggerCtxKey struct{} + +// NewLogger creates a new logger with the given debug level. +func NewLogger(w io.Writer, cfg *config.Config) *slog.Logger { + level := slog.LevelInfo + addSource := false + + if cfg.Debug { + level = slog.LevelDebug + addSource = true + } + + return slog.New( + slog.NewJSONHandler(w, &slog.HandlerOptions{ + Level: level, + AddSource: addSource, + }), + ) +} + +func FromContext(ctx context.Context) *slog.Logger { + l, ok := ctx.Value(loggerCtxKey{}).(*slog.Logger) + if !ok { + panic("logger not found in context") + } + + return l +} + +func WithLogger(ctx context.Context, l *slog.Logger) context.Context { + return context.WithValue(ctx, loggerCtxKey{}, l) +} diff --git a/internal/log/log_test.go b/internal/log/log_test.go new file mode 100644 index 0000000..ffee3db --- /dev/null +++ b/internal/log/log_test.go @@ -0,0 +1,95 @@ +package log_test + +import ( + "context" + "strings" + "testing" + + "git.maronato.dev/maronato/finger/internal/config" + "git.maronato.dev/maronato/finger/internal/log" +) + +func assertPanic(t *testing.T, f func()) { + t.Helper() + + defer func() { + if r := recover(); r == nil { + t.Errorf("The code did not panic") + } + }() + + // Call the function + f() +} + +func TestNewLogger(t *testing.T) { + t.Parallel() + + t.Run("defaults to info level", func(t *testing.T) { + t.Parallel() + + cfg := config.NewConfig() + + w := &strings.Builder{} + l := log.NewLogger(w, cfg) + + // It shouldn't log debug messages + l.Debug("test") + + if w.String() != "" { + t.Error("logger logged debug message") + } + + // It should log info messages + l.Info("test") + + if w.String() == "" { + t.Error("logger did not log info message") + } + }) + + t.Run("logs debug messages if debug is enabled", func(t *testing.T) { + t.Parallel() + + cfg := config.NewConfig() + cfg.Debug = true + + w := &strings.Builder{} + l := log.NewLogger(w, cfg) + + // It should log debug messages + l.Debug("test") + + if w.String() == "" { + t.Error("logger did not log debug message") + } + }) +} + +func TestFromContext(t *testing.T) { + t.Parallel() + + ctx := context.Background() + cfg := config.NewConfig() + l := log.NewLogger(nil, cfg) + + t.Run("panics if no logger in context", func(t *testing.T) { + t.Parallel() + + assertPanic(t, func() { + log.FromContext(ctx) + }) + }) + + t.Run("returns logger from context", func(t *testing.T) { + t.Parallel() + + ctx = log.WithLogger(ctx, l) + + l2 := log.FromContext(ctx) + + if l2 == nil { + t.Error("logger is nil") + } + }) +} diff --git a/internal/middleware/log.go b/internal/middleware/log.go new file mode 100644 index 0000000..6630a21 --- /dev/null +++ b/internal/middleware/log.go @@ -0,0 +1,44 @@ +package middleware + +import ( + "log/slog" + "net/http" + "time" + + "git.maronato.dev/maronato/finger/internal/log" +) + +func RequestLogger(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + l := log.FromContext(ctx) + + start := time.Now() + + // Wrap the response writer + wrapped := WrapResponseWriter(w) + + // Call the next handler + next.ServeHTTP(wrapped, r) + + status := wrapped.Status() + + // Log the request + lg := l.With( + slog.String("method", r.Method), + slog.String("path", r.URL.Path), + slog.Int("status", status), + slog.String("remote", r.RemoteAddr), + slog.Duration("duration", time.Since(start)), + ) + + switch { + case status >= http.StatusInternalServerError: + lg.Error("Server error") + case status >= http.StatusBadRequest: + lg.Info("Client error") + default: + lg.Info("Request completed") + } + }) +} diff --git a/internal/middleware/log_test.go b/internal/middleware/log_test.go new file mode 100644 index 0000000..bfd0c3c --- /dev/null +++ b/internal/middleware/log_test.go @@ -0,0 +1,44 @@ +package middleware_test + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "git.maronato.dev/maronato/finger/internal/config" + "git.maronato.dev/maronato/finger/internal/log" + "git.maronato.dev/maronato/finger/internal/middleware" +) + +func TestRequestLogger(t *testing.T) { + t.Parallel() + + ctx := context.Background() + cfg := config.NewConfig() + + stdout := &strings.Builder{} + + l := log.NewLogger(stdout, cfg) + ctx = log.WithLogger(ctx, l) + + w := httptest.NewRecorder() + r, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/", http.NoBody) + + if stdout.String() != "" { + t.Error("logger logged before request") + } + + middleware.RequestLogger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })).ServeHTTP(w, r) + + if w.Code != http.StatusOK { + t.Error("status is not 200") + } + + if stdout.String() == "" { + t.Error("logger did not log request") + } +} diff --git a/internal/middleware/recover.go b/internal/middleware/recover.go new file mode 100644 index 0000000..282ae87 --- /dev/null +++ b/internal/middleware/recover.go @@ -0,0 +1,27 @@ +package middleware + +import ( + "log/slog" + "net/http" + + "git.maronato.dev/maronato/finger/internal/log" +) + +func Recoverer(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + l := log.FromContext(ctx) + + defer func() { + err := recover() + if err != nil { + l.Error("Panic", slog.Any("error", err)) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + + return + } + }() + + next.ServeHTTP(w, r) + }) +} diff --git a/internal/middleware/recover_test.go b/internal/middleware/recover_test.go new file mode 100644 index 0000000..4586a07 --- /dev/null +++ b/internal/middleware/recover_test.go @@ -0,0 +1,76 @@ +package middleware_test + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "git.maronato.dev/maronato/finger/internal/config" + "git.maronato.dev/maronato/finger/internal/log" + "git.maronato.dev/maronato/finger/internal/middleware" +) + +func assertNoPanic(t *testing.T, f func()) { + t.Helper() + + defer func() { + if r := recover(); r != nil { + t.Error("function panicked") + } + }() + + f() +} + +func TestRecoverer(t *testing.T) { + t.Parallel() + + ctx := context.Background() + cfg := config.NewConfig() + l := log.NewLogger(&strings.Builder{}, cfg) + ctx = log.WithLogger(ctx, l) + + t.Run("handles panics", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + r, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/", http.NoBody) + + h := middleware.Recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("test") + })) + + assertNoPanic(t, func() { + h.ServeHTTP(w, r) + }) + + if w.Code != http.StatusInternalServerError { + t.Error("status is not 500") + } + + if w.Body.String() != "Internal Server Error\n" { + t.Error("response body is not 'Internal Server Error'") + } + }) + + t.Run("handles successful requests", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + r, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/", http.NoBody) + + h := middleware.Recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + assertNoPanic(t, func() { + h.ServeHTTP(w, r) + }) + + if w.Code != http.StatusOK { + t.Error("status is not 200") + } + }) +} diff --git a/internal/middleware/wrapper.go b/internal/middleware/wrapper.go new file mode 100644 index 0000000..dbb1ace --- /dev/null +++ b/internal/middleware/wrapper.go @@ -0,0 +1,42 @@ +package middleware + +import ( + "fmt" + "net/http" +) + +type ResponseWrapper struct { + http.ResponseWriter + + status int +} + +func WrapResponseWriter(w http.ResponseWriter) *ResponseWrapper { + return &ResponseWrapper{w, 0} +} + +func (w *ResponseWrapper) WriteHeader(code int) { + w.status = code + w.ResponseWriter.WriteHeader(code) +} + +func (w *ResponseWrapper) Status() int { + return w.status +} + +func (w *ResponseWrapper) Write(b []byte) (int, error) { + if w.status == 0 { + w.status = http.StatusOK + } + + size, err := w.ResponseWriter.Write(b) + if err != nil { + return 0, fmt.Errorf("error writing response: %w", err) + } + + return size, nil +} + +func (w *ResponseWrapper) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} diff --git a/internal/middleware/wrapper_test.go b/internal/middleware/wrapper_test.go new file mode 100644 index 0000000..81a5e32 --- /dev/null +++ b/internal/middleware/wrapper_test.go @@ -0,0 +1,97 @@ +package middleware_test + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "git.maronato.dev/maronato/finger/internal/middleware" +) + +func TestWrapResponseWriter(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + wrapped := middleware.WrapResponseWriter(w) + + if wrapped == nil { + t.Error("wrapper is nil") + } +} + +func TestResponseWrapper_Status(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + wrapped := middleware.WrapResponseWriter(w) + + if wrapped.Status() != 0 { + t.Error("status is not 0") + } + + wrapped.WriteHeader(http.StatusOK) + + if wrapped.Status() != http.StatusOK { + t.Error("status is not 200") + } +} + +type FailWriter struct{} + +func (w *FailWriter) Write(b []byte) (int, error) { + return 0, fmt.Errorf("error") +} + +func (w *FailWriter) Header() http.Header { + return http.Header{} +} + +func (w *FailWriter) WriteHeader(_ int) {} + +func TestResponseWrapper_Write(t *testing.T) { + t.Parallel() + + t.Run("writes success messages", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + wrapped := middleware.WrapResponseWriter(w) + + size, err := wrapped.Write([]byte("test")) + if err != nil { + t.Errorf("error writing response: %v", err) + } + + if size != 4 { + t.Error("size is not 4") + } + + if wrapped.Status() != http.StatusOK { + t.Error("status is not 200") + } + }) + + t.Run("returns error on fail write", func(t *testing.T) { + t.Parallel() + + w := &FailWriter{} + wrapped := middleware.WrapResponseWriter(w) + + _, err := wrapped.Write([]byte("test")) + if err == nil { + t.Error("error is nil") + } + }) +} + +func TestResponseWrapper_Unwrap(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + wrapped := middleware.WrapResponseWriter(w) + + if wrapped.Unwrap() != w { + t.Error("unwrapped response is not the same") + } +} diff --git a/internal/server/healthcheck.go b/internal/server/healthcheck.go new file mode 100644 index 0000000..b75a214 --- /dev/null +++ b/internal/server/healthcheck.go @@ -0,0 +1,13 @@ +package server + +import ( + "net/http" + + "git.maronato.dev/maronato/finger/internal/config" +) + +func HealthCheckHandler(_ *config.Config) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) +} diff --git a/internal/server/healthcheck_test.go b/internal/server/healthcheck_test.go new file mode 100644 index 0000000..e58497c --- /dev/null +++ b/internal/server/healthcheck_test.go @@ -0,0 +1,40 @@ +package server_test + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "git.maronato.dev/maronato/finger/internal/config" + "git.maronato.dev/maronato/finger/internal/log" + "git.maronato.dev/maronato/finger/internal/server" +) + +func TestHealthcheckHandler(t *testing.T) { + t.Parallel() + + ctx := context.Background() + cfg := config.NewConfig() + l := log.NewLogger(&strings.Builder{}, cfg) + + ctx = log.WithLogger(ctx, l) + + // Create a new request + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/healthz", http.NoBody) + + // Create a new recorder + rec := httptest.NewRecorder() + + // Create a new handler + h := server.HealthCheckHandler(cfg) + + // Serve the request + h.ServeHTTP(rec, req) + + // Check the status code + if rec.Code != http.StatusOK { + t.Errorf("expected status code %d, got %d", http.StatusOK, rec.Code) + } +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..5503632 --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,100 @@ +package server + +import ( + "context" + "fmt" + "log/slog" + "net" + "net/http" + "time" + + "git.maronato.dev/maronato/finger/internal/config" + "git.maronato.dev/maronato/finger/internal/log" + "git.maronato.dev/maronato/finger/internal/middleware" + "git.maronato.dev/maronato/finger/internal/webfinger" + "golang.org/x/sync/errgroup" +) + +const ( + // ReadTimeout is the maximum duration for reading the entire + // request, including the body. + ReadTimeout = 5 * time.Second + // WriteTimeout is the maximum duration before timing out + // writes of the response. + WriteTimeout = 10 * time.Second + // IdleTimeout is the maximum amount of time to wait for the + // next request when keep-alives are enabled. + IdleTimeout = 30 * time.Second + // ReadHeaderTimeout is the amount of time allowed to read + // request headers. + ReadHeaderTimeout = 2 * time.Second + // RequestTimeout is the maximum duration for the entire + // request. + RequestTimeout = 7 * 24 * time.Hour +) + +func StartServer(ctx context.Context, cfg *config.Config, webfingers webfinger.WebFingers) error { + l := log.FromContext(ctx) + + // Create the server mux + mux := http.NewServeMux() + mux.Handle("/.well-known/webfinger", WebfingerHandler(cfg, webfingers)) + mux.Handle("/healthz", HealthCheckHandler(cfg)) + + // Create a new server + srv := &http.Server{ + Addr: cfg.GetAddr(), + BaseContext: func(_ net.Listener) context.Context { + return ctx + }, + Handler: middleware.RequestLogger( + middleware.Recoverer( + http.TimeoutHandler(mux, RequestTimeout, "request timed out"), + ), + ), + ReadHeaderTimeout: ReadHeaderTimeout, + ReadTimeout: ReadTimeout, + WriteTimeout: WriteTimeout, + IdleTimeout: IdleTimeout, + } + + // Create the errorgroup that will manage the server execution + eg, egCtx := errgroup.WithContext(ctx) + + // Start the server + eg.Go(func() error { + l.Info("Starting server", slog.String("addr", srv.Addr)) + + // Use the global context for the server + srv.BaseContext = func(_ net.Listener) context.Context { + return egCtx + } + + return srv.ListenAndServe() //nolint:wrapcheck // We wrap the error in the errgroup + }) + // Gracefully shutdown the server when the context is done + eg.Go(func() error { + // Wait for the context to be done + <-egCtx.Done() + + l.Info("Shutting down server") + // Disable the cancel since we don't wan't to force + // the server to shutdown if the context is canceled. + noCancelCtx := context.WithoutCancel(egCtx) + + return srv.Shutdown(noCancelCtx) //nolint:wrapcheck // We wrap the error in the errgroup + }) + + // Log when the server is fully shutdown + srv.RegisterOnShutdown(func() { + l.Info("Server shutdown complete") + }) + + // Wait for the server to exit and check for errors that + // are not caused by the context being canceled. + if err := eg.Wait(); err != nil && ctx.Err() == nil { + return fmt.Errorf("server exited with error: %w", err) + } + + return nil +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..a901137 --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,206 @@ +package server_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "reflect" + "strings" + "sync" + "testing" + "time" + + "git.maronato.dev/maronato/finger/internal/config" + "git.maronato.dev/maronato/finger/internal/log" + "git.maronato.dev/maronato/finger/internal/server" + "git.maronato.dev/maronato/finger/internal/webfinger" +) + +func getPortGenerator() func() int { + lock := &sync.Mutex{} + port := 8080 + + return func() int { + lock.Lock() + defer lock.Unlock() + + port++ + + return port + } +} + +func TestStartServer(t *testing.T) { + t.Parallel() + + portGenerator := getPortGenerator() + + t.Run("starts and shuts down", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() + + cfg := config.NewConfig() + l := log.NewLogger(&strings.Builder{}, cfg) + + ctx = log.WithLogger(ctx, l) + + // Use a new port + cfg.Port = fmt.Sprint(portGenerator()) + + // Start the server + err := server.StartServer(ctx, cfg, nil) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + }) + + t.Run("fails to start", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() + + cfg := config.NewConfig() + l := log.NewLogger(&strings.Builder{}, cfg) + + ctx = log.WithLogger(ctx, l) + + // Use a new port + cfg.Port = fmt.Sprint(portGenerator()) + + // Use invalid host + cfg.Host = "google.com" + + // Start the server + err := server.StartServer(ctx, cfg, nil) + if err == nil { + t.Errorf("expected error, got nil") + } + }) + + t.Run("serves webfinger", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200) + defer cancel() + + cfg := config.NewConfig() + l := log.NewLogger(&strings.Builder{}, cfg) + + ctx = log.WithLogger(ctx, l) + + // Use a new port + cfg.Port = fmt.Sprint(portGenerator()) + + resource := "acct:user@example.com" + webfingers := webfinger.WebFingers{ + resource: &webfinger.WebFinger{ + Subject: resource, + Properties: map[string]string{ + "http://webfinger.net/rel/name": "John Doe", + }, + }, + } + + go func() { + // Start the server + err := server.StartServer(ctx, cfg, webfingers) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + }() + + // Wait for the server to start + time.Sleep(time.Millisecond * 50) + + // Create a new client + c := http.Client{} + + // Create a new request + r, _ := http.NewRequestWithContext(ctx, + http.MethodGet, + "http://"+cfg.GetAddr()+"/.well-known/webfinger?resource=acct:user@example.com", + http.NoBody, + ) + + // Send the request + resp, err := c.Do(r) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + + defer resp.Body.Close() + + // Check the status code + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + + // Check the response body + fingerGot := &webfinger.WebFinger{} + + // Decode the response body + if err := json.NewDecoder(resp.Body).Decode(fingerGot); err != nil { + t.Errorf("error decoding json: %v", err) + } + + // Check the response body + fingerWant := webfingers[resource] + + if !reflect.DeepEqual(fingerGot, fingerWant) { + t.Errorf("expected %v, got %v", fingerWant, fingerGot) + } + }) + + t.Run("serves healthcheck", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200) + defer cancel() + + cfg := config.NewConfig() + l := log.NewLogger(&strings.Builder{}, cfg) + + ctx = log.WithLogger(ctx, l) + + // Use a new port + cfg.Port = fmt.Sprint(portGenerator()) + + go func() { + // Start the server + err := server.StartServer(ctx, cfg, nil) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + }() + + // Wait for the server to start + time.Sleep(time.Millisecond * 50) + + // Create a new client + c := http.Client{} + + // Create a new request + r, _ := http.NewRequestWithContext(ctx, + http.MethodGet, + "http://"+cfg.GetAddr()+"/healthz", + http.NoBody, + ) + + // Send the request + resp, err := c.Do(r) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + + defer resp.Body.Close() + + // Check the status code + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status code %d, got %d", http.StatusOK, resp.StatusCode) + } + }) +} diff --git a/internal/server/webfinger.go b/internal/server/webfinger.go new file mode 100644 index 0000000..56e7cf8 --- /dev/null +++ b/internal/server/webfinger.go @@ -0,0 +1,59 @@ +package server + +import ( + "encoding/json" + "net/http" + + "git.maronato.dev/maronato/finger/internal/config" + "git.maronato.dev/maronato/finger/internal/log" + "git.maronato.dev/maronato/finger/internal/webfinger" +) + +func WebfingerHandler(_ *config.Config, webfingers webfinger.WebFingers) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + l := log.FromContext(ctx) + + // Only handle GET requests + if r.Method != http.MethodGet { + l.Debug("Method not allowed") + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + + return + } + + // Get the query params + q := r.URL.Query() + + // Get the resource + resource := q.Get("resource") + if resource == "" { + l.Debug("No resource provided") + http.Error(w, "No resource provided", http.StatusBadRequest) + + return + } + + // Get and validate resource + finger, ok := webfingers[resource] + if !ok { + l.Debug("Resource not found") + http.Error(w, "Resource not found", http.StatusNotFound) + + return + } + + // Set the content type + w.Header().Set("Content-Type", "application/jrd+json") + + // Write the response + if err := json.NewEncoder(w).Encode(finger); err != nil { + l.Debug("Error encoding json") + http.Error(w, "Error encoding json", http.StatusInternalServerError) + + return + } + + l.Debug("Webfinger request successful") + }) +} diff --git a/internal/server/webfinger_test.go b/internal/server/webfinger_test.go new file mode 100644 index 0000000..c887a18 --- /dev/null +++ b/internal/server/webfinger_test.go @@ -0,0 +1,149 @@ +package server_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "reflect" + "sort" + "strings" + "testing" + + "git.maronato.dev/maronato/finger/internal/config" + "git.maronato.dev/maronato/finger/internal/log" + "git.maronato.dev/maronato/finger/internal/server" + "git.maronato.dev/maronato/finger/internal/webfinger" +) + +func TestWebfingerHandler(t *testing.T) { + t.Parallel() + + webfingers := webfinger.WebFingers{ + "acct:user@example.com": { + Subject: "acct:user@example.com", + Links: []webfinger.Link{ + { + Rel: "http://webfinger.net/rel/profile-page", + Href: "https://example.com/user", + }, + }, + Properties: map[string]string{ + "http://webfinger.net/rel/name": "John Doe", + }, + }, + "acct:other@example.com": { + Subject: "acct:other@example.com", + Properties: map[string]string{ + "http://webfinger.net/rel/name": "Jane Doe", + }, + }, + "https://example.com/user": { + Subject: "https://example.com/user", + Properties: map[string]string{ + "http://webfinger.net/rel/name": "John Baz", + }, + }, + } + + tests := []struct { + name string + resource string + wantCode int + alternateMethod string + }{ + { + name: "valid resource", + resource: "acct:user@example.com", + wantCode: http.StatusOK, + }, + { + name: "other valid resource", + resource: "acct:other@example.com", + wantCode: http.StatusOK, + }, + { + name: "url resource", + resource: "https://example.com/user", + wantCode: http.StatusOK, + }, + { + name: "resource missing acct:", + resource: "user@example.com", + wantCode: http.StatusNotFound, + }, + { + name: "resource missing", + resource: "", + wantCode: http.StatusBadRequest, + }, + { + name: "invalid method", + resource: "acct:user@example.com", + wantCode: http.StatusMethodNotAllowed, + alternateMethod: http.MethodPost, + }, + } + + for _, tt := range tests { + tc := tt + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + cfg := config.NewConfig() + l := log.NewLogger(&strings.Builder{}, cfg) + + ctx = log.WithLogger(ctx, l) + + // Create a new request + r, _ := http.NewRequestWithContext(ctx, tc.alternateMethod, "/.well-known/webfinger?resource="+tc.resource, http.NoBody) + + // Create a new response + w := httptest.NewRecorder() + + // Create a new handler + h := server.WebfingerHandler(cfg, webfingers) + + // Serve the request + h.ServeHTTP(w, r) + + // Check the status code + if w.Code != tc.wantCode { + t.Errorf("expected status code %d, got %d", tc.wantCode, w.Code) + } + + // If the status code is 200, check the response body + if tc.wantCode == http.StatusOK { + // Check the content type + if w.Header().Get("Content-Type") != "application/jrd+json" { + t.Errorf("expected content type %s, got %s", "application/jrd+json", w.Header().Get("Content-Type")) + } + + fingerWant := webfingers[tc.resource] + fingerGot := &webfinger.WebFinger{} + + // Decode the response body + if err := json.NewDecoder(w.Body).Decode(fingerGot); err != nil { + t.Errorf("error decoding json: %v", err) + } + + // Sort links + + sort.Slice(fingerGot.Links, func(i, j int) bool { + return fingerGot.Links[i].Rel < fingerGot.Links[j].Rel + }) + + sort.Slice(fingerWant.Links, func(i, j int) bool { + return fingerWant.Links[i].Rel < fingerWant.Links[j].Rel + }) + + // Check the response body + if !reflect.DeepEqual(fingerGot, fingerWant) { + t.Errorf("expected body %v, got %v", fingerWant, fingerGot) + } + } + }) + } +} diff --git a/internal/webfinger/webfinger.go b/internal/webfinger/webfinger.go new file mode 100644 index 0000000..784aa5b --- /dev/null +++ b/internal/webfinger/webfinger.go @@ -0,0 +1,160 @@ +package webfinger + +import ( + "context" + "fmt" + "log/slog" + "net/mail" + "net/url" + "os" + + "git.maronato.dev/maronato/finger/internal/config" + "git.maronato.dev/maronato/finger/internal/log" + "gopkg.in/yaml.v3" +) + +type Link struct { + Rel string `json:"rel"` + Href string `json:"href,omitempty"` +} + +type WebFinger struct { + Subject string `json:"subject"` + Links []Link `json:"links,omitempty"` + Properties map[string]string `json:"properties,omitempty"` +} + +type WebFingers map[string]*WebFinger + +type ( + URNMap = map[string]string + RawFingersMap = map[string]map[string]string +) + +type FingerReader struct { + URNSFile []byte + FingersFile []byte +} + +func NewFingerReader() *FingerReader { + return &FingerReader{} +} + +func (f *FingerReader) ReadFiles(cfg *config.Config) error { + // Read URNs file + file, err := os.ReadFile(cfg.URNPath) + if err != nil { + return fmt.Errorf("error opening URNs file: %w", err) + } + + f.URNSFile = file + + // Read fingers file + file, err = os.ReadFile(cfg.FingerPath) + if err != nil { + return fmt.Errorf("error opening fingers file: %w", err) + } + + f.FingersFile = file + + return nil +} + +func (f *FingerReader) ParseFingers(ctx context.Context, urns URNMap, rawFingers RawFingersMap) (WebFingers, error) { + l := log.FromContext(ctx) + + webfingers := make(WebFingers) + + // Parse the webfinger file + for k, v := range rawFingers { + resource := k + + // Remove leading acct: if present + if len(k) > 5 && resource[:5] == "acct:" { + resource = resource[5:] + } + + // The key must be a URL or email address + if _, err := mail.ParseAddress(resource); err != nil { + if _, err := url.ParseRequestURI(resource); err != nil { + return nil, fmt.Errorf("error parsing webfinger key (%s): %w", k, err) + } + } else { + // Add acct: back to the key if it is an email address + resource = fmt.Sprintf("acct:%s", resource) + } + + // Create a new webfinger + webfinger := &WebFinger{ + Subject: resource, + } + + // Parse the fields + for field, value := range v { + fieldUrn := field + + // If the key is present in the URNs file, use the value + if _, ok := urns[field]; ok { + fieldUrn = urns[field] + } + + // If the value is a valid URI, add it to the links + if _, err := url.ParseRequestURI(value); err == nil { + webfinger.Links = append(webfinger.Links, Link{ + Rel: fieldUrn, + Href: value, + }) + } else { + // Otherwise add it to the properties + if webfinger.Properties == nil { + webfinger.Properties = make(map[string]string) + } + + webfinger.Properties[fieldUrn] = value + } + } + + // Add the webfinger to the map + webfingers[resource] = webfinger + } + + l.Debug("Webfinger map built successfully", slog.Int("number", len(webfingers)), slog.Any("data", webfingers)) + + return webfingers, nil +} + +func (f *FingerReader) ReadFingerFile(ctx context.Context) (WebFingers, error) { + l := log.FromContext(ctx) + + urnMap := make(URNMap) + fingerData := make(RawFingersMap) + + // Parse the URNs file + if err := yaml.Unmarshal(f.URNSFile, &urnMap); err != nil { + return nil, fmt.Errorf("error unmarshalling URNs file: %w", err) + } + + // The URNs file must be a map of strings to valid URLs + for _, v := range urnMap { + if _, err := url.ParseRequestURI(v); err != nil { + return nil, fmt.Errorf("error parsing URN URIs: %w", err) + } + } + + l.Debug("URNs file parsed successfully", slog.Int("number", len(urnMap)), slog.Any("data", urnMap)) + + // Parse the fingers file + if err := yaml.Unmarshal(f.FingersFile, &fingerData); err != nil { + return nil, fmt.Errorf("error unmarshalling fingers file: %w", err) + } + + l.Debug("Fingers file parsed successfully", slog.Int("number", len(fingerData)), slog.Any("data", fingerData)) + + // Parse raw data + webfingers, err := f.ParseFingers(ctx, urnMap, fingerData) + if err != nil { + return nil, fmt.Errorf("error parsing raw fingers: %w", err) + } + + return webfingers, nil +} diff --git a/internal/webfinger/webfinger_test.go b/internal/webfinger/webfinger_test.go new file mode 100644 index 0000000..e61b76b --- /dev/null +++ b/internal/webfinger/webfinger_test.go @@ -0,0 +1,444 @@ +package webfinger_test + +import ( + "context" + "encoding/json" + "os" + "reflect" + "sort" + "strings" + "testing" + + "git.maronato.dev/maronato/finger/internal/config" + "git.maronato.dev/maronato/finger/internal/log" + "git.maronato.dev/maronato/finger/internal/webfinger" +) + +func newTempFile(t *testing.T, content string) (name string, remove func()) { + t.Helper() + + f, err := os.CreateTemp("", "finger-test") + if err != nil { + t.Fatalf("error creating temp file: %v", err) + } + + _, err = f.WriteString(content) + if err != nil { + t.Fatalf("error writing to temp file: %v", err) + } + + return f.Name(), func() { + err = os.Remove(f.Name()) + if err != nil { + t.Fatalf("error removing temp file: %v", err) + } + } +} + +func TestNewFingerReader(t *testing.T) { + t.Parallel() + + f := webfinger.NewFingerReader() + + if f == nil { + t.Errorf("NewFingerReader() = %v, want: %v", f, nil) + } +} + +func TestFingerReader_ReadFiles(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + urnsContent string + fingersContent string + useURNFile bool + useFingerFile bool + wantErr bool + }{ + { + name: "reads files", + urnsContent: "name: https://schema/name\nprofile: https://schema/profile", + fingersContent: "user@example.com:\n name: John Doe", + useURNFile: true, + useFingerFile: true, + wantErr: false, + }, + { + name: "errors on missing URNs file", + urnsContent: "invalid", + fingersContent: "user@example.com:\n name: John Doe", + useURNFile: false, + useFingerFile: true, + wantErr: true, + }, + { + name: "errors on missing fingers file", + urnsContent: "name: https://schema/name\nprofile: https://schema/profile", + fingersContent: "invalid", + useFingerFile: false, + useURNFile: true, + wantErr: true, + }, + } + + for _, tt := range tests { + tc := tt + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + cfg := config.NewConfig() + + urnsFileName, urnsCleanup := newTempFile(t, tc.urnsContent) + defer urnsCleanup() + + fingersFileName, fingersCleanup := newTempFile(t, tc.fingersContent) + defer fingersCleanup() + + if !tc.useURNFile { + cfg.URNPath = "invalid" + } else { + cfg.URNPath = urnsFileName + } + + if !tc.useFingerFile { + cfg.FingerPath = "invalid" + } else { + cfg.FingerPath = fingersFileName + } + + f := webfinger.NewFingerReader() + + err := f.ReadFiles(cfg) + if err != nil { + if !tc.wantErr { + t.Errorf("ReadFiles() error = %v", err) + } + + return + } else if tc.wantErr { + t.Errorf("ReadFiles() error = %v, wantErr %v", err, tc.wantErr) + } + + if !reflect.DeepEqual(f.URNSFile, []byte(tc.urnsContent)) { + t.Errorf("ReadFiles() URNsFile = %v, want: %v", f.URNSFile, tc.urnsContent) + } + + if !reflect.DeepEqual(f.FingersFile, []byte(tc.fingersContent)) { + t.Errorf("ReadFiles() FingersFile = %v, want: %v", f.FingersFile, tc.fingersContent) + } + }) + } +} + +func TestParseFingers(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + rawFingers webfinger.RawFingersMap + want webfinger.WebFingers + wantErr bool + }{ + { + name: "parses links", + rawFingers: webfinger.RawFingersMap{ + "user@example.com": { + "profile": "https://example.com/profile", + "invalidalias": "https://example.com/invalidalias", + "https://something": "https://somethingelse", + }, + }, + want: webfinger.WebFingers{ + "acct:user@example.com": { + Subject: "acct:user@example.com", + Links: []webfinger.Link{ + { + Rel: "https://schema/profile", + Href: "https://example.com/profile", + }, + { + Rel: "invalidalias", + Href: "https://example.com/invalidalias", + }, + { + Rel: "https://something", + Href: "https://somethingelse", + }, + }, + }, + }, + wantErr: false, + }, + { + name: "parses properties", + rawFingers: webfinger.RawFingersMap{ + "user@example.com": { + "name": "John Doe", + "invalidalias": "value1", + "https://mylink": "value2", + }, + }, + want: webfinger.WebFingers{ + "acct:user@example.com": { + Subject: "acct:user@example.com", + Properties: map[string]string{ + "https://schema/name": "John Doe", + "invalidalias": "value1", + "https://mylink": "value2", + }, + }, + }, + wantErr: false, + }, + { + name: "accepts acct: prefix", + rawFingers: webfinger.RawFingersMap{ + "acct:user@example.com": { + "name": "John Doe", + }, + }, + want: webfinger.WebFingers{ + "acct:user@example.com": { + Subject: "acct:user@example.com", + Properties: map[string]string{ + "https://schema/name": "John Doe", + }, + }, + }, + wantErr: false, + }, + { + name: "accepts urls as resource", + rawFingers: webfinger.RawFingersMap{ + "https://example.com": { + "name": "John Doe", + }, + }, + want: webfinger.WebFingers{ + "https://example.com": { + Subject: "https://example.com", + Properties: map[string]string{ + "https://schema/name": "John Doe", + }, + }, + }, + wantErr: false, + }, + { + name: "accepts multiple resources", + rawFingers: webfinger.RawFingersMap{ + "user@example.com": { + "name": "John Doe", + }, + "other@example.com": { + "name": "Jane Doe", + }, + }, + want: webfinger.WebFingers{ + "acct:user@example.com": { + Subject: "acct:user@example.com", + Properties: map[string]string{ + "https://schema/name": "John Doe", + }, + }, + "acct:other@example.com": { + Subject: "acct:other@example.com", + Properties: map[string]string{ + "https://schema/name": "Jane Doe", + }, + }, + }, + wantErr: false, + }, + { + name: "errors on invalid resource", + rawFingers: webfinger.RawFingersMap{ + "invalid": { + "name": "John Doe", + }, + }, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + tc := tt + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Create a urn map + urns := webfinger.URNMap{ + "name": "https://schema/name", + "profile": "https://schema/profile", + } + + ctx := context.Background() + cfg := config.NewConfig() + l := log.NewLogger(&strings.Builder{}, cfg) + + ctx = log.WithLogger(ctx, l) + + f := webfinger.NewFingerReader() + + got, err := f.ParseFingers(ctx, urns, tc.rawFingers) + if (err != nil) != tc.wantErr { + t.Errorf("ParseFingers() error = %v, wantErr %v", err, tc.wantErr) + + return + } + + // Sort links to make it easier to compare + for _, v := range got { + for range v.Links { + sort.Slice(v.Links, func(i, j int) bool { + return v.Links[i].Rel < v.Links[j].Rel + }) + } + } + + for _, v := range tc.want { + for range v.Links { + sort.Slice(v.Links, func(i, j int) bool { + return v.Links[i].Rel < v.Links[j].Rel + }) + } + } + + if !reflect.DeepEqual(got, tc.want) { + // Unmarshal the structs to JSON to make it easier to print + gotstr := &strings.Builder{} + gotenc := json.NewEncoder(gotstr) + + wantstr := &strings.Builder{} + wantenc := json.NewEncoder(wantstr) + + _ = gotenc.Encode(got) + _ = wantenc.Encode(tc.want) + + t.Errorf("ParseFingers() got = \n%s want: \n%s", gotstr.String(), wantstr.String()) + } + }) + } +} + +func TestReadFingerFile(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + urnsContent string + fingersContent string + wantURN webfinger.URNMap + wantFinger webfinger.RawFingersMap + returns *webfinger.WebFingers + wantErr bool + }{ + { + name: "reads files", + urnsContent: "name: https://schema/name\nprofile: https://schema/profile", + fingersContent: "user@example.com:\n name: John Doe", + wantURN: webfinger.URNMap{ + "name": "https://schema/name", + "profile": "https://schema/profile", + }, + wantFinger: webfinger.RawFingersMap{ + "user@example.com": { + "name": "John Doe", + }, + }, + returns: &webfinger.WebFingers{ + "acct:user@example.com": { + Subject: "acct:user@example.com", + Properties: map[string]string{ + "https://schema/name": "John Doe", + }, + }, + }, + wantErr: false, + }, + { + name: "uses custom URNs", + urnsContent: "favorite_food: https://schema/favorite_food", + fingersContent: "user@example.com:\n favorite_food: Apple", + wantURN: webfinger.URNMap{ + "favorite_food": "https://schema/favorite_food", + }, + wantFinger: webfinger.RawFingersMap{ + "user@example.com": { + "https://schema/favorite_food": "Apple", + }, + }, + wantErr: false, + }, + { + name: "errors on invalid URNs file", + urnsContent: "invalid", + fingersContent: "user@example.com:\n name: John Doe", + wantURN: webfinger.URNMap{}, + wantFinger: webfinger.RawFingersMap{}, + wantErr: true, + }, + { + name: "errors on invalid fingers file", + urnsContent: "name: https://schema/name\nprofile: https://schema/profile", + fingersContent: "invalid", + wantURN: webfinger.URNMap{}, + wantFinger: webfinger.RawFingersMap{}, + wantErr: true, + }, + { + name: "errors on invalid URNs values", + urnsContent: "name: invalid", + fingersContent: "user@example.com:\n name: John Doe", + wantURN: webfinger.URNMap{}, + wantFinger: webfinger.RawFingersMap{}, + wantErr: true, + }, + { + name: "errors on invalid fingers values", + urnsContent: "name: https://schema/name\nprofile: https://schema/profile", + fingersContent: "invalid:\n name: John Doe", + wantURN: webfinger.URNMap{}, + wantFinger: webfinger.RawFingersMap{}, + wantErr: true, + }, + } + + for _, tt := range tests { + tc := tt + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx := context.Background() + cfg := config.NewConfig() + l := log.NewLogger(&strings.Builder{}, cfg) + + ctx = log.WithLogger(ctx, l) + + f := webfinger.NewFingerReader() + + f.FingersFile = []byte(tc.fingersContent) + f.URNSFile = []byte(tc.urnsContent) + + got, err := f.ReadFingerFile(ctx) + if err != nil { + if !tc.wantErr { + t.Errorf("ReadFingerFile() error = %v", err) + } + + return + } else if tc.wantErr { + t.Errorf("ReadFingerFile() error = %v, wantErr %v", err, tc.wantErr) + } + + if tc.returns != nil && !reflect.DeepEqual(got, *tc.returns) { + t.Errorf("ReadFingerFile() got = %v, want: %v", got, *tc.returns) + } + }) + } +} diff --git a/main.go b/main.go index 860b03b..724839e 100644 --- a/main.go +++ b/main.go @@ -1,566 +1,19 @@ package main import ( - "context" - "encoding/json" - "errors" "fmt" - "net" - "net/http" - "net/mail" - "net/url" "os" - "os/signal" - "syscall" - "time" - "github.com/peterbourgon/ff/v4" - "github.com/peterbourgon/ff/v4/ffhelp" - "golang.org/x/exp/slog" - "golang.org/x/sync/errgroup" - "gopkg.in/yaml.v3" + "git.maronato.dev/maronato/finger/cmd" ) -const appName = "finger" - -// Version of the application. +// Version of the app. var version = "dev" func main() { // Run the server - if err := Run(); err != nil { + if err := cmd.Run(version); err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) os.Exit(1) } } - -func Run() error { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Allow graceful shutdown - trapSignalsCrossPlatform(cancel) - - cfg := &Config{} - - // Create a new root command - subcommands := []*ff.Command{ - NewServerCmd(cfg), - NewHealthcheckCmd(cfg), - } - cmd := NewRootCmd(cfg, subcommands) - - // Parse and run - if err := cmd.ParseAndRun(ctx, os.Args[1:], ff.WithEnvVarPrefix("WF")); err != nil { - if errors.Is(err, ff.ErrHelp) || errors.Is(err, ff.ErrNoExec) { - fmt.Fprintf(os.Stderr, "\n%s\n", ffhelp.Command(cmd)) - - return nil - } - - return fmt.Errorf("error running command: %w", err) - } - - return nil -} - -func NewServerCmd(cfg *Config) *ff.Command { - return &ff.Command{ - Name: "serve", - Usage: "serve [flags]", - ShortHelp: "Start the webfinger server", - Exec: func(ctx context.Context, args []string) error { - // Create a logger and add it to the context - l := NewLogger(cfg) - ctx = WithLogger(ctx, l) - - // Parse the webfinger files - fingermap, err := ParseFingerFile(ctx, cfg) - if err != nil { - return fmt.Errorf("error parsing finger files: %w", err) - } - - l.Info(fmt.Sprintf("Loaded %d webfingers", len(fingermap))) - - // Start the server - if err := StartServer(ctx, cfg, fingermap); err != nil { - return fmt.Errorf("error running server: %w", err) - } - - return nil - }, - } -} - -func NewHealthcheckCmd(cfg *Config) *ff.Command { - return &ff.Command{ - Name: "healthcheck", - Usage: "healthcheck [flags]", - ShortHelp: "Check if the server is running", - Exec: func(ctx context.Context, args []string) error { - // Create a new client - client := &http.Client{ - Timeout: 5 * time.Second, //nolint:gomnd // We want to use a constant - } - - // Create a new request - reqURL := url.URL{ - Scheme: "http", - Host: net.JoinHostPort(cfg.Host, cfg.Port), - Path: "/healthz", - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), http.NoBody) - if err != nil { - return fmt.Errorf("error creating request: %w", err) - } - - // Send the request - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("error sending request: %w", err) - } - - defer resp.Body.Close() - - // Check the response - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("server returned status %d", resp.StatusCode) //nolint:goerr113 // We want to return an error - } - - return nil - }, - } -} - -type loggerCtxKey struct{} - -// NewLogger creates a new logger with the given debug level. -func NewLogger(cfg *Config) *slog.Logger { - level := slog.LevelInfo - addSource := false - - if cfg.Debug { - level = slog.LevelDebug - addSource = true - } - - return slog.New( - slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{ - Level: level, - AddSource: addSource, - }), - ) -} - -func LoggerFromContext(ctx context.Context) *slog.Logger { - l, ok := ctx.Value(loggerCtxKey{}).(*slog.Logger) - if !ok { - panic("logger not found in context") - } - - return l -} - -func WithLogger(ctx context.Context, l *slog.Logger) context.Context { - return context.WithValue(ctx, loggerCtxKey{}, l) -} - -// https://github.com/caddyserver/caddy/blob/fbb0ecfa322aa7710a3448453fd3ae40f037b8d1/sigtrap.go#L37 -// trapSignalsCrossPlatform captures SIGINT or interrupt (depending -// on the OS), which initiates a graceful shutdown. A second SIGINT -// or interrupt will forcefully exit the process immediately. -func trapSignalsCrossPlatform(cancel context.CancelFunc) { - go func() { - shutdown := make(chan os.Signal, 1) - signal.Notify(shutdown, os.Interrupt, syscall.SIGINT) - - for i := 0; true; i++ { - <-shutdown - - if i > 0 { - fmt.Printf("\nForce quit\n") //nolint:forbidigo // We want to print to stdout - os.Exit(1) - } - - fmt.Printf("\nGracefully shutting down. Press Ctrl+C again to force quit\n") //nolint:forbidigo // We want to print to stdout - cancel() - } - }() -} - -type Config struct { - Debug bool - Host string - Port string - urnPath string - fingerPath string -} - -// NewRootCmd parses the command line flags and returns a Config struct. -func NewRootCmd(cfg *Config, subcommands []*ff.Command) *ff.Command { - fs := ff.NewFlagSet(appName) - - for _, cmd := range subcommands { - cmd.Flags = ff.NewFlagSet(cmd.Name).SetParent(fs) - } - - cmd := &ff.Command{ - Name: appName, - Usage: fmt.Sprintf("%s [flags]", appName), - ShortHelp: fmt.Sprintf("(%s) A webfinger server", version), - Flags: fs, - Subcommands: subcommands, - } - - // Use 0.0.0.0 as the default host if on docker - defaultHost := "localhost" - if os.Getenv("ENV_DOCKER") == "true" { - defaultHost = "0.0.0.0" - } - - fs.BoolVar(&cfg.Debug, 'd', "debug", "Enable debug logging") - fs.StringVar(&cfg.Host, 'h', "host", defaultHost, "Host to listen on") - fs.StringVar(&cfg.Port, 'p', "port", "8080", "Port to listen on") - fs.StringVar(&cfg.urnPath, 'u', "urn-file", "urns.yml", "Path to the URNs file") - fs.StringVar(&cfg.fingerPath, 'f', "finger-file", "fingers.yml", "Path to the fingers file") - - return cmd -} - -type Link struct { - Rel string `json:"rel"` - Href string `json:"href,omitempty"` -} - -type WebFinger struct { - Subject string `json:"subject"` - Links []Link `json:"links,omitempty"` - Properties map[string]string `json:"properties,omitempty"` -} - -type WebFingerMap map[string]*WebFinger - -func ParseFingerFile(ctx context.Context, cfg *Config) (WebFingerMap, error) { - l := LoggerFromContext(ctx) - - urnMap := make(map[string]string) - fingerData := make(map[string]map[string]string) - - fingermap := make(WebFingerMap) - - // Read URNs file - file, err := os.ReadFile(cfg.urnPath) - if err != nil { - return nil, fmt.Errorf("error opening URNs file: %w", err) - } - - if err := yaml.Unmarshal(file, &urnMap); err != nil { - return nil, fmt.Errorf("error unmarshalling URNs file: %w", err) - } - - // The URNs file must be a map of strings to valid URLs - for _, v := range urnMap { - if _, err := url.Parse(v); err != nil { - return nil, fmt.Errorf("error parsing URN URIs: %w", err) - } - } - - l.Debug("URNs file parsed successfully", slog.Int("number", len(urnMap)), slog.Any("data", urnMap)) - - // Read webfingers file - file, err = os.ReadFile(cfg.fingerPath) - if err != nil { - return nil, fmt.Errorf("error opening fingers file: %w", err) - } - - if err := yaml.Unmarshal(file, &fingerData); err != nil { - return nil, fmt.Errorf("error unmarshalling fingers file: %w", err) - } - - l.Debug("Fingers file parsed successfully", slog.Int("number", len(fingerData)), slog.Any("data", fingerData)) - - // Parse the webfinger file - for k, v := range fingerData { - resource := k - - // Remove leading acct: if present - if len(k) > 5 && resource[:5] == "acct:" { - resource = resource[5:] - } - - // The key must be a URL or email address - if _, err := mail.ParseAddress(resource); err != nil { - if _, err := url.Parse(resource); err != nil { - return nil, fmt.Errorf("error parsing webfinger key (%s): %w", k, err) - } - } else { - // Add acct: back to the key if it is an email address - resource = fmt.Sprintf("acct:%s", resource) - } - - // Create a new webfinger - webfinger := &WebFinger{ - Subject: resource, - } - - // Parse the fields - for field, value := range v { - fieldUrn := field - - // If the key is present in the URNs file, use the value - if _, ok := urnMap[field]; ok { - fieldUrn = urnMap[field] - } - - // If the value is a valid URI, add it to the links - if _, err := url.Parse(value); err == nil { - webfinger.Links = append(webfinger.Links, Link{ - Rel: fieldUrn, - Href: value, - }) - } else { - // Otherwise add it to the properties - if webfinger.Properties == nil { - webfinger.Properties = make(map[string]string) - } - - webfinger.Properties[fieldUrn] = value - } - } - - // Add the webfinger to the map - fingermap[resource] = webfinger - } - - l.Debug("Webfinger map built successfully", slog.Int("number", len(fingermap)), slog.Any("data", fingermap)) - - return fingermap, nil -} - -func WebfingerHandler(_ *Config, webmap WebFingerMap) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - l := LoggerFromContext(ctx) - - // Only handle GET requests - if r.Method != http.MethodGet { - l.Debug("Method not allowed") - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - - return - } - - // Get the query params - q := r.URL.Query() - - // Get the resource - resource := q.Get("resource") - if resource == "" { - l.Debug("No resource provided") - http.Error(w, "No resource provided", http.StatusBadRequest) - - return - } - - // Get and validate resource - webfinger, ok := webmap[resource] - if !ok { - l.Debug("Resource not found") - http.Error(w, "Resource not found", http.StatusNotFound) - - return - } - - // Set the content type - w.Header().Set("Content-Type", "application/jrd+json") - - // Write the response - if err := json.NewEncoder(w).Encode(webfinger); err != nil { - l.Debug("Error encoding json") - http.Error(w, "Error encoding json", http.StatusInternalServerError) - - return - } - - l.Debug("Webfinger request successful") - }) -} - -func HealthCheckHandler(_ *Config) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - }) -} - -type ResponseWrapper struct { - http.ResponseWriter - - status int -} - -func WrapResponseWriter(w http.ResponseWriter) *ResponseWrapper { - return &ResponseWrapper{w, 0} -} - -func (w *ResponseWrapper) WriteHeader(code int) { - w.status = code - w.ResponseWriter.WriteHeader(code) -} - -func (w *ResponseWrapper) Status() int { - return w.status -} - -func (w *ResponseWrapper) Write(b []byte) (int, error) { - if w.status == 0 { - w.status = http.StatusOK - } - - size, err := w.ResponseWriter.Write(b) - if err != nil { - return 0, fmt.Errorf("error writing response: %w", err) - } - - return size, nil -} - -func (w *ResponseWrapper) Unwrap() http.ResponseWriter { - return w.ResponseWriter -} - -func LoggingMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - l := LoggerFromContext(ctx) - - start := time.Now() - - // Wrap the response writer - wrapped := WrapResponseWriter(w) - - // Call the next handler - next.ServeHTTP(wrapped, r) - - status := wrapped.Status() - - // Log the request - lg := l.With( - slog.String("method", r.Method), - slog.String("path", r.URL.Path), - slog.Int("status", status), - slog.String("remote", r.RemoteAddr), - slog.Duration("duration", time.Since(start)), - ) - - switch { - case status >= http.StatusInternalServerError: - lg.Error("Server error") - case status >= http.StatusBadRequest: - lg.Info("Client error") - default: - lg.Info("Request completed") - } - }) -} - -const ( - // ReadTimeout is the maximum duration for reading the entire - // request, including the body. - ReadTimeout = 5 * time.Second - // WriteTimeout is the maximum duration before timing out - // writes of the response. - WriteTimeout = 10 * time.Second - // IdleTimeout is the maximum amount of time to wait for the - // next request when keep-alives are enabled. - IdleTimeout = 30 * time.Second - // ReadHeaderTimeout is the amount of time allowed to read - // request headers. - ReadHeaderTimeout = 2 * time.Second - // RequestTimeout is the maximum duration for the entire - // request. - RequestTimeout = 7 * 24 * time.Hour -) - -func StartServer(ctx context.Context, cfg *Config, webmap WebFingerMap) error { - l := LoggerFromContext(ctx) - - // Create the server mux - mux := http.NewServeMux() - mux.Handle("/.well-known/webfinger", WebfingerHandler(cfg, webmap)) - mux.Handle("/healthz", HealthCheckHandler(cfg)) - - // Create a new server - srv := &http.Server{ - Addr: net.JoinHostPort(cfg.Host, cfg.Port), - BaseContext: func(_ net.Listener) context.Context { - return ctx - }, - Handler: LoggingMiddleware( - RecoveryHandler( - http.TimeoutHandler(mux, RequestTimeout, "request timed out"), - ), - ), - ReadHeaderTimeout: ReadHeaderTimeout, - ReadTimeout: ReadTimeout, - WriteTimeout: WriteTimeout, - IdleTimeout: IdleTimeout, - } - - // Create the errorgroup that will manage the server execution - eg, egCtx := errgroup.WithContext(ctx) - - // Start the server - eg.Go(func() error { - l.Info("Starting server", slog.String("addr", srv.Addr)) - - // Use the global context for the server - srv.BaseContext = func(_ net.Listener) context.Context { - return egCtx - } - - return srv.ListenAndServe() //nolint:wrapcheck // We wrap the error in the errgroup - }) - // Gracefully shutdown the server when the context is done - eg.Go(func() error { - // Wait for the context to be done - <-egCtx.Done() - - l.Info("Shutting down server") - // Disable the cancel since we don't wan't to force - // the server to shutdown if the context is canceled. - noCancelCtx := context.WithoutCancel(egCtx) - - return srv.Shutdown(noCancelCtx) //nolint:wrapcheck // We wrap the error in the errgroup - }) - - srv.RegisterOnShutdown(func() { - l.Info("Server shutdown complete") - }) - - // Ignore the error if the context was canceled - if err := eg.Wait(); err != nil && ctx.Err() == nil { - return fmt.Errorf("server exited with error: %w", err) - } - - return nil -} - -func RecoveryHandler(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - l := LoggerFromContext(ctx) - - defer func() { - err := recover() - if err != nil { - l.Error("Panic", slog.Any("error", err)) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - - return - } - }() - - next.ServeHTTP(w, r) - }) -} diff --git a/main_test.go b/main_test.go deleted file mode 100644 index 12c02e6..0000000 --- a/main_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package main_test - -import ( - "context" - "fmt" - "net/http" - "net/http/httptest" - "testing" - - finger "git.maronato.dev/maronato/finger" -) - -func BenchmarkGetWebfinger(b *testing.B) { - ctx := context.Background() - cfg := &finger.Config{} - l := finger.NewLogger(cfg) - - ctx = finger.WithLogger(ctx, l) - resource := "acct:user@example.com" - webmap := finger.WebFingerMap{ - resource: { - Subject: resource, - Links: []finger.Link{ - { - Rel: "http://webfinger.net/rel/avatar", - Href: "https://example.com/avatar.png", - }, - }, - Properties: map[string]string{ - "example": "value", - }, - }, - "acct:other": { - Subject: "acct:other", - Links: []finger.Link{ - { - Rel: "http://webfinger.net/rel/avatar", - Href: "https://example.com/avatar.png", - }, - }, - Properties: map[string]string{ - "example": "value", - }, - }, - } - - handler := finger.WebfingerHandler(&finger.Config{}, webmap) - - r, _ := http.NewRequestWithContext( - ctx, - http.MethodGet, - fmt.Sprintf("/.well-known/webfinger?resource=%s", resource), - http.NoBody, - ) - - for i := 0; i < b.N; i++ { - w := httptest.NewRecorder() - handler.ServeHTTP(w, r) - } -}