package vpnruntime import ( "context" "errors" "io" "sync" "testing" "time" ) type recordingGatewayTransport struct { mu sync.Mutex batches [][][]byte } type closingGatewayTransport struct { closed chan struct{} } func (t *closingGatewayTransport) SendGatewayPacketBatch(context.Context, [][]byte) error { return nil } func (t *closingGatewayTransport) ReceiveGatewayPacketBatch(context.Context, time.Duration) ([][]byte, error) { return [][]byte{[]byte("packet")}, nil } func (t *closingGatewayTransport) Close() error { close(t.closed) return nil } type failingWriteTun struct { closed chan struct{} } func (t failingWriteTun) Read([]byte) (int, error) { <-t.closed return 0, io.EOF } func (t failingWriteTun) Write([]byte) (int, error) { return 0, errors.New("write failed") } func (t failingWriteTun) Close() error { select { case <-t.closed: default: close(t.closed) } return nil } func (t *recordingGatewayTransport) SendGatewayPacketBatch(ctx context.Context, packets [][]byte) error { copied := make([][]byte, len(packets)) for i, packet := range packets { copied[i] = append([]byte(nil), packet...) } t.mu.Lock() t.batches = append(t.batches, copied) t.mu.Unlock() return nil } func (t *recordingGatewayTransport) ReceiveGatewayPacketBatch(ctx context.Context, timeout time.Duration) ([][]byte, error) { return nil, ctx.Err() } func (t *recordingGatewayTransport) firstBatch() [][]byte { t.mu.Lock() defer t.mu.Unlock() if len(t.batches) == 0 { return nil } return t.batches[0] } func TestGatewayRunClosesPacketTransportOnRuntimeError(t *testing.T) { transport := &closingGatewayTransport{closed: make(chan struct{})} gateway := &Gateway{ Transport: transport, VPNConnectionID: "vpn-1", PollTimeout: time.Millisecond, } err := gateway.run(context.Background(), failingWriteTun{closed: make(chan struct{})}) if err == nil || err.Error() != "write failed" { t.Fatalf("run error = %v, want write failed", err) } select { case <-transport.closed: case <-time.After(time.Second): t.Fatal("packet transport was not closed") } } func TestGatewayUploadPrioritizesTCPControlPackets(t *testing.T) { transport := &recordingGatewayTransport{} gateway := &Gateway{Transport: transport, VPNConnectionID: "vpn-1"} priorityPackets := make(chan []byte, 1) packets := make(chan []byte, 1) normal := testIPv4TCPPacket([4]byte{101, 32, 118, 25}, [4]byte{10, 77, 0, 2}, 443, 37566) priority := testIPv4TCPPacket([4]byte{192, 168, 200, 95}, [4]byte{10, 77, 0, 2}, 3389, 51000) priority[33] = 0x12 packets <- normal priorityPackets <- priority ctx, cancel := context.WithCancel(context.Background()) done := make(chan error, 1) go func() { done <- gateway.uploadGatewayPackets(ctx, priorityPackets, packets) }() defer func() { cancel() <-done }() deadline := time.After(time.Second) for { if batch := transport.firstBatch(); len(batch) == 1 { if string(batch[0]) != string(priority) { t.Fatalf("first uploaded packet = %#v, want priority packet", batch[0]) } return } select { case <-deadline: t.Fatal("timed out waiting for first gateway upload batch") default: time.Sleep(time.Millisecond) } } } func TestGatewayUploadPreemptsPendingNormalBatchForTCPControlPackets(t *testing.T) { transport := &recordingGatewayTransport{} gateway := &Gateway{Transport: transport, VPNConnectionID: "vpn-1"} priorityPackets := make(chan []byte, 1) packets := make(chan []byte, 1) normal := testIPv4TCPPacket([4]byte{101, 32, 118, 25}, [4]byte{10, 77, 0, 2}, 443, 37566) priority := testIPv4TCPPacket([4]byte{192, 168, 200, 95}, [4]byte{10, 77, 0, 2}, 3389, 51000) priority[33] = 0x12 ctx, cancel := context.WithCancel(context.Background()) done := make(chan error, 1) go func() { done <- gateway.uploadGatewayPackets(ctx, priorityPackets, packets) }() defer func() { cancel() <-done }() packets <- normal time.Sleep(time.Millisecond) priorityPackets <- priority deadline := time.After(time.Second) for { if batch := transport.firstBatch(); len(batch) == 1 { if string(batch[0]) != string(priority) { t.Fatalf("first uploaded packet = %#v, want priority packet before pending normal batch", batch[0]) } return } select { case <-deadline: t.Fatal("timed out waiting for preempted priority upload batch") default: time.Sleep(time.Millisecond) } } } func TestGatewayUploadMicroBatchesTCPControlPackets(t *testing.T) { transport := &recordingGatewayTransport{} gateway := &Gateway{Transport: transport, VPNConnectionID: "vpn-1"} priorityPackets := make(chan []byte, 2) packets := make(chan []byte, 1) first := testIPv4TCPPacket([4]byte{192, 168, 200, 95}, [4]byte{10, 77, 0, 2}, 3389, 51000) first[33] = 0x12 second := testIPv4TCPPacket([4]byte{188, 40, 167, 82}, [4]byte{10, 77, 0, 2}, 80, 51002) second[33] = 0x12 priorityPackets <- first priorityPackets <- second ctx, cancel := context.WithCancel(context.Background()) done := make(chan error, 1) go func() { done <- gateway.uploadGatewayPackets(ctx, priorityPackets, packets) }() defer func() { cancel() <-done }() deadline := time.After(time.Second) for { if batch := transport.firstBatch(); len(batch) == 2 { if string(batch[0]) != string(first) || string(batch[1]) != string(second) { t.Fatalf("priority batch = %#v, want both control packets in order", batch) } return } select { case <-deadline: t.Fatal("timed out waiting for priority microbatch") default: time.Sleep(time.Millisecond) } } } func TestIsTCPControlPacket(t *testing.T) { packet := testIPv4TCPPacket([4]byte{192, 168, 200, 95}, [4]byte{10, 77, 0, 2}, 3389, 51000) if isTCPControlPacket(packet) { t.Fatal("packet without control flags was classified as control") } packet[33] = 0x12 if !isTCPControlPacket(packet) { t.Fatal("tcp syn-ack was not classified as control") } packet[9] = 17 if isTCPControlPacket(packet) { t.Fatal("udp packet was classified as tcp control") } }