linter & use mutex
All checks were successful
tests / go test (push) Successful in 8s
tests / golangci-lint (push) Successful in 14s

This commit is contained in:
Jared Allard 2025-03-06 21:39:37 -08:00
parent affa5919f7
commit c86cc3da61
Signed by: jaredallard
SSH key fingerprint: SHA256:wyRyyv28jBYw8Yp/oABNPUYvbGd6hyZj23XVXEm5G/U

View file

@ -21,7 +21,6 @@ package server
import (
"context"
"fmt"
"io"
"net"
"sync"
"time"
@ -35,21 +34,6 @@ import (
"google.golang.org/grpc/reflection"
)
// nopWriteCloser is a no-op [io.WriteCloser]
type nopWriteCloser struct {
io.Writer
}
// Close implements [io.Closer]
func (nwc nopWriteCloser) Close() error {
return nil
}
// newNopWriteCloser creates a new nopWriteCloser
func newNopWriteCloser(w io.Writer) *nopWriteCloser {
return &nopWriteCloser{w}
}
// Session represents a session where a machine is attempting to receive
// a private key from SubmitKey.
type Session struct {
@ -69,7 +53,7 @@ type Server struct {
// ses is a machine_id -> Session map
ses map[string]*Session
sesMu sync.Mutex
sesMu sync.RWMutex
pbgrpcv1.UnimplementedKlefkiServiceServer
}
@ -100,18 +84,24 @@ func (s *Server) Run(ctx context.Context) error {
}
// GetTime implements the GetTime RPC
func (s *Server) GetTime(_ context.Context, req *pbgrpcv1.GetTimeRequest) (*pbgrpcv1.GetTimeResponse, error) {
func (s *Server) GetTime(_ context.Context, _ *pbgrpcv1.GetTimeRequest) (*pbgrpcv1.GetTimeResponse, error) {
resp := &pbgrpcv1.GetTimeResponse{}
resp.SetTime(time.Now().Format(time.RFC3339Nano))
return resp, nil
}
// SubmitKey implements the SubmitKey RPC
func (s *Server) SubmitKey(ctx context.Context, req *pbgrpcv1.SubmitKeyRequest) (*pbgrpcv1.SubmitKeyResponse, error) {
func (s *Server) SubmitKey(_ context.Context, req *pbgrpcv1.SubmitKeyRequest) (*pbgrpcv1.SubmitKeyResponse, error) {
machineID := req.GetMachineId()
s.sesMu.RLock()
if _, ok := s.ses[machineID]; !ok {
s.sesMu.RUnlock()
return nil, fmt.Errorf("failed to find machine ID %q", machineID)
}
s.sesMu.RUnlock()
s.sesMu.Lock()
defer s.sesMu.Unlock()
s.ses[machineID].EncKey = req.GetEncKey()
return &pbgrpcv1.SubmitKeyResponse{}, nil
@ -127,6 +117,10 @@ func (s *Server) GetKey(ctx context.Context, req *pbgrpcv1.GetKeyRequest) (*pbgr
return nil, fmt.Errorf("failed to parsed signed at %q: %w", req.GetSignedAt(), err)
}
ts = ts.UTC() // Always operate with UTC time.
if time.Since(ts) < 5*time.Minute {
return nil, fmt.Errorf("signature has expired")
}
sig := req.GetSignature()
machine, err := s.db.Machine.Get(ctx, req.GetMachineId())
@ -152,6 +146,9 @@ func (s *Server) GetKey(ctx context.Context, req *pbgrpcv1.GetKeyRequest) (*pbgr
return nil, fmt.Errorf("failed to add instance public key to encryptor: %w", err)
}
s.sesMu.Lock()
defer s.sesMu.Unlock()
// Track the last time the machine asked for a key. This is what backs
// the sessions api
if _, ok := s.ses[machine.ID]; !ok {
@ -172,6 +169,9 @@ func (s *Server) GetKey(ctx context.Context, req *pbgrpcv1.GetKeyRequest) (*pbgr
// ListSessions implements the ListSessions RPC.
func (s *Server) ListSessions(ctx context.Context, _ *pbgrpcv1.ListSessionsRequest) (*pbgrpcv1.ListSessionsResponse, error) {
s.sesMu.RLock()
defer s.sesMu.RUnlock()
resp := &pbgrpcv1.ListSessionsResponse{}
grpcMachines := make([]*pbgrpcv1.Machine, 0, len(s.ses))