mirror of
https://github.com/Maronato/go-finger.git
synced 2025-03-15 00:34:47 +00:00
refactor and add tests
This commit is contained in:
parent
6bbfbad1d0
commit
f96dda4af2
26 changed files with 2180 additions and 613 deletions
148
README.md
148
README.md
|
@ -0,0 +1,148 @@
|
||||||
|
# Finger
|
||||||
|
|
||||||
|
Webfinger server written in Go.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
- 🍰 Easy YAML configuration
|
||||||
|
- 🪶 Single 8MB binary / 0% idle CPU / 4MB idle RAM
|
||||||
|
- ⚡️ Sub millisecond responses at 10,000 request per second
|
||||||
|
- 🐳 10MB Docker image
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
Via `go install`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go install git.maronato.dev/maronato/finger@latest
|
||||||
|
```
|
||||||
|
|
||||||
|
Via Docker:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker run --name finger /
|
||||||
|
-p 8080:8080 /
|
||||||
|
git.maronato.dev/maronato/finger
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
If you installed it using `go install`, run
|
||||||
|
```bash
|
||||||
|
finger serve
|
||||||
|
```
|
||||||
|
To start the server on port `8080`. Your resources will be queryable via `locahost:8080/.well-known/webfinger?resource=<your-resource>`
|
||||||
|
|
||||||
|
If you're using Docker, the use the same command in the install section.
|
||||||
|
|
||||||
|
By default, no resources will be exposed. You can create resources via a `fingers.yml` file. It should contain a collection of resources as keys and their attributes as their objects.
|
||||||
|
|
||||||
|
Some default URN aliases are provided via the built-in mapping ([`urns.yml`](./urns.yml)). You can replace that with your own or use URNs directly in the `fingers.yml` file.
|
||||||
|
|
||||||
|
Here's an example:
|
||||||
|
```yaml
|
||||||
|
# fingers.yml
|
||||||
|
|
||||||
|
# Resources go in the root of the file. Email address will have the acct:
|
||||||
|
# prefix added automatically.
|
||||||
|
alice@example.com:
|
||||||
|
# "avatar" is an alias of "http://webfinger.net/rel/avatar"
|
||||||
|
# (see urns.yml for more)
|
||||||
|
avatar: "https://example.com/alice-pic"
|
||||||
|
|
||||||
|
# If the value is a URI, it'll be exposed as a webfinger link
|
||||||
|
openid: "https://sso.example.com/"
|
||||||
|
|
||||||
|
# If the value of the attribute is not a URI, it will be exposed as a
|
||||||
|
# webfinger property
|
||||||
|
name: "Alice Doe"
|
||||||
|
|
||||||
|
# You can also specify URN's directly instead of the aliases
|
||||||
|
http://webfinger.net/rel/profile-page: "https://example.com/user/alice"
|
||||||
|
|
||||||
|
bob@example.com:
|
||||||
|
name: Bob Foo
|
||||||
|
openid: "https://sso.example.com/"
|
||||||
|
|
||||||
|
# Resources can also be URIs
|
||||||
|
https://example.com/user/charlie:
|
||||||
|
name: Charlie Baz
|
||||||
|
profile: https://example.com/user/charlie
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example queries
|
||||||
|
<details>
|
||||||
|
<summary><b>Query Alice</b><pre>GET http://localhost:8080/.well-known/webfinger?resource=acct:alice@example.com</pre></summary>
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"subject": "acct:alice@example.com",
|
||||||
|
"links": [
|
||||||
|
{
|
||||||
|
"rel": "avatar",
|
||||||
|
"href": "https://example.com/alice-pic"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"rel": "openid",
|
||||||
|
"href": "https://sso.example.com/"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"rel": "name",
|
||||||
|
"href": "Alice Doe"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"rel": "http://webfinger.net/rel/profile-page",
|
||||||
|
"href": "https://example.com/user/alice"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><b>Query Bob</b><pre>GET http://localhost:8080/.well-known/webfinger?resource=acct:bob@example.com</pre></summary>
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"subject": "acct:bob@example.com",
|
||||||
|
"links": [
|
||||||
|
{
|
||||||
|
"rel": "name",
|
||||||
|
"href": "Bob Foo"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"rel": "openid",
|
||||||
|
"href": "https://sso.example.com/"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><b>Query Charlie</b><pre>GET http://localhost:8080/.well-known/webfinger?resource=https://example.com/user/charlie</pre></summary>
|
||||||
|
|
||||||
|
```JSON
|
||||||
|
{
|
||||||
|
"subject": "https://example.com/user/charlie",
|
||||||
|
"links": [
|
||||||
|
{
|
||||||
|
"rel": "name",
|
||||||
|
"href": "Charlie Baz"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"rel": "profile",
|
||||||
|
"href": "https://example.com/user/charlie"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## Configs
|
||||||
|
Here are the config options available. You can change them via command line flags or environment variables:
|
||||||
|
|
||||||
|
| CLI flag | Env variable | Default | Description |
|
||||||
|
| -------- | ------------ | ------- | ----------- |
|
||||||
|
| fdsfds | gsfgfs | fgfsdgf | gdfsgdf |
|
98
cmd/cmd.go
Normal file
98
cmd/cmd.go
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"git.maronato.dev/maronato/finger/internal/config"
|
||||||
|
"github.com/peterbourgon/ff/v4"
|
||||||
|
"github.com/peterbourgon/ff/v4/ffhelp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Run(version string) error {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Allow graceful shutdown
|
||||||
|
trapSignalsCrossPlatform(cancel)
|
||||||
|
|
||||||
|
cfg := &config.Config{}
|
||||||
|
|
||||||
|
// Create a new root command
|
||||||
|
subcommands := []*ff.Command{
|
||||||
|
newServerCmd(cfg),
|
||||||
|
newHealthcheckCmd(cfg),
|
||||||
|
}
|
||||||
|
cmd := newRootCmd(version, cfg, subcommands)
|
||||||
|
|
||||||
|
// Parse and run
|
||||||
|
if err := cmd.ParseAndRun(ctx, os.Args[1:], ff.WithEnvVarPrefix("WF")); err != nil {
|
||||||
|
if errors.Is(err, ff.ErrHelp) || errors.Is(err, ff.ErrNoExec) {
|
||||||
|
fmt.Fprintf(os.Stderr, "\n%s\n", ffhelp.Command(cmd))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("error running command: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/caddyserver/caddy/blob/fbb0ecfa322aa7710a3448453fd3ae40f037b8d1/sigtrap.go#L37
|
||||||
|
// trapSignalsCrossPlatform captures SIGINT or interrupt (depending
|
||||||
|
// on the OS), which initiates a graceful shutdown. A second SIGINT
|
||||||
|
// or interrupt will forcefully exit the process immediately.
|
||||||
|
func trapSignalsCrossPlatform(cancel context.CancelFunc) {
|
||||||
|
go func() {
|
||||||
|
shutdown := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(shutdown, os.Interrupt, syscall.SIGINT)
|
||||||
|
|
||||||
|
for i := 0; true; i++ {
|
||||||
|
<-shutdown
|
||||||
|
|
||||||
|
if i > 0 {
|
||||||
|
fmt.Printf("\nForce quit\n") //nolint:forbidigo // We want to print to stdout
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("\nGracefully shutting down. Press Ctrl+C again to force quit\n") //nolint:forbidigo // We want to print to stdout
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRootCmd parses the command line flags and returns a config.Config struct.
|
||||||
|
func newRootCmd(version string, cfg *config.Config, subcommands []*ff.Command) *ff.Command {
|
||||||
|
fs := ff.NewFlagSet(appName)
|
||||||
|
|
||||||
|
for _, cmd := range subcommands {
|
||||||
|
cmd.Flags = ff.NewFlagSet(cmd.Name).SetParent(fs)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := &ff.Command{
|
||||||
|
Name: appName,
|
||||||
|
Usage: fmt.Sprintf("%s <command> [flags]", appName),
|
||||||
|
ShortHelp: fmt.Sprintf("(%s) A webfinger server", version),
|
||||||
|
Flags: fs,
|
||||||
|
Subcommands: subcommands,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use 0.0.0.0 as the default host if on docker
|
||||||
|
defaultHost := "localhost"
|
||||||
|
if os.Getenv("ENV_DOCKER") == "true" {
|
||||||
|
defaultHost = "0.0.0.0"
|
||||||
|
}
|
||||||
|
|
||||||
|
fs.BoolVar(&cfg.Debug, 'd', "debug", "Enable debug logging")
|
||||||
|
fs.StringVar(&cfg.Host, 'h', "host", defaultHost, "Host to listen on")
|
||||||
|
fs.StringVar(&cfg.Port, 'p', "port", "8080", "Port to listen on")
|
||||||
|
fs.StringVar(&cfg.URNPath, 'u', "urn-file", "urns.yml", "Path to the URNs file")
|
||||||
|
fs.StringVar(&cfg.FingerPath, 'f', "finger-file", "fingers.yml", "Path to the fingers file")
|
||||||
|
|
||||||
|
return cmd
|
||||||
|
}
|
53
cmd/healthcheck.go
Normal file
53
cmd/healthcheck.go
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.maronato.dev/maronato/finger/internal/config"
|
||||||
|
"github.com/peterbourgon/ff/v4"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newHealthcheckCmd(cfg *config.Config) *ff.Command {
|
||||||
|
return &ff.Command{
|
||||||
|
Name: "healthcheck",
|
||||||
|
Usage: "healthcheck [flags]",
|
||||||
|
ShortHelp: "Check if the server is running",
|
||||||
|
Exec: func(ctx context.Context, args []string) error {
|
||||||
|
// Create a new client
|
||||||
|
client := &http.Client{
|
||||||
|
Timeout: 5 * time.Second, //nolint:gomnd // We want to use a constant
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new request
|
||||||
|
reqURL := url.URL{
|
||||||
|
Scheme: "http",
|
||||||
|
Host: cfg.GetAddr(),
|
||||||
|
Path: "/healthz",
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), http.NoBody)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error creating request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the request
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error sending request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Check the response
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("server returned status %d", resp.StatusCode) //nolint:goerr113 // We want to return an error
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
49
cmd/serve.go
Normal file
49
cmd/serve.go
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"git.maronato.dev/maronato/finger/internal/config"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/log"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/server"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/webfinger"
|
||||||
|
"github.com/peterbourgon/ff/v4"
|
||||||
|
)
|
||||||
|
|
||||||
|
const appName = "finger"
|
||||||
|
|
||||||
|
func newServerCmd(cfg *config.Config) *ff.Command {
|
||||||
|
return &ff.Command{
|
||||||
|
Name: "serve",
|
||||||
|
Usage: "serve [flags]",
|
||||||
|
ShortHelp: "Start the webfinger server",
|
||||||
|
Exec: func(ctx context.Context, args []string) error {
|
||||||
|
// Create a logger and add it to the context
|
||||||
|
l := log.NewLogger(os.Stderr, cfg)
|
||||||
|
ctx = log.WithLogger(ctx, l)
|
||||||
|
|
||||||
|
// Read the webfinger files
|
||||||
|
r := webfinger.NewFingerReader()
|
||||||
|
err := r.ReadFiles(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error reading finger files: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
webfingers, err := r.ReadFingerFile(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error parsing finger files: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Info(fmt.Sprintf("Loaded %d webfingers", len(webfingers)))
|
||||||
|
|
||||||
|
// Start the server
|
||||||
|
if err := server.StartServer(ctx, cfg, webfingers); err != nil {
|
||||||
|
return fmt.Errorf("error running server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
1
go.mod
1
go.mod
|
@ -4,7 +4,6 @@ go 1.21.0
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/peterbourgon/ff/v4 v4.0.0-alpha.3
|
github.com/peterbourgon/ff/v4 v4.0.0-alpha.3
|
||||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9
|
|
||||||
golang.org/x/sync v0.3.0
|
golang.org/x/sync v0.3.0
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
)
|
)
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -2,8 +2,6 @@ github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N
|
||||||
github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
|
github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
|
||||||
github.com/peterbourgon/ff/v4 v4.0.0-alpha.3 h1:fpyiFVEJvxIFljxM4l5ANSk/UGlM1gyU+hPAr9jhB7M=
|
github.com/peterbourgon/ff/v4 v4.0.0-alpha.3 h1:fpyiFVEJvxIFljxM4l5ANSk/UGlM1gyU+hPAr9jhB7M=
|
||||||
github.com/peterbourgon/ff/v4 v4.0.0-alpha.3/go.mod h1:H/13DK46DKXy7EaIxPhk2Y0EC8aubKm35nBjBe8AAGc=
|
github.com/peterbourgon/ff/v4 v4.0.0-alpha.3/go.mod h1:H/13DK46DKXy7EaIxPhk2Y0EC8aubKm35nBjBe8AAGc=
|
||||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
|
|
||||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
|
|
||||||
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
|
golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E=
|
||||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||||
|
|
67
internal/config/config.go
Normal file
67
internal/config/config.go
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/url"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultHost is the default host to listen on.
|
||||||
|
DefaultHost = "localhost"
|
||||||
|
// DefaultPort is the default port to listen on.
|
||||||
|
DefaultPort = "8080"
|
||||||
|
// DefaultURNPath is the default file path to the URN alias file.
|
||||||
|
DefaultURNPath = "urns.yml"
|
||||||
|
// DefaultFingerPath is the default file path to the webfinger definition file.
|
||||||
|
DefaultFingerPath = "finger.yml"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrInvalidConfig is returned when the config is invalid.
|
||||||
|
var ErrInvalidConfig = errors.New("invalid config")
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Debug bool
|
||||||
|
Host string
|
||||||
|
Port string
|
||||||
|
URNPath string
|
||||||
|
FingerPath string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConfig() *Config {
|
||||||
|
return &Config{
|
||||||
|
Host: DefaultHost,
|
||||||
|
Port: DefaultPort,
|
||||||
|
URNPath: DefaultURNPath,
|
||||||
|
FingerPath: DefaultFingerPath,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) GetAddr() string {
|
||||||
|
return net.JoinHostPort(c.Host, c.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) Validate() error {
|
||||||
|
if c.Host == "" {
|
||||||
|
return fmt.Errorf("%w: host is empty", ErrInvalidConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Port == "" {
|
||||||
|
return fmt.Errorf("%w: port is empty", ErrInvalidConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := url.Parse(c.GetAddr()); err != nil {
|
||||||
|
return fmt.Errorf("%w: %w", ErrInvalidConfig, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.URNPath == "" {
|
||||||
|
return fmt.Errorf("%w: urn path is empty", ErrInvalidConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.FingerPath == "" {
|
||||||
|
return fmt.Errorf("%w: finger path is empty", ErrInvalidConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
124
internal/config/config_test.go
Normal file
124
internal/config/config_test.go
Normal file
|
@ -0,0 +1,124 @@
|
||||||
|
package config_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.maronato.dev/maronato/finger/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConfig_GetAddr(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg *config.Config
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default",
|
||||||
|
cfg: config.NewConfig(),
|
||||||
|
want: "localhost:8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom",
|
||||||
|
cfg: &config.Config{
|
||||||
|
Host: "example.com",
|
||||||
|
Port: "1234",
|
||||||
|
},
|
||||||
|
want: "example.com:1234",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tc := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := tc.cfg.GetAddr()
|
||||||
|
if got != tc.want {
|
||||||
|
t.Errorf("Config.GetAddr() = %v, want %v", got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_Validate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
cfg *config.Config
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default",
|
||||||
|
cfg: config.NewConfig(),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty host",
|
||||||
|
cfg: &config.Config{
|
||||||
|
Host: "",
|
||||||
|
Port: "1234",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty port",
|
||||||
|
cfg: &config.Config{
|
||||||
|
Host: "example.com",
|
||||||
|
Port: "",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid addr",
|
||||||
|
cfg: &config.Config{
|
||||||
|
Host: "example.com",
|
||||||
|
Port: "invalid",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty urn path",
|
||||||
|
cfg: &config.Config{
|
||||||
|
Host: "example.com",
|
||||||
|
Port: "1234",
|
||||||
|
URNPath: "",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty finger path",
|
||||||
|
cfg: &config.Config{
|
||||||
|
Host: "example.com",
|
||||||
|
Port: "1234",
|
||||||
|
URNPath: "urns.yml",
|
||||||
|
FingerPath: "",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid",
|
||||||
|
cfg: &config.Config{
|
||||||
|
Host: "example.com",
|
||||||
|
Port: "1234",
|
||||||
|
URNPath: "urns.yml",
|
||||||
|
FingerPath: "finger.yml",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tc := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
err := tc.cfg.Validate()
|
||||||
|
if (err != nil) != tc.wantErr {
|
||||||
|
t.Errorf("Config.Validate() error = %v, wantErr %v", err, tc.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
42
internal/log/log.go
Normal file
42
internal/log/log.go
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
package log
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
|
"git.maronato.dev/maronato/finger/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type loggerCtxKey struct{}
|
||||||
|
|
||||||
|
// NewLogger creates a new logger with the given debug level.
|
||||||
|
func NewLogger(w io.Writer, cfg *config.Config) *slog.Logger {
|
||||||
|
level := slog.LevelInfo
|
||||||
|
addSource := false
|
||||||
|
|
||||||
|
if cfg.Debug {
|
||||||
|
level = slog.LevelDebug
|
||||||
|
addSource = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return slog.New(
|
||||||
|
slog.NewJSONHandler(w, &slog.HandlerOptions{
|
||||||
|
Level: level,
|
||||||
|
AddSource: addSource,
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func FromContext(ctx context.Context) *slog.Logger {
|
||||||
|
l, ok := ctx.Value(loggerCtxKey{}).(*slog.Logger)
|
||||||
|
if !ok {
|
||||||
|
panic("logger not found in context")
|
||||||
|
}
|
||||||
|
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithLogger(ctx context.Context, l *slog.Logger) context.Context {
|
||||||
|
return context.WithValue(ctx, loggerCtxKey{}, l)
|
||||||
|
}
|
95
internal/log/log_test.go
Normal file
95
internal/log/log_test.go
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
package log_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.maronato.dev/maronato/finger/internal/config"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func assertPanic(t *testing.T, f func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Errorf("The code did not panic")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Call the function
|
||||||
|
f()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewLogger(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("defaults to info level", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cfg := config.NewConfig()
|
||||||
|
|
||||||
|
w := &strings.Builder{}
|
||||||
|
l := log.NewLogger(w, cfg)
|
||||||
|
|
||||||
|
// It shouldn't log debug messages
|
||||||
|
l.Debug("test")
|
||||||
|
|
||||||
|
if w.String() != "" {
|
||||||
|
t.Error("logger logged debug message")
|
||||||
|
}
|
||||||
|
|
||||||
|
// It should log info messages
|
||||||
|
l.Info("test")
|
||||||
|
|
||||||
|
if w.String() == "" {
|
||||||
|
t.Error("logger did not log info message")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("logs debug messages if debug is enabled", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cfg := config.NewConfig()
|
||||||
|
cfg.Debug = true
|
||||||
|
|
||||||
|
w := &strings.Builder{}
|
||||||
|
l := log.NewLogger(w, cfg)
|
||||||
|
|
||||||
|
// It should log debug messages
|
||||||
|
l.Debug("test")
|
||||||
|
|
||||||
|
if w.String() == "" {
|
||||||
|
t.Error("logger did not log debug message")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromContext(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
cfg := config.NewConfig()
|
||||||
|
l := log.NewLogger(nil, cfg)
|
||||||
|
|
||||||
|
t.Run("panics if no logger in context", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
assertPanic(t, func() {
|
||||||
|
log.FromContext(ctx)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns logger from context", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx = log.WithLogger(ctx, l)
|
||||||
|
|
||||||
|
l2 := log.FromContext(ctx)
|
||||||
|
|
||||||
|
if l2 == nil {
|
||||||
|
t.Error("logger is nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
44
internal/middleware/log.go
Normal file
44
internal/middleware/log.go
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.maronato.dev/maronato/finger/internal/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RequestLogger(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
l := log.FromContext(ctx)
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
// Wrap the response writer
|
||||||
|
wrapped := WrapResponseWriter(w)
|
||||||
|
|
||||||
|
// Call the next handler
|
||||||
|
next.ServeHTTP(wrapped, r)
|
||||||
|
|
||||||
|
status := wrapped.Status()
|
||||||
|
|
||||||
|
// Log the request
|
||||||
|
lg := l.With(
|
||||||
|
slog.String("method", r.Method),
|
||||||
|
slog.String("path", r.URL.Path),
|
||||||
|
slog.Int("status", status),
|
||||||
|
slog.String("remote", r.RemoteAddr),
|
||||||
|
slog.Duration("duration", time.Since(start)),
|
||||||
|
)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case status >= http.StatusInternalServerError:
|
||||||
|
lg.Error("Server error")
|
||||||
|
case status >= http.StatusBadRequest:
|
||||||
|
lg.Info("Client error")
|
||||||
|
default:
|
||||||
|
lg.Info("Request completed")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
44
internal/middleware/log_test.go
Normal file
44
internal/middleware/log_test.go
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
package middleware_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.maronato.dev/maronato/finger/internal/config"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/log"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRequestLogger(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
cfg := config.NewConfig()
|
||||||
|
|
||||||
|
stdout := &strings.Builder{}
|
||||||
|
|
||||||
|
l := log.NewLogger(stdout, cfg)
|
||||||
|
ctx = log.WithLogger(ctx, l)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/", http.NoBody)
|
||||||
|
|
||||||
|
if stdout.String() != "" {
|
||||||
|
t.Error("logger logged before request")
|
||||||
|
}
|
||||||
|
|
||||||
|
middleware.RequestLogger(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})).ServeHTTP(w, r)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Error("status is not 200")
|
||||||
|
}
|
||||||
|
|
||||||
|
if stdout.String() == "" {
|
||||||
|
t.Error("logger did not log request")
|
||||||
|
}
|
||||||
|
}
|
27
internal/middleware/recover.go
Normal file
27
internal/middleware/recover.go
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.maronato.dev/maronato/finger/internal/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Recoverer(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
l := log.FromContext(ctx)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
err := recover()
|
||||||
|
if err != nil {
|
||||||
|
l.Error("Panic", slog.Any("error", err))
|
||||||
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
76
internal/middleware/recover_test.go
Normal file
76
internal/middleware/recover_test.go
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
package middleware_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.maronato.dev/maronato/finger/internal/config"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/log"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func assertNoPanic(t *testing.T, f func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Error("function panicked")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
f()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRecoverer(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
cfg := config.NewConfig()
|
||||||
|
l := log.NewLogger(&strings.Builder{}, cfg)
|
||||||
|
ctx = log.WithLogger(ctx, l)
|
||||||
|
|
||||||
|
t.Run("handles panics", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/", http.NoBody)
|
||||||
|
|
||||||
|
h := middleware.Recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
panic("test")
|
||||||
|
}))
|
||||||
|
|
||||||
|
assertNoPanic(t, func() {
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Error("status is not 500")
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.Body.String() != "Internal Server Error\n" {
|
||||||
|
t.Error("response body is not 'Internal Server Error'")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles successful requests", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/", http.NoBody)
|
||||||
|
|
||||||
|
h := middleware.Recoverer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
assertNoPanic(t, func() {
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Error("status is not 200")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
42
internal/middleware/wrapper.go
Normal file
42
internal/middleware/wrapper.go
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ResponseWrapper struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
|
||||||
|
status int
|
||||||
|
}
|
||||||
|
|
||||||
|
func WrapResponseWriter(w http.ResponseWriter) *ResponseWrapper {
|
||||||
|
return &ResponseWrapper{w, 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ResponseWrapper) WriteHeader(code int) {
|
||||||
|
w.status = code
|
||||||
|
w.ResponseWriter.WriteHeader(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ResponseWrapper) Status() int {
|
||||||
|
return w.status
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ResponseWrapper) Write(b []byte) (int, error) {
|
||||||
|
if w.status == 0 {
|
||||||
|
w.status = http.StatusOK
|
||||||
|
}
|
||||||
|
|
||||||
|
size, err := w.ResponseWriter.Write(b)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("error writing response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return size, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *ResponseWrapper) Unwrap() http.ResponseWriter {
|
||||||
|
return w.ResponseWriter
|
||||||
|
}
|
97
internal/middleware/wrapper_test.go
Normal file
97
internal/middleware/wrapper_test.go
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
package middleware_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.maronato.dev/maronato/finger/internal/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWrapResponseWriter(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
wrapped := middleware.WrapResponseWriter(w)
|
||||||
|
|
||||||
|
if wrapped == nil {
|
||||||
|
t.Error("wrapper is nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseWrapper_Status(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
wrapped := middleware.WrapResponseWriter(w)
|
||||||
|
|
||||||
|
if wrapped.Status() != 0 {
|
||||||
|
t.Error("status is not 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
wrapped.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
if wrapped.Status() != http.StatusOK {
|
||||||
|
t.Error("status is not 200")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type FailWriter struct{}
|
||||||
|
|
||||||
|
func (w *FailWriter) Write(b []byte) (int, error) {
|
||||||
|
return 0, fmt.Errorf("error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *FailWriter) Header() http.Header {
|
||||||
|
return http.Header{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *FailWriter) WriteHeader(_ int) {}
|
||||||
|
|
||||||
|
func TestResponseWrapper_Write(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("writes success messages", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
wrapped := middleware.WrapResponseWriter(w)
|
||||||
|
|
||||||
|
size, err := wrapped.Write([]byte("test"))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("error writing response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if size != 4 {
|
||||||
|
t.Error("size is not 4")
|
||||||
|
}
|
||||||
|
|
||||||
|
if wrapped.Status() != http.StatusOK {
|
||||||
|
t.Error("status is not 200")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns error on fail write", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
w := &FailWriter{}
|
||||||
|
wrapped := middleware.WrapResponseWriter(w)
|
||||||
|
|
||||||
|
_, err := wrapped.Write([]byte("test"))
|
||||||
|
if err == nil {
|
||||||
|
t.Error("error is nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseWrapper_Unwrap(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
wrapped := middleware.WrapResponseWriter(w)
|
||||||
|
|
||||||
|
if wrapped.Unwrap() != w {
|
||||||
|
t.Error("unwrapped response is not the same")
|
||||||
|
}
|
||||||
|
}
|
13
internal/server/healthcheck.go
Normal file
13
internal/server/healthcheck.go
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.maronato.dev/maronato/finger/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func HealthCheckHandler(_ *config.Config) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
}
|
40
internal/server/healthcheck_test.go
Normal file
40
internal/server/healthcheck_test.go
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
package server_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.maronato.dev/maronato/finger/internal/config"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/log"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHealthcheckHandler(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
cfg := config.NewConfig()
|
||||||
|
l := log.NewLogger(&strings.Builder{}, cfg)
|
||||||
|
|
||||||
|
ctx = log.WithLogger(ctx, l)
|
||||||
|
|
||||||
|
// Create a new request
|
||||||
|
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, "/healthz", http.NoBody)
|
||||||
|
|
||||||
|
// Create a new recorder
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Create a new handler
|
||||||
|
h := server.HealthCheckHandler(cfg)
|
||||||
|
|
||||||
|
// Serve the request
|
||||||
|
h.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
// Check the status code
|
||||||
|
if rec.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status code %d, got %d", http.StatusOK, rec.Code)
|
||||||
|
}
|
||||||
|
}
|
100
internal/server/server.go
Normal file
100
internal/server/server.go
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.maronato.dev/maronato/finger/internal/config"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/log"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/middleware"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/webfinger"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ReadTimeout is the maximum duration for reading the entire
|
||||||
|
// request, including the body.
|
||||||
|
ReadTimeout = 5 * time.Second
|
||||||
|
// WriteTimeout is the maximum duration before timing out
|
||||||
|
// writes of the response.
|
||||||
|
WriteTimeout = 10 * time.Second
|
||||||
|
// IdleTimeout is the maximum amount of time to wait for the
|
||||||
|
// next request when keep-alives are enabled.
|
||||||
|
IdleTimeout = 30 * time.Second
|
||||||
|
// ReadHeaderTimeout is the amount of time allowed to read
|
||||||
|
// request headers.
|
||||||
|
ReadHeaderTimeout = 2 * time.Second
|
||||||
|
// RequestTimeout is the maximum duration for the entire
|
||||||
|
// request.
|
||||||
|
RequestTimeout = 7 * 24 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
func StartServer(ctx context.Context, cfg *config.Config, webfingers webfinger.WebFingers) error {
|
||||||
|
l := log.FromContext(ctx)
|
||||||
|
|
||||||
|
// Create the server mux
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.Handle("/.well-known/webfinger", WebfingerHandler(cfg, webfingers))
|
||||||
|
mux.Handle("/healthz", HealthCheckHandler(cfg))
|
||||||
|
|
||||||
|
// Create a new server
|
||||||
|
srv := &http.Server{
|
||||||
|
Addr: cfg.GetAddr(),
|
||||||
|
BaseContext: func(_ net.Listener) context.Context {
|
||||||
|
return ctx
|
||||||
|
},
|
||||||
|
Handler: middleware.RequestLogger(
|
||||||
|
middleware.Recoverer(
|
||||||
|
http.TimeoutHandler(mux, RequestTimeout, "request timed out"),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
ReadHeaderTimeout: ReadHeaderTimeout,
|
||||||
|
ReadTimeout: ReadTimeout,
|
||||||
|
WriteTimeout: WriteTimeout,
|
||||||
|
IdleTimeout: IdleTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the errorgroup that will manage the server execution
|
||||||
|
eg, egCtx := errgroup.WithContext(ctx)
|
||||||
|
|
||||||
|
// Start the server
|
||||||
|
eg.Go(func() error {
|
||||||
|
l.Info("Starting server", slog.String("addr", srv.Addr))
|
||||||
|
|
||||||
|
// Use the global context for the server
|
||||||
|
srv.BaseContext = func(_ net.Listener) context.Context {
|
||||||
|
return egCtx
|
||||||
|
}
|
||||||
|
|
||||||
|
return srv.ListenAndServe() //nolint:wrapcheck // We wrap the error in the errgroup
|
||||||
|
})
|
||||||
|
// Gracefully shutdown the server when the context is done
|
||||||
|
eg.Go(func() error {
|
||||||
|
// Wait for the context to be done
|
||||||
|
<-egCtx.Done()
|
||||||
|
|
||||||
|
l.Info("Shutting down server")
|
||||||
|
// Disable the cancel since we don't wan't to force
|
||||||
|
// the server to shutdown if the context is canceled.
|
||||||
|
noCancelCtx := context.WithoutCancel(egCtx)
|
||||||
|
|
||||||
|
return srv.Shutdown(noCancelCtx) //nolint:wrapcheck // We wrap the error in the errgroup
|
||||||
|
})
|
||||||
|
|
||||||
|
// Log when the server is fully shutdown
|
||||||
|
srv.RegisterOnShutdown(func() {
|
||||||
|
l.Info("Server shutdown complete")
|
||||||
|
})
|
||||||
|
|
||||||
|
// Wait for the server to exit and check for errors that
|
||||||
|
// are not caused by the context being canceled.
|
||||||
|
if err := eg.Wait(); err != nil && ctx.Err() == nil {
|
||||||
|
return fmt.Errorf("server exited with error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
206
internal/server/server_test.go
Normal file
206
internal/server/server_test.go
Normal file
|
@ -0,0 +1,206 @@
|
||||||
|
package server_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.maronato.dev/maronato/finger/internal/config"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/log"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/server"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/webfinger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getPortGenerator() func() int {
|
||||||
|
lock := &sync.Mutex{}
|
||||||
|
port := 8080
|
||||||
|
|
||||||
|
return func() int {
|
||||||
|
lock.Lock()
|
||||||
|
defer lock.Unlock()
|
||||||
|
|
||||||
|
port++
|
||||||
|
|
||||||
|
return port
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartServer(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
portGenerator := getPortGenerator()
|
||||||
|
|
||||||
|
t.Run("starts and shuts down", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cfg := config.NewConfig()
|
||||||
|
l := log.NewLogger(&strings.Builder{}, cfg)
|
||||||
|
|
||||||
|
ctx = log.WithLogger(ctx, l)
|
||||||
|
|
||||||
|
// Use a new port
|
||||||
|
cfg.Port = fmt.Sprint(portGenerator())
|
||||||
|
|
||||||
|
// Start the server
|
||||||
|
err := server.StartServer(ctx, cfg, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("fails to start", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cfg := config.NewConfig()
|
||||||
|
l := log.NewLogger(&strings.Builder{}, cfg)
|
||||||
|
|
||||||
|
ctx = log.WithLogger(ctx, l)
|
||||||
|
|
||||||
|
// Use a new port
|
||||||
|
cfg.Port = fmt.Sprint(portGenerator())
|
||||||
|
|
||||||
|
// Use invalid host
|
||||||
|
cfg.Host = "google.com"
|
||||||
|
|
||||||
|
// Start the server
|
||||||
|
err := server.StartServer(ctx, cfg, nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("expected error, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("serves webfinger", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cfg := config.NewConfig()
|
||||||
|
l := log.NewLogger(&strings.Builder{}, cfg)
|
||||||
|
|
||||||
|
ctx = log.WithLogger(ctx, l)
|
||||||
|
|
||||||
|
// Use a new port
|
||||||
|
cfg.Port = fmt.Sprint(portGenerator())
|
||||||
|
|
||||||
|
resource := "acct:user@example.com"
|
||||||
|
webfingers := webfinger.WebFingers{
|
||||||
|
resource: &webfinger.WebFinger{
|
||||||
|
Subject: resource,
|
||||||
|
Properties: map[string]string{
|
||||||
|
"http://webfinger.net/rel/name": "John Doe",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// Start the server
|
||||||
|
err := server.StartServer(ctx, cfg, webfingers)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for the server to start
|
||||||
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
|
||||||
|
// Create a new client
|
||||||
|
c := http.Client{}
|
||||||
|
|
||||||
|
// Create a new request
|
||||||
|
r, _ := http.NewRequestWithContext(ctx,
|
||||||
|
http.MethodGet,
|
||||||
|
"http://"+cfg.GetAddr()+"/.well-known/webfinger?resource=acct:user@example.com",
|
||||||
|
http.NoBody,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Send the request
|
||||||
|
resp, err := c.Do(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Check the status code
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected status code %d, got %d", http.StatusOK, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the response body
|
||||||
|
fingerGot := &webfinger.WebFinger{}
|
||||||
|
|
||||||
|
// Decode the response body
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(fingerGot); err != nil {
|
||||||
|
t.Errorf("error decoding json: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the response body
|
||||||
|
fingerWant := webfingers[resource]
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(fingerGot, fingerWant) {
|
||||||
|
t.Errorf("expected %v, got %v", fingerWant, fingerGot)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("serves healthcheck", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cfg := config.NewConfig()
|
||||||
|
l := log.NewLogger(&strings.Builder{}, cfg)
|
||||||
|
|
||||||
|
ctx = log.WithLogger(ctx, l)
|
||||||
|
|
||||||
|
// Use a new port
|
||||||
|
cfg.Port = fmt.Sprint(portGenerator())
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
// Start the server
|
||||||
|
err := server.StartServer(ctx, cfg, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for the server to start
|
||||||
|
time.Sleep(time.Millisecond * 50)
|
||||||
|
|
||||||
|
// Create a new client
|
||||||
|
c := http.Client{}
|
||||||
|
|
||||||
|
// Create a new request
|
||||||
|
r, _ := http.NewRequestWithContext(ctx,
|
||||||
|
http.MethodGet,
|
||||||
|
"http://"+cfg.GetAddr()+"/healthz",
|
||||||
|
http.NoBody,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Send the request
|
||||||
|
resp, err := c.Do(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// Check the status code
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Errorf("expected status code %d, got %d", http.StatusOK, resp.StatusCode)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
59
internal/server/webfinger.go
Normal file
59
internal/server/webfinger.go
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.maronato.dev/maronato/finger/internal/config"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/log"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/webfinger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func WebfingerHandler(_ *config.Config, webfingers webfinger.WebFingers) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
l := log.FromContext(ctx)
|
||||||
|
|
||||||
|
// Only handle GET requests
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
l.Debug("Method not allowed")
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the query params
|
||||||
|
q := r.URL.Query()
|
||||||
|
|
||||||
|
// Get the resource
|
||||||
|
resource := q.Get("resource")
|
||||||
|
if resource == "" {
|
||||||
|
l.Debug("No resource provided")
|
||||||
|
http.Error(w, "No resource provided", http.StatusBadRequest)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get and validate resource
|
||||||
|
finger, ok := webfingers[resource]
|
||||||
|
if !ok {
|
||||||
|
l.Debug("Resource not found")
|
||||||
|
http.Error(w, "Resource not found", http.StatusNotFound)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the content type
|
||||||
|
w.Header().Set("Content-Type", "application/jrd+json")
|
||||||
|
|
||||||
|
// Write the response
|
||||||
|
if err := json.NewEncoder(w).Encode(finger); err != nil {
|
||||||
|
l.Debug("Error encoding json")
|
||||||
|
http.Error(w, "Error encoding json", http.StatusInternalServerError)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Debug("Webfinger request successful")
|
||||||
|
})
|
||||||
|
}
|
149
internal/server/webfinger_test.go
Normal file
149
internal/server/webfinger_test.go
Normal file
|
@ -0,0 +1,149 @@
|
||||||
|
package server_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.maronato.dev/maronato/finger/internal/config"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/log"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/server"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/webfinger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWebfingerHandler(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
webfingers := webfinger.WebFingers{
|
||||||
|
"acct:user@example.com": {
|
||||||
|
Subject: "acct:user@example.com",
|
||||||
|
Links: []webfinger.Link{
|
||||||
|
{
|
||||||
|
Rel: "http://webfinger.net/rel/profile-page",
|
||||||
|
Href: "https://example.com/user",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Properties: map[string]string{
|
||||||
|
"http://webfinger.net/rel/name": "John Doe",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"acct:other@example.com": {
|
||||||
|
Subject: "acct:other@example.com",
|
||||||
|
Properties: map[string]string{
|
||||||
|
"http://webfinger.net/rel/name": "Jane Doe",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"https://example.com/user": {
|
||||||
|
Subject: "https://example.com/user",
|
||||||
|
Properties: map[string]string{
|
||||||
|
"http://webfinger.net/rel/name": "John Baz",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
resource string
|
||||||
|
wantCode int
|
||||||
|
alternateMethod string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid resource",
|
||||||
|
resource: "acct:user@example.com",
|
||||||
|
wantCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "other valid resource",
|
||||||
|
resource: "acct:other@example.com",
|
||||||
|
wantCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "url resource",
|
||||||
|
resource: "https://example.com/user",
|
||||||
|
wantCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "resource missing acct:",
|
||||||
|
resource: "user@example.com",
|
||||||
|
wantCode: http.StatusNotFound,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "resource missing",
|
||||||
|
resource: "",
|
||||||
|
wantCode: http.StatusBadRequest,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid method",
|
||||||
|
resource: "acct:user@example.com",
|
||||||
|
wantCode: http.StatusMethodNotAllowed,
|
||||||
|
alternateMethod: http.MethodPost,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tc := tt
|
||||||
|
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
cfg := config.NewConfig()
|
||||||
|
l := log.NewLogger(&strings.Builder{}, cfg)
|
||||||
|
|
||||||
|
ctx = log.WithLogger(ctx, l)
|
||||||
|
|
||||||
|
// Create a new request
|
||||||
|
r, _ := http.NewRequestWithContext(ctx, tc.alternateMethod, "/.well-known/webfinger?resource="+tc.resource, http.NoBody)
|
||||||
|
|
||||||
|
// Create a new response
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Create a new handler
|
||||||
|
h := server.WebfingerHandler(cfg, webfingers)
|
||||||
|
|
||||||
|
// Serve the request
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
|
||||||
|
// Check the status code
|
||||||
|
if w.Code != tc.wantCode {
|
||||||
|
t.Errorf("expected status code %d, got %d", tc.wantCode, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the status code is 200, check the response body
|
||||||
|
if tc.wantCode == http.StatusOK {
|
||||||
|
// Check the content type
|
||||||
|
if w.Header().Get("Content-Type") != "application/jrd+json" {
|
||||||
|
t.Errorf("expected content type %s, got %s", "application/jrd+json", w.Header().Get("Content-Type"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fingerWant := webfingers[tc.resource]
|
||||||
|
fingerGot := &webfinger.WebFinger{}
|
||||||
|
|
||||||
|
// Decode the response body
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(fingerGot); err != nil {
|
||||||
|
t.Errorf("error decoding json: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort links
|
||||||
|
|
||||||
|
sort.Slice(fingerGot.Links, func(i, j int) bool {
|
||||||
|
return fingerGot.Links[i].Rel < fingerGot.Links[j].Rel
|
||||||
|
})
|
||||||
|
|
||||||
|
sort.Slice(fingerWant.Links, func(i, j int) bool {
|
||||||
|
return fingerWant.Links[i].Rel < fingerWant.Links[j].Rel
|
||||||
|
})
|
||||||
|
|
||||||
|
// Check the response body
|
||||||
|
if !reflect.DeepEqual(fingerGot, fingerWant) {
|
||||||
|
t.Errorf("expected body %v, got %v", fingerWant, fingerGot)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
160
internal/webfinger/webfinger.go
Normal file
160
internal/webfinger/webfinger.go
Normal file
|
@ -0,0 +1,160 @@
|
||||||
|
package webfinger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"net/mail"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"git.maronato.dev/maronato/finger/internal/config"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/log"
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Link struct {
|
||||||
|
Rel string `json:"rel"`
|
||||||
|
Href string `json:"href,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type WebFinger struct {
|
||||||
|
Subject string `json:"subject"`
|
||||||
|
Links []Link `json:"links,omitempty"`
|
||||||
|
Properties map[string]string `json:"properties,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type WebFingers map[string]*WebFinger
|
||||||
|
|
||||||
|
type (
|
||||||
|
URNMap = map[string]string
|
||||||
|
RawFingersMap = map[string]map[string]string
|
||||||
|
)
|
||||||
|
|
||||||
|
type FingerReader struct {
|
||||||
|
URNSFile []byte
|
||||||
|
FingersFile []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFingerReader() *FingerReader {
|
||||||
|
return &FingerReader{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FingerReader) ReadFiles(cfg *config.Config) error {
|
||||||
|
// Read URNs file
|
||||||
|
file, err := os.ReadFile(cfg.URNPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error opening URNs file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.URNSFile = file
|
||||||
|
|
||||||
|
// Read fingers file
|
||||||
|
file, err = os.ReadFile(cfg.FingerPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error opening fingers file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.FingersFile = file
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FingerReader) ParseFingers(ctx context.Context, urns URNMap, rawFingers RawFingersMap) (WebFingers, error) {
|
||||||
|
l := log.FromContext(ctx)
|
||||||
|
|
||||||
|
webfingers := make(WebFingers)
|
||||||
|
|
||||||
|
// Parse the webfinger file
|
||||||
|
for k, v := range rawFingers {
|
||||||
|
resource := k
|
||||||
|
|
||||||
|
// Remove leading acct: if present
|
||||||
|
if len(k) > 5 && resource[:5] == "acct:" {
|
||||||
|
resource = resource[5:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// The key must be a URL or email address
|
||||||
|
if _, err := mail.ParseAddress(resource); err != nil {
|
||||||
|
if _, err := url.ParseRequestURI(resource); err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing webfinger key (%s): %w", k, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Add acct: back to the key if it is an email address
|
||||||
|
resource = fmt.Sprintf("acct:%s", resource)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new webfinger
|
||||||
|
webfinger := &WebFinger{
|
||||||
|
Subject: resource,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the fields
|
||||||
|
for field, value := range v {
|
||||||
|
fieldUrn := field
|
||||||
|
|
||||||
|
// If the key is present in the URNs file, use the value
|
||||||
|
if _, ok := urns[field]; ok {
|
||||||
|
fieldUrn = urns[field]
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the value is a valid URI, add it to the links
|
||||||
|
if _, err := url.ParseRequestURI(value); err == nil {
|
||||||
|
webfinger.Links = append(webfinger.Links, Link{
|
||||||
|
Rel: fieldUrn,
|
||||||
|
Href: value,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// Otherwise add it to the properties
|
||||||
|
if webfinger.Properties == nil {
|
||||||
|
webfinger.Properties = make(map[string]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
webfinger.Properties[fieldUrn] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the webfinger to the map
|
||||||
|
webfingers[resource] = webfinger
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Debug("Webfinger map built successfully", slog.Int("number", len(webfingers)), slog.Any("data", webfingers))
|
||||||
|
|
||||||
|
return webfingers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *FingerReader) ReadFingerFile(ctx context.Context) (WebFingers, error) {
|
||||||
|
l := log.FromContext(ctx)
|
||||||
|
|
||||||
|
urnMap := make(URNMap)
|
||||||
|
fingerData := make(RawFingersMap)
|
||||||
|
|
||||||
|
// Parse the URNs file
|
||||||
|
if err := yaml.Unmarshal(f.URNSFile, &urnMap); err != nil {
|
||||||
|
return nil, fmt.Errorf("error unmarshalling URNs file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The URNs file must be a map of strings to valid URLs
|
||||||
|
for _, v := range urnMap {
|
||||||
|
if _, err := url.ParseRequestURI(v); err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing URN URIs: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Debug("URNs file parsed successfully", slog.Int("number", len(urnMap)), slog.Any("data", urnMap))
|
||||||
|
|
||||||
|
// Parse the fingers file
|
||||||
|
if err := yaml.Unmarshal(f.FingersFile, &fingerData); err != nil {
|
||||||
|
return nil, fmt.Errorf("error unmarshalling fingers file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Debug("Fingers file parsed successfully", slog.Int("number", len(fingerData)), slog.Any("data", fingerData))
|
||||||
|
|
||||||
|
// Parse raw data
|
||||||
|
webfingers, err := f.ParseFingers(ctx, urnMap, fingerData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error parsing raw fingers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return webfingers, nil
|
||||||
|
}
|
444
internal/webfinger/webfinger_test.go
Normal file
444
internal/webfinger/webfinger_test.go
Normal file
|
@ -0,0 +1,444 @@
|
||||||
|
package webfinger_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.maronato.dev/maronato/finger/internal/config"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/log"
|
||||||
|
"git.maronato.dev/maronato/finger/internal/webfinger"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTempFile(t *testing.T, content string) (name string, remove func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
f, err := os.CreateTemp("", "finger-test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error creating temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = f.WriteString(content)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error writing to temp file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f.Name(), func() {
|
||||||
|
err = os.Remove(f.Name())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("error removing temp file: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewFingerReader(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
f := webfinger.NewFingerReader()
|
||||||
|
|
||||||
|
if f == nil {
|
||||||
|
t.Errorf("NewFingerReader() = %v, want: %v", f, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFingerReader_ReadFiles(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
urnsContent string
|
||||||
|
fingersContent string
|
||||||
|
useURNFile bool
|
||||||
|
useFingerFile bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "reads files",
|
||||||
|
urnsContent: "name: https://schema/name\nprofile: https://schema/profile",
|
||||||
|
fingersContent: "user@example.com:\n name: John Doe",
|
||||||
|
useURNFile: true,
|
||||||
|
useFingerFile: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "errors on missing URNs file",
|
||||||
|
urnsContent: "invalid",
|
||||||
|
fingersContent: "user@example.com:\n name: John Doe",
|
||||||
|
useURNFile: false,
|
||||||
|
useFingerFile: true,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "errors on missing fingers file",
|
||||||
|
urnsContent: "name: https://schema/name\nprofile: https://schema/profile",
|
||||||
|
fingersContent: "invalid",
|
||||||
|
useFingerFile: false,
|
||||||
|
useURNFile: true,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tc := tt
|
||||||
|
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cfg := config.NewConfig()
|
||||||
|
|
||||||
|
urnsFileName, urnsCleanup := newTempFile(t, tc.urnsContent)
|
||||||
|
defer urnsCleanup()
|
||||||
|
|
||||||
|
fingersFileName, fingersCleanup := newTempFile(t, tc.fingersContent)
|
||||||
|
defer fingersCleanup()
|
||||||
|
|
||||||
|
if !tc.useURNFile {
|
||||||
|
cfg.URNPath = "invalid"
|
||||||
|
} else {
|
||||||
|
cfg.URNPath = urnsFileName
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tc.useFingerFile {
|
||||||
|
cfg.FingerPath = "invalid"
|
||||||
|
} else {
|
||||||
|
cfg.FingerPath = fingersFileName
|
||||||
|
}
|
||||||
|
|
||||||
|
f := webfinger.NewFingerReader()
|
||||||
|
|
||||||
|
err := f.ReadFiles(cfg)
|
||||||
|
if err != nil {
|
||||||
|
if !tc.wantErr {
|
||||||
|
t.Errorf("ReadFiles() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
} else if tc.wantErr {
|
||||||
|
t.Errorf("ReadFiles() error = %v, wantErr %v", err, tc.wantErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(f.URNSFile, []byte(tc.urnsContent)) {
|
||||||
|
t.Errorf("ReadFiles() URNsFile = %v, want: %v", f.URNSFile, tc.urnsContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(f.FingersFile, []byte(tc.fingersContent)) {
|
||||||
|
t.Errorf("ReadFiles() FingersFile = %v, want: %v", f.FingersFile, tc.fingersContent)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFingers(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rawFingers webfinger.RawFingersMap
|
||||||
|
want webfinger.WebFingers
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "parses links",
|
||||||
|
rawFingers: webfinger.RawFingersMap{
|
||||||
|
"user@example.com": {
|
||||||
|
"profile": "https://example.com/profile",
|
||||||
|
"invalidalias": "https://example.com/invalidalias",
|
||||||
|
"https://something": "https://somethingelse",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: webfinger.WebFingers{
|
||||||
|
"acct:user@example.com": {
|
||||||
|
Subject: "acct:user@example.com",
|
||||||
|
Links: []webfinger.Link{
|
||||||
|
{
|
||||||
|
Rel: "https://schema/profile",
|
||||||
|
Href: "https://example.com/profile",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Rel: "invalidalias",
|
||||||
|
Href: "https://example.com/invalidalias",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Rel: "https://something",
|
||||||
|
Href: "https://somethingelse",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parses properties",
|
||||||
|
rawFingers: webfinger.RawFingersMap{
|
||||||
|
"user@example.com": {
|
||||||
|
"name": "John Doe",
|
||||||
|
"invalidalias": "value1",
|
||||||
|
"https://mylink": "value2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: webfinger.WebFingers{
|
||||||
|
"acct:user@example.com": {
|
||||||
|
Subject: "acct:user@example.com",
|
||||||
|
Properties: map[string]string{
|
||||||
|
"https://schema/name": "John Doe",
|
||||||
|
"invalidalias": "value1",
|
||||||
|
"https://mylink": "value2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "accepts acct: prefix",
|
||||||
|
rawFingers: webfinger.RawFingersMap{
|
||||||
|
"acct:user@example.com": {
|
||||||
|
"name": "John Doe",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: webfinger.WebFingers{
|
||||||
|
"acct:user@example.com": {
|
||||||
|
Subject: "acct:user@example.com",
|
||||||
|
Properties: map[string]string{
|
||||||
|
"https://schema/name": "John Doe",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "accepts urls as resource",
|
||||||
|
rawFingers: webfinger.RawFingersMap{
|
||||||
|
"https://example.com": {
|
||||||
|
"name": "John Doe",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: webfinger.WebFingers{
|
||||||
|
"https://example.com": {
|
||||||
|
Subject: "https://example.com",
|
||||||
|
Properties: map[string]string{
|
||||||
|
"https://schema/name": "John Doe",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "accepts multiple resources",
|
||||||
|
rawFingers: webfinger.RawFingersMap{
|
||||||
|
"user@example.com": {
|
||||||
|
"name": "John Doe",
|
||||||
|
},
|
||||||
|
"other@example.com": {
|
||||||
|
"name": "Jane Doe",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: webfinger.WebFingers{
|
||||||
|
"acct:user@example.com": {
|
||||||
|
Subject: "acct:user@example.com",
|
||||||
|
Properties: map[string]string{
|
||||||
|
"https://schema/name": "John Doe",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"acct:other@example.com": {
|
||||||
|
Subject: "acct:other@example.com",
|
||||||
|
Properties: map[string]string{
|
||||||
|
"https://schema/name": "Jane Doe",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "errors on invalid resource",
|
||||||
|
rawFingers: webfinger.RawFingersMap{
|
||||||
|
"invalid": {
|
||||||
|
"name": "John Doe",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tc := tt
|
||||||
|
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Create a urn map
|
||||||
|
urns := webfinger.URNMap{
|
||||||
|
"name": "https://schema/name",
|
||||||
|
"profile": "https://schema/profile",
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
cfg := config.NewConfig()
|
||||||
|
l := log.NewLogger(&strings.Builder{}, cfg)
|
||||||
|
|
||||||
|
ctx = log.WithLogger(ctx, l)
|
||||||
|
|
||||||
|
f := webfinger.NewFingerReader()
|
||||||
|
|
||||||
|
got, err := f.ParseFingers(ctx, urns, tc.rawFingers)
|
||||||
|
if (err != nil) != tc.wantErr {
|
||||||
|
t.Errorf("ParseFingers() error = %v, wantErr %v", err, tc.wantErr)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort links to make it easier to compare
|
||||||
|
for _, v := range got {
|
||||||
|
for range v.Links {
|
||||||
|
sort.Slice(v.Links, func(i, j int) bool {
|
||||||
|
return v.Links[i].Rel < v.Links[j].Rel
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range tc.want {
|
||||||
|
for range v.Links {
|
||||||
|
sort.Slice(v.Links, func(i, j int) bool {
|
||||||
|
return v.Links[i].Rel < v.Links[j].Rel
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(got, tc.want) {
|
||||||
|
// Unmarshal the structs to JSON to make it easier to print
|
||||||
|
gotstr := &strings.Builder{}
|
||||||
|
gotenc := json.NewEncoder(gotstr)
|
||||||
|
|
||||||
|
wantstr := &strings.Builder{}
|
||||||
|
wantenc := json.NewEncoder(wantstr)
|
||||||
|
|
||||||
|
_ = gotenc.Encode(got)
|
||||||
|
_ = wantenc.Encode(tc.want)
|
||||||
|
|
||||||
|
t.Errorf("ParseFingers() got = \n%s want: \n%s", gotstr.String(), wantstr.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadFingerFile(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
urnsContent string
|
||||||
|
fingersContent string
|
||||||
|
wantURN webfinger.URNMap
|
||||||
|
wantFinger webfinger.RawFingersMap
|
||||||
|
returns *webfinger.WebFingers
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "reads files",
|
||||||
|
urnsContent: "name: https://schema/name\nprofile: https://schema/profile",
|
||||||
|
fingersContent: "user@example.com:\n name: John Doe",
|
||||||
|
wantURN: webfinger.URNMap{
|
||||||
|
"name": "https://schema/name",
|
||||||
|
"profile": "https://schema/profile",
|
||||||
|
},
|
||||||
|
wantFinger: webfinger.RawFingersMap{
|
||||||
|
"user@example.com": {
|
||||||
|
"name": "John Doe",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
returns: &webfinger.WebFingers{
|
||||||
|
"acct:user@example.com": {
|
||||||
|
Subject: "acct:user@example.com",
|
||||||
|
Properties: map[string]string{
|
||||||
|
"https://schema/name": "John Doe",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "uses custom URNs",
|
||||||
|
urnsContent: "favorite_food: https://schema/favorite_food",
|
||||||
|
fingersContent: "user@example.com:\n favorite_food: Apple",
|
||||||
|
wantURN: webfinger.URNMap{
|
||||||
|
"favorite_food": "https://schema/favorite_food",
|
||||||
|
},
|
||||||
|
wantFinger: webfinger.RawFingersMap{
|
||||||
|
"user@example.com": {
|
||||||
|
"https://schema/favorite_food": "Apple",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "errors on invalid URNs file",
|
||||||
|
urnsContent: "invalid",
|
||||||
|
fingersContent: "user@example.com:\n name: John Doe",
|
||||||
|
wantURN: webfinger.URNMap{},
|
||||||
|
wantFinger: webfinger.RawFingersMap{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "errors on invalid fingers file",
|
||||||
|
urnsContent: "name: https://schema/name\nprofile: https://schema/profile",
|
||||||
|
fingersContent: "invalid",
|
||||||
|
wantURN: webfinger.URNMap{},
|
||||||
|
wantFinger: webfinger.RawFingersMap{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "errors on invalid URNs values",
|
||||||
|
urnsContent: "name: invalid",
|
||||||
|
fingersContent: "user@example.com:\n name: John Doe",
|
||||||
|
wantURN: webfinger.URNMap{},
|
||||||
|
wantFinger: webfinger.RawFingersMap{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "errors on invalid fingers values",
|
||||||
|
urnsContent: "name: https://schema/name\nprofile: https://schema/profile",
|
||||||
|
fingersContent: "invalid:\n name: John Doe",
|
||||||
|
wantURN: webfinger.URNMap{},
|
||||||
|
wantFinger: webfinger.RawFingersMap{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tc := tt
|
||||||
|
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
cfg := config.NewConfig()
|
||||||
|
l := log.NewLogger(&strings.Builder{}, cfg)
|
||||||
|
|
||||||
|
ctx = log.WithLogger(ctx, l)
|
||||||
|
|
||||||
|
f := webfinger.NewFingerReader()
|
||||||
|
|
||||||
|
f.FingersFile = []byte(tc.fingersContent)
|
||||||
|
f.URNSFile = []byte(tc.urnsContent)
|
||||||
|
|
||||||
|
got, err := f.ReadFingerFile(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if !tc.wantErr {
|
||||||
|
t.Errorf("ReadFingerFile() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
} else if tc.wantErr {
|
||||||
|
t.Errorf("ReadFingerFile() error = %v, wantErr %v", err, tc.wantErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.returns != nil && !reflect.DeepEqual(got, *tc.returns) {
|
||||||
|
t.Errorf("ReadFingerFile() got = %v, want: %v", got, *tc.returns)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
553
main.go
553
main.go
|
@ -1,566 +1,19 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/mail"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/peterbourgon/ff/v4"
|
"git.maronato.dev/maronato/finger/cmd"
|
||||||
"github.com/peterbourgon/ff/v4/ffhelp"
|
|
||||||
"golang.org/x/exp/slog"
|
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const appName = "finger"
|
// Version of the app.
|
||||||
|
|
||||||
// Version of the application.
|
|
||||||
var version = "dev"
|
var version = "dev"
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Run the server
|
// Run the server
|
||||||
if err := Run(); err != nil {
|
if err := cmd.Run(version); err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Run() error {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Allow graceful shutdown
|
|
||||||
trapSignalsCrossPlatform(cancel)
|
|
||||||
|
|
||||||
cfg := &Config{}
|
|
||||||
|
|
||||||
// Create a new root command
|
|
||||||
subcommands := []*ff.Command{
|
|
||||||
NewServerCmd(cfg),
|
|
||||||
NewHealthcheckCmd(cfg),
|
|
||||||
}
|
|
||||||
cmd := NewRootCmd(cfg, subcommands)
|
|
||||||
|
|
||||||
// Parse and run
|
|
||||||
if err := cmd.ParseAndRun(ctx, os.Args[1:], ff.WithEnvVarPrefix("WF")); err != nil {
|
|
||||||
if errors.Is(err, ff.ErrHelp) || errors.Is(err, ff.ErrNoExec) {
|
|
||||||
fmt.Fprintf(os.Stderr, "\n%s\n", ffhelp.Command(cmd))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("error running command: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewServerCmd(cfg *Config) *ff.Command {
|
|
||||||
return &ff.Command{
|
|
||||||
Name: "serve",
|
|
||||||
Usage: "serve [flags]",
|
|
||||||
ShortHelp: "Start the webfinger server",
|
|
||||||
Exec: func(ctx context.Context, args []string) error {
|
|
||||||
// Create a logger and add it to the context
|
|
||||||
l := NewLogger(cfg)
|
|
||||||
ctx = WithLogger(ctx, l)
|
|
||||||
|
|
||||||
// Parse the webfinger files
|
|
||||||
fingermap, err := ParseFingerFile(ctx, cfg)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error parsing finger files: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
l.Info(fmt.Sprintf("Loaded %d webfingers", len(fingermap)))
|
|
||||||
|
|
||||||
// Start the server
|
|
||||||
if err := StartServer(ctx, cfg, fingermap); err != nil {
|
|
||||||
return fmt.Errorf("error running server: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewHealthcheckCmd(cfg *Config) *ff.Command {
|
|
||||||
return &ff.Command{
|
|
||||||
Name: "healthcheck",
|
|
||||||
Usage: "healthcheck [flags]",
|
|
||||||
ShortHelp: "Check if the server is running",
|
|
||||||
Exec: func(ctx context.Context, args []string) error {
|
|
||||||
// Create a new client
|
|
||||||
client := &http.Client{
|
|
||||||
Timeout: 5 * time.Second, //nolint:gomnd // We want to use a constant
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new request
|
|
||||||
reqURL := url.URL{
|
|
||||||
Scheme: "http",
|
|
||||||
Host: net.JoinHostPort(cfg.Host, cfg.Port),
|
|
||||||
Path: "/healthz",
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqURL.String(), http.NoBody)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error creating request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send the request
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error sending request: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
// Check the response
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return fmt.Errorf("server returned status %d", resp.StatusCode) //nolint:goerr113 // We want to return an error
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type loggerCtxKey struct{}
|
|
||||||
|
|
||||||
// NewLogger creates a new logger with the given debug level.
|
|
||||||
func NewLogger(cfg *Config) *slog.Logger {
|
|
||||||
level := slog.LevelInfo
|
|
||||||
addSource := false
|
|
||||||
|
|
||||||
if cfg.Debug {
|
|
||||||
level = slog.LevelDebug
|
|
||||||
addSource = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return slog.New(
|
|
||||||
slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
|
|
||||||
Level: level,
|
|
||||||
AddSource: addSource,
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoggerFromContext(ctx context.Context) *slog.Logger {
|
|
||||||
l, ok := ctx.Value(loggerCtxKey{}).(*slog.Logger)
|
|
||||||
if !ok {
|
|
||||||
panic("logger not found in context")
|
|
||||||
}
|
|
||||||
|
|
||||||
return l
|
|
||||||
}
|
|
||||||
|
|
||||||
func WithLogger(ctx context.Context, l *slog.Logger) context.Context {
|
|
||||||
return context.WithValue(ctx, loggerCtxKey{}, l)
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/caddyserver/caddy/blob/fbb0ecfa322aa7710a3448453fd3ae40f037b8d1/sigtrap.go#L37
|
|
||||||
// trapSignalsCrossPlatform captures SIGINT or interrupt (depending
|
|
||||||
// on the OS), which initiates a graceful shutdown. A second SIGINT
|
|
||||||
// or interrupt will forcefully exit the process immediately.
|
|
||||||
func trapSignalsCrossPlatform(cancel context.CancelFunc) {
|
|
||||||
go func() {
|
|
||||||
shutdown := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(shutdown, os.Interrupt, syscall.SIGINT)
|
|
||||||
|
|
||||||
for i := 0; true; i++ {
|
|
||||||
<-shutdown
|
|
||||||
|
|
||||||
if i > 0 {
|
|
||||||
fmt.Printf("\nForce quit\n") //nolint:forbidigo // We want to print to stdout
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("\nGracefully shutting down. Press Ctrl+C again to force quit\n") //nolint:forbidigo // We want to print to stdout
|
|
||||||
cancel()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
type Config struct {
|
|
||||||
Debug bool
|
|
||||||
Host string
|
|
||||||
Port string
|
|
||||||
urnPath string
|
|
||||||
fingerPath string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRootCmd parses the command line flags and returns a Config struct.
|
|
||||||
func NewRootCmd(cfg *Config, subcommands []*ff.Command) *ff.Command {
|
|
||||||
fs := ff.NewFlagSet(appName)
|
|
||||||
|
|
||||||
for _, cmd := range subcommands {
|
|
||||||
cmd.Flags = ff.NewFlagSet(cmd.Name).SetParent(fs)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := &ff.Command{
|
|
||||||
Name: appName,
|
|
||||||
Usage: fmt.Sprintf("%s <command> [flags]", appName),
|
|
||||||
ShortHelp: fmt.Sprintf("(%s) A webfinger server", version),
|
|
||||||
Flags: fs,
|
|
||||||
Subcommands: subcommands,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use 0.0.0.0 as the default host if on docker
|
|
||||||
defaultHost := "localhost"
|
|
||||||
if os.Getenv("ENV_DOCKER") == "true" {
|
|
||||||
defaultHost = "0.0.0.0"
|
|
||||||
}
|
|
||||||
|
|
||||||
fs.BoolVar(&cfg.Debug, 'd', "debug", "Enable debug logging")
|
|
||||||
fs.StringVar(&cfg.Host, 'h', "host", defaultHost, "Host to listen on")
|
|
||||||
fs.StringVar(&cfg.Port, 'p', "port", "8080", "Port to listen on")
|
|
||||||
fs.StringVar(&cfg.urnPath, 'u', "urn-file", "urns.yml", "Path to the URNs file")
|
|
||||||
fs.StringVar(&cfg.fingerPath, 'f', "finger-file", "fingers.yml", "Path to the fingers file")
|
|
||||||
|
|
||||||
return cmd
|
|
||||||
}
|
|
||||||
|
|
||||||
type Link struct {
|
|
||||||
Rel string `json:"rel"`
|
|
||||||
Href string `json:"href,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type WebFinger struct {
|
|
||||||
Subject string `json:"subject"`
|
|
||||||
Links []Link `json:"links,omitempty"`
|
|
||||||
Properties map[string]string `json:"properties,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type WebFingerMap map[string]*WebFinger
|
|
||||||
|
|
||||||
func ParseFingerFile(ctx context.Context, cfg *Config) (WebFingerMap, error) {
|
|
||||||
l := LoggerFromContext(ctx)
|
|
||||||
|
|
||||||
urnMap := make(map[string]string)
|
|
||||||
fingerData := make(map[string]map[string]string)
|
|
||||||
|
|
||||||
fingermap := make(WebFingerMap)
|
|
||||||
|
|
||||||
// Read URNs file
|
|
||||||
file, err := os.ReadFile(cfg.urnPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error opening URNs file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := yaml.Unmarshal(file, &urnMap); err != nil {
|
|
||||||
return nil, fmt.Errorf("error unmarshalling URNs file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The URNs file must be a map of strings to valid URLs
|
|
||||||
for _, v := range urnMap {
|
|
||||||
if _, err := url.Parse(v); err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing URN URIs: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
l.Debug("URNs file parsed successfully", slog.Int("number", len(urnMap)), slog.Any("data", urnMap))
|
|
||||||
|
|
||||||
// Read webfingers file
|
|
||||||
file, err = os.ReadFile(cfg.fingerPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error opening fingers file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := yaml.Unmarshal(file, &fingerData); err != nil {
|
|
||||||
return nil, fmt.Errorf("error unmarshalling fingers file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
l.Debug("Fingers file parsed successfully", slog.Int("number", len(fingerData)), slog.Any("data", fingerData))
|
|
||||||
|
|
||||||
// Parse the webfinger file
|
|
||||||
for k, v := range fingerData {
|
|
||||||
resource := k
|
|
||||||
|
|
||||||
// Remove leading acct: if present
|
|
||||||
if len(k) > 5 && resource[:5] == "acct:" {
|
|
||||||
resource = resource[5:]
|
|
||||||
}
|
|
||||||
|
|
||||||
// The key must be a URL or email address
|
|
||||||
if _, err := mail.ParseAddress(resource); err != nil {
|
|
||||||
if _, err := url.Parse(resource); err != nil {
|
|
||||||
return nil, fmt.Errorf("error parsing webfinger key (%s): %w", k, err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Add acct: back to the key if it is an email address
|
|
||||||
resource = fmt.Sprintf("acct:%s", resource)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new webfinger
|
|
||||||
webfinger := &WebFinger{
|
|
||||||
Subject: resource,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the fields
|
|
||||||
for field, value := range v {
|
|
||||||
fieldUrn := field
|
|
||||||
|
|
||||||
// If the key is present in the URNs file, use the value
|
|
||||||
if _, ok := urnMap[field]; ok {
|
|
||||||
fieldUrn = urnMap[field]
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the value is a valid URI, add it to the links
|
|
||||||
if _, err := url.Parse(value); err == nil {
|
|
||||||
webfinger.Links = append(webfinger.Links, Link{
|
|
||||||
Rel: fieldUrn,
|
|
||||||
Href: value,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
// Otherwise add it to the properties
|
|
||||||
if webfinger.Properties == nil {
|
|
||||||
webfinger.Properties = make(map[string]string)
|
|
||||||
}
|
|
||||||
|
|
||||||
webfinger.Properties[fieldUrn] = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the webfinger to the map
|
|
||||||
fingermap[resource] = webfinger
|
|
||||||
}
|
|
||||||
|
|
||||||
l.Debug("Webfinger map built successfully", slog.Int("number", len(fingermap)), slog.Any("data", fingermap))
|
|
||||||
|
|
||||||
return fingermap, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func WebfingerHandler(_ *Config, webmap WebFingerMap) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx := r.Context()
|
|
||||||
l := LoggerFromContext(ctx)
|
|
||||||
|
|
||||||
// Only handle GET requests
|
|
||||||
if r.Method != http.MethodGet {
|
|
||||||
l.Debug("Method not allowed")
|
|
||||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the query params
|
|
||||||
q := r.URL.Query()
|
|
||||||
|
|
||||||
// Get the resource
|
|
||||||
resource := q.Get("resource")
|
|
||||||
if resource == "" {
|
|
||||||
l.Debug("No resource provided")
|
|
||||||
http.Error(w, "No resource provided", http.StatusBadRequest)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get and validate resource
|
|
||||||
webfinger, ok := webmap[resource]
|
|
||||||
if !ok {
|
|
||||||
l.Debug("Resource not found")
|
|
||||||
http.Error(w, "Resource not found", http.StatusNotFound)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the content type
|
|
||||||
w.Header().Set("Content-Type", "application/jrd+json")
|
|
||||||
|
|
||||||
// Write the response
|
|
||||||
if err := json.NewEncoder(w).Encode(webfinger); err != nil {
|
|
||||||
l.Debug("Error encoding json")
|
|
||||||
http.Error(w, "Error encoding json", http.StatusInternalServerError)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
l.Debug("Webfinger request successful")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func HealthCheckHandler(_ *Config) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type ResponseWrapper struct {
|
|
||||||
http.ResponseWriter
|
|
||||||
|
|
||||||
status int
|
|
||||||
}
|
|
||||||
|
|
||||||
func WrapResponseWriter(w http.ResponseWriter) *ResponseWrapper {
|
|
||||||
return &ResponseWrapper{w, 0}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *ResponseWrapper) WriteHeader(code int) {
|
|
||||||
w.status = code
|
|
||||||
w.ResponseWriter.WriteHeader(code)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *ResponseWrapper) Status() int {
|
|
||||||
return w.status
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *ResponseWrapper) Write(b []byte) (int, error) {
|
|
||||||
if w.status == 0 {
|
|
||||||
w.status = http.StatusOK
|
|
||||||
}
|
|
||||||
|
|
||||||
size, err := w.ResponseWriter.Write(b)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("error writing response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return size, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *ResponseWrapper) Unwrap() http.ResponseWriter {
|
|
||||||
return w.ResponseWriter
|
|
||||||
}
|
|
||||||
|
|
||||||
func LoggingMiddleware(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx := r.Context()
|
|
||||||
l := LoggerFromContext(ctx)
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
// Wrap the response writer
|
|
||||||
wrapped := WrapResponseWriter(w)
|
|
||||||
|
|
||||||
// Call the next handler
|
|
||||||
next.ServeHTTP(wrapped, r)
|
|
||||||
|
|
||||||
status := wrapped.Status()
|
|
||||||
|
|
||||||
// Log the request
|
|
||||||
lg := l.With(
|
|
||||||
slog.String("method", r.Method),
|
|
||||||
slog.String("path", r.URL.Path),
|
|
||||||
slog.Int("status", status),
|
|
||||||
slog.String("remote", r.RemoteAddr),
|
|
||||||
slog.Duration("duration", time.Since(start)),
|
|
||||||
)
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case status >= http.StatusInternalServerError:
|
|
||||||
lg.Error("Server error")
|
|
||||||
case status >= http.StatusBadRequest:
|
|
||||||
lg.Info("Client error")
|
|
||||||
default:
|
|
||||||
lg.Info("Request completed")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
// ReadTimeout is the maximum duration for reading the entire
|
|
||||||
// request, including the body.
|
|
||||||
ReadTimeout = 5 * time.Second
|
|
||||||
// WriteTimeout is the maximum duration before timing out
|
|
||||||
// writes of the response.
|
|
||||||
WriteTimeout = 10 * time.Second
|
|
||||||
// IdleTimeout is the maximum amount of time to wait for the
|
|
||||||
// next request when keep-alives are enabled.
|
|
||||||
IdleTimeout = 30 * time.Second
|
|
||||||
// ReadHeaderTimeout is the amount of time allowed to read
|
|
||||||
// request headers.
|
|
||||||
ReadHeaderTimeout = 2 * time.Second
|
|
||||||
// RequestTimeout is the maximum duration for the entire
|
|
||||||
// request.
|
|
||||||
RequestTimeout = 7 * 24 * time.Hour
|
|
||||||
)
|
|
||||||
|
|
||||||
func StartServer(ctx context.Context, cfg *Config, webmap WebFingerMap) error {
|
|
||||||
l := LoggerFromContext(ctx)
|
|
||||||
|
|
||||||
// Create the server mux
|
|
||||||
mux := http.NewServeMux()
|
|
||||||
mux.Handle("/.well-known/webfinger", WebfingerHandler(cfg, webmap))
|
|
||||||
mux.Handle("/healthz", HealthCheckHandler(cfg))
|
|
||||||
|
|
||||||
// Create a new server
|
|
||||||
srv := &http.Server{
|
|
||||||
Addr: net.JoinHostPort(cfg.Host, cfg.Port),
|
|
||||||
BaseContext: func(_ net.Listener) context.Context {
|
|
||||||
return ctx
|
|
||||||
},
|
|
||||||
Handler: LoggingMiddleware(
|
|
||||||
RecoveryHandler(
|
|
||||||
http.TimeoutHandler(mux, RequestTimeout, "request timed out"),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
ReadHeaderTimeout: ReadHeaderTimeout,
|
|
||||||
ReadTimeout: ReadTimeout,
|
|
||||||
WriteTimeout: WriteTimeout,
|
|
||||||
IdleTimeout: IdleTimeout,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the errorgroup that will manage the server execution
|
|
||||||
eg, egCtx := errgroup.WithContext(ctx)
|
|
||||||
|
|
||||||
// Start the server
|
|
||||||
eg.Go(func() error {
|
|
||||||
l.Info("Starting server", slog.String("addr", srv.Addr))
|
|
||||||
|
|
||||||
// Use the global context for the server
|
|
||||||
srv.BaseContext = func(_ net.Listener) context.Context {
|
|
||||||
return egCtx
|
|
||||||
}
|
|
||||||
|
|
||||||
return srv.ListenAndServe() //nolint:wrapcheck // We wrap the error in the errgroup
|
|
||||||
})
|
|
||||||
// Gracefully shutdown the server when the context is done
|
|
||||||
eg.Go(func() error {
|
|
||||||
// Wait for the context to be done
|
|
||||||
<-egCtx.Done()
|
|
||||||
|
|
||||||
l.Info("Shutting down server")
|
|
||||||
// Disable the cancel since we don't wan't to force
|
|
||||||
// the server to shutdown if the context is canceled.
|
|
||||||
noCancelCtx := context.WithoutCancel(egCtx)
|
|
||||||
|
|
||||||
return srv.Shutdown(noCancelCtx) //nolint:wrapcheck // We wrap the error in the errgroup
|
|
||||||
})
|
|
||||||
|
|
||||||
srv.RegisterOnShutdown(func() {
|
|
||||||
l.Info("Server shutdown complete")
|
|
||||||
})
|
|
||||||
|
|
||||||
// Ignore the error if the context was canceled
|
|
||||||
if err := eg.Wait(); err != nil && ctx.Err() == nil {
|
|
||||||
return fmt.Errorf("server exited with error: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func RecoveryHandler(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx := r.Context()
|
|
||||||
l := LoggerFromContext(ctx)
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
err := recover()
|
|
||||||
if err != nil {
|
|
||||||
l.Error("Panic", slog.Any("error", err))
|
|
||||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
60
main_test.go
60
main_test.go
|
@ -1,60 +0,0 @@
|
||||||
package main_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
finger "git.maronato.dev/maronato/finger"
|
|
||||||
)
|
|
||||||
|
|
||||||
func BenchmarkGetWebfinger(b *testing.B) {
|
|
||||||
ctx := context.Background()
|
|
||||||
cfg := &finger.Config{}
|
|
||||||
l := finger.NewLogger(cfg)
|
|
||||||
|
|
||||||
ctx = finger.WithLogger(ctx, l)
|
|
||||||
resource := "acct:user@example.com"
|
|
||||||
webmap := finger.WebFingerMap{
|
|
||||||
resource: {
|
|
||||||
Subject: resource,
|
|
||||||
Links: []finger.Link{
|
|
||||||
{
|
|
||||||
Rel: "http://webfinger.net/rel/avatar",
|
|
||||||
Href: "https://example.com/avatar.png",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Properties: map[string]string{
|
|
||||||
"example": "value",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"acct:other": {
|
|
||||||
Subject: "acct:other",
|
|
||||||
Links: []finger.Link{
|
|
||||||
{
|
|
||||||
Rel: "http://webfinger.net/rel/avatar",
|
|
||||||
Href: "https://example.com/avatar.png",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Properties: map[string]string{
|
|
||||||
"example": "value",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
handler := finger.WebfingerHandler(&finger.Config{}, webmap)
|
|
||||||
|
|
||||||
r, _ := http.NewRequestWithContext(
|
|
||||||
ctx,
|
|
||||||
http.MethodGet,
|
|
||||||
fmt.Sprintf("/.well-known/webfinger?resource=%s", resource),
|
|
||||||
http.NoBody,
|
|
||||||
)
|
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
handler.ServeHTTP(w, r)
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Add table
Reference in a new issue