diff --git a/internal/server/server.go b/internal/server/server.go index 987b034..1defdfc 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -21,7 +21,6 @@ package server import ( "context" "fmt" - "io" "net" "sync" "time" @@ -35,21 +34,6 @@ import ( "google.golang.org/grpc/reflection" ) -// nopWriteCloser is a no-op [io.WriteCloser] -type nopWriteCloser struct { - io.Writer -} - -// Close implements [io.Closer] -func (nwc nopWriteCloser) Close() error { - return nil -} - -// newNopWriteCloser creates a new nopWriteCloser -func newNopWriteCloser(w io.Writer) *nopWriteCloser { - return &nopWriteCloser{w} -} - // Session represents a session where a machine is attempting to receive // a private key from SubmitKey. type Session struct { @@ -69,7 +53,7 @@ type Server struct { // ses is a machine_id -> Session map ses map[string]*Session - sesMu sync.Mutex + sesMu sync.RWMutex pbgrpcv1.UnimplementedKlefkiServiceServer } @@ -100,18 +84,24 @@ func (s *Server) Run(ctx context.Context) error { } // GetTime implements the GetTime RPC -func (s *Server) GetTime(_ context.Context, req *pbgrpcv1.GetTimeRequest) (*pbgrpcv1.GetTimeResponse, error) { +func (s *Server) GetTime(_ context.Context, _ *pbgrpcv1.GetTimeRequest) (*pbgrpcv1.GetTimeResponse, error) { resp := &pbgrpcv1.GetTimeResponse{} resp.SetTime(time.Now().Format(time.RFC3339Nano)) return resp, nil } // SubmitKey implements the SubmitKey RPC -func (s *Server) SubmitKey(ctx context.Context, req *pbgrpcv1.SubmitKeyRequest) (*pbgrpcv1.SubmitKeyResponse, error) { +func (s *Server) SubmitKey(_ context.Context, req *pbgrpcv1.SubmitKeyRequest) (*pbgrpcv1.SubmitKeyResponse, error) { machineID := req.GetMachineId() + s.sesMu.RLock() if _, ok := s.ses[machineID]; !ok { + s.sesMu.RUnlock() return nil, fmt.Errorf("failed to find machine ID %q", machineID) } + s.sesMu.RUnlock() + + s.sesMu.Lock() + defer s.sesMu.Unlock() s.ses[machineID].EncKey = req.GetEncKey() return &pbgrpcv1.SubmitKeyResponse{}, nil @@ -127,6 +117,10 @@ func (s *Server) GetKey(ctx context.Context, req *pbgrpcv1.GetKeyRequest) (*pbgr return nil, fmt.Errorf("failed to parsed signed at %q: %w", req.GetSignedAt(), err) } ts = ts.UTC() // Always operate with UTC time. + if time.Since(ts) < 5*time.Minute { + return nil, fmt.Errorf("signature has expired") + } + sig := req.GetSignature() machine, err := s.db.Machine.Get(ctx, req.GetMachineId()) @@ -152,6 +146,9 @@ func (s *Server) GetKey(ctx context.Context, req *pbgrpcv1.GetKeyRequest) (*pbgr return nil, fmt.Errorf("failed to add instance public key to encryptor: %w", err) } + s.sesMu.Lock() + defer s.sesMu.Unlock() + // Track the last time the machine asked for a key. This is what backs // the sessions api if _, ok := s.ses[machine.ID]; !ok { @@ -172,6 +169,9 @@ func (s *Server) GetKey(ctx context.Context, req *pbgrpcv1.GetKeyRequest) (*pbgr // ListSessions implements the ListSessions RPC. func (s *Server) ListSessions(ctx context.Context, _ *pbgrpcv1.ListSessionsRequest) (*pbgrpcv1.ListSessionsResponse, error) { + s.sesMu.RLock() + defer s.sesMu.RUnlock() + resp := &pbgrpcv1.ListSessionsResponse{} grpcMachines := make([]*pbgrpcv1.Machine, 0, len(s.ses))