linter & use mutex
All checks were successful
tests / go test (push) Successful in 8s
tests / golangci-lint (push) Successful in 14s

This commit is contained in:
Jared Allard 2025-03-06 21:39:37 -08:00
parent affa5919f7
commit c86cc3da61
Signed by: jaredallard
SSH key fingerprint: SHA256:wyRyyv28jBYw8Yp/oABNPUYvbGd6hyZj23XVXEm5G/U

View file

@ -21,7 +21,6 @@ package server
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"net" "net"
"sync" "sync"
"time" "time"
@ -35,21 +34,6 @@ import (
"google.golang.org/grpc/reflection" "google.golang.org/grpc/reflection"
) )
// nopWriteCloser is a no-op [io.WriteCloser]
type nopWriteCloser struct {
io.Writer
}
// Close implements [io.Closer]
func (nwc nopWriteCloser) Close() error {
return nil
}
// newNopWriteCloser creates a new nopWriteCloser
func newNopWriteCloser(w io.Writer) *nopWriteCloser {
return &nopWriteCloser{w}
}
// Session represents a session where a machine is attempting to receive // Session represents a session where a machine is attempting to receive
// a private key from SubmitKey. // a private key from SubmitKey.
type Session struct { type Session struct {
@ -69,7 +53,7 @@ type Server struct {
// ses is a machine_id -> Session map // ses is a machine_id -> Session map
ses map[string]*Session ses map[string]*Session
sesMu sync.Mutex sesMu sync.RWMutex
pbgrpcv1.UnimplementedKlefkiServiceServer pbgrpcv1.UnimplementedKlefkiServiceServer
} }
@ -100,18 +84,24 @@ func (s *Server) Run(ctx context.Context) error {
} }
// GetTime implements the GetTime RPC // GetTime implements the GetTime RPC
func (s *Server) GetTime(_ context.Context, req *pbgrpcv1.GetTimeRequest) (*pbgrpcv1.GetTimeResponse, error) { func (s *Server) GetTime(_ context.Context, _ *pbgrpcv1.GetTimeRequest) (*pbgrpcv1.GetTimeResponse, error) {
resp := &pbgrpcv1.GetTimeResponse{} resp := &pbgrpcv1.GetTimeResponse{}
resp.SetTime(time.Now().Format(time.RFC3339Nano)) resp.SetTime(time.Now().Format(time.RFC3339Nano))
return resp, nil return resp, nil
} }
// SubmitKey implements the SubmitKey RPC // SubmitKey implements the SubmitKey RPC
func (s *Server) SubmitKey(ctx context.Context, req *pbgrpcv1.SubmitKeyRequest) (*pbgrpcv1.SubmitKeyResponse, error) { func (s *Server) SubmitKey(_ context.Context, req *pbgrpcv1.SubmitKeyRequest) (*pbgrpcv1.SubmitKeyResponse, error) {
machineID := req.GetMachineId() machineID := req.GetMachineId()
s.sesMu.RLock()
if _, ok := s.ses[machineID]; !ok { if _, ok := s.ses[machineID]; !ok {
s.sesMu.RUnlock()
return nil, fmt.Errorf("failed to find machine ID %q", machineID) return nil, fmt.Errorf("failed to find machine ID %q", machineID)
} }
s.sesMu.RUnlock()
s.sesMu.Lock()
defer s.sesMu.Unlock()
s.ses[machineID].EncKey = req.GetEncKey() s.ses[machineID].EncKey = req.GetEncKey()
return &pbgrpcv1.SubmitKeyResponse{}, nil return &pbgrpcv1.SubmitKeyResponse{}, nil
@ -127,6 +117,10 @@ func (s *Server) GetKey(ctx context.Context, req *pbgrpcv1.GetKeyRequest) (*pbgr
return nil, fmt.Errorf("failed to parsed signed at %q: %w", req.GetSignedAt(), err) return nil, fmt.Errorf("failed to parsed signed at %q: %w", req.GetSignedAt(), err)
} }
ts = ts.UTC() // Always operate with UTC time. ts = ts.UTC() // Always operate with UTC time.
if time.Since(ts) < 5*time.Minute {
return nil, fmt.Errorf("signature has expired")
}
sig := req.GetSignature() sig := req.GetSignature()
machine, err := s.db.Machine.Get(ctx, req.GetMachineId()) machine, err := s.db.Machine.Get(ctx, req.GetMachineId())
@ -152,6 +146,9 @@ func (s *Server) GetKey(ctx context.Context, req *pbgrpcv1.GetKeyRequest) (*pbgr
return nil, fmt.Errorf("failed to add instance public key to encryptor: %w", err) return nil, fmt.Errorf("failed to add instance public key to encryptor: %w", err)
} }
s.sesMu.Lock()
defer s.sesMu.Unlock()
// Track the last time the machine asked for a key. This is what backs // Track the last time the machine asked for a key. This is what backs
// the sessions api // the sessions api
if _, ok := s.ses[machine.ID]; !ok { if _, ok := s.ses[machine.ID]; !ok {
@ -172,6 +169,9 @@ func (s *Server) GetKey(ctx context.Context, req *pbgrpcv1.GetKeyRequest) (*pbgr
// ListSessions implements the ListSessions RPC. // ListSessions implements the ListSessions RPC.
func (s *Server) ListSessions(ctx context.Context, _ *pbgrpcv1.ListSessionsRequest) (*pbgrpcv1.ListSessionsResponse, error) { func (s *Server) ListSessions(ctx context.Context, _ *pbgrpcv1.ListSessionsRequest) (*pbgrpcv1.ListSessionsResponse, error) {
s.sesMu.RLock()
defer s.sesMu.RUnlock()
resp := &pbgrpcv1.ListSessionsResponse{} resp := &pbgrpcv1.ListSessionsResponse{}
grpcMachines := make([]*pbgrpcv1.Machine, 0, len(s.ses)) grpcMachines := make([]*pbgrpcv1.Machine, 0, len(s.ses))