linter & use mutex
This commit is contained in:
parent
affa5919f7
commit
c86cc3da61
1 changed files with 19 additions and 19 deletions
|
@ -21,7 +21,6 @@ package server
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -35,21 +34,6 @@ import (
|
|||
"google.golang.org/grpc/reflection"
|
||||
)
|
||||
|
||||
// nopWriteCloser is a no-op [io.WriteCloser]
|
||||
type nopWriteCloser struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
// Close implements [io.Closer]
|
||||
func (nwc nopWriteCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// newNopWriteCloser creates a new nopWriteCloser
|
||||
func newNopWriteCloser(w io.Writer) *nopWriteCloser {
|
||||
return &nopWriteCloser{w}
|
||||
}
|
||||
|
||||
// Session represents a session where a machine is attempting to receive
|
||||
// a private key from SubmitKey.
|
||||
type Session struct {
|
||||
|
@ -69,7 +53,7 @@ type Server struct {
|
|||
|
||||
// ses is a machine_id -> Session map
|
||||
ses map[string]*Session
|
||||
sesMu sync.Mutex
|
||||
sesMu sync.RWMutex
|
||||
|
||||
pbgrpcv1.UnimplementedKlefkiServiceServer
|
||||
}
|
||||
|
@ -100,18 +84,24 @@ func (s *Server) Run(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// GetTime implements the GetTime RPC
|
||||
func (s *Server) GetTime(_ context.Context, req *pbgrpcv1.GetTimeRequest) (*pbgrpcv1.GetTimeResponse, error) {
|
||||
func (s *Server) GetTime(_ context.Context, _ *pbgrpcv1.GetTimeRequest) (*pbgrpcv1.GetTimeResponse, error) {
|
||||
resp := &pbgrpcv1.GetTimeResponse{}
|
||||
resp.SetTime(time.Now().Format(time.RFC3339Nano))
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// SubmitKey implements the SubmitKey RPC
|
||||
func (s *Server) SubmitKey(ctx context.Context, req *pbgrpcv1.SubmitKeyRequest) (*pbgrpcv1.SubmitKeyResponse, error) {
|
||||
func (s *Server) SubmitKey(_ context.Context, req *pbgrpcv1.SubmitKeyRequest) (*pbgrpcv1.SubmitKeyResponse, error) {
|
||||
machineID := req.GetMachineId()
|
||||
s.sesMu.RLock()
|
||||
if _, ok := s.ses[machineID]; !ok {
|
||||
s.sesMu.RUnlock()
|
||||
return nil, fmt.Errorf("failed to find machine ID %q", machineID)
|
||||
}
|
||||
s.sesMu.RUnlock()
|
||||
|
||||
s.sesMu.Lock()
|
||||
defer s.sesMu.Unlock()
|
||||
|
||||
s.ses[machineID].EncKey = req.GetEncKey()
|
||||
return &pbgrpcv1.SubmitKeyResponse{}, nil
|
||||
|
@ -127,6 +117,10 @@ func (s *Server) GetKey(ctx context.Context, req *pbgrpcv1.GetKeyRequest) (*pbgr
|
|||
return nil, fmt.Errorf("failed to parsed signed at %q: %w", req.GetSignedAt(), err)
|
||||
}
|
||||
ts = ts.UTC() // Always operate with UTC time.
|
||||
if time.Since(ts) < 5*time.Minute {
|
||||
return nil, fmt.Errorf("signature has expired")
|
||||
}
|
||||
|
||||
sig := req.GetSignature()
|
||||
|
||||
machine, err := s.db.Machine.Get(ctx, req.GetMachineId())
|
||||
|
@ -152,6 +146,9 @@ func (s *Server) GetKey(ctx context.Context, req *pbgrpcv1.GetKeyRequest) (*pbgr
|
|||
return nil, fmt.Errorf("failed to add instance public key to encryptor: %w", err)
|
||||
}
|
||||
|
||||
s.sesMu.Lock()
|
||||
defer s.sesMu.Unlock()
|
||||
|
||||
// Track the last time the machine asked for a key. This is what backs
|
||||
// the sessions api
|
||||
if _, ok := s.ses[machine.ID]; !ok {
|
||||
|
@ -172,6 +169,9 @@ func (s *Server) GetKey(ctx context.Context, req *pbgrpcv1.GetKeyRequest) (*pbgr
|
|||
|
||||
// ListSessions implements the ListSessions RPC.
|
||||
func (s *Server) ListSessions(ctx context.Context, _ *pbgrpcv1.ListSessionsRequest) (*pbgrpcv1.ListSessionsResponse, error) {
|
||||
s.sesMu.RLock()
|
||||
defer s.sesMu.RUnlock()
|
||||
|
||||
resp := &pbgrpcv1.ListSessionsResponse{}
|
||||
|
||||
grpcMachines := make([]*pbgrpcv1.Machine, 0, len(s.ses))
|
||||
|
|
Loading…
Add table
Reference in a new issue