// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package test import ( "bytes" "context" "errors" "fmt" "io" "sync" "github.com/pion/transport/v3/netctx" ) var errByteSequenceChanged = errors.New("byte sequence changed") // Options represents the configuration of the stress test type Options struct { MsgSize int MsgCount int } // Stress enables stress testing of a io.ReadWriter. // It checks that packets are received correctly and in order. func Stress(ca io.Writer, cb io.Reader, opt Options) error { return StressContext(context.Background(), &wrappedWriter{ca}, &wrappedReader{cb}, opt) } // StressContext enables stress testing of a io.ReadWriter. // It checks that packets are received correctly and in order. func StressContext(ctx context.Context, ca netctx.Writer, cb netctx.Reader, opt Options) error { bufs := make(chan []byte, opt.MsgCount) errCh := make(chan error) // Write go func() { err := write(ctx, ca, bufs, opt) errCh <- err close(bufs) }() // Read go func() { result := make([]byte, opt.MsgSize) for original := range bufs { err := read(ctx, cb, original, result) if err != nil { errCh <- err } } close(errCh) }() return FlattenErrs(GatherErrs(errCh)) } func read(ctx context.Context, r netctx.Reader, original, result []byte) error { n, err := r.ReadContext(ctx, result) if err != nil { return err } if !bytes.Equal(original, result[:n]) { return fmt.Errorf("%w %#v != %#v", errByteSequenceChanged, original, result) } return nil } // StressDuplex enables duplex stress testing of a io.ReadWriter. // It checks that packets are received correctly and in order. func StressDuplex(ca io.ReadWriter, cb io.ReadWriter, opt Options) error { return StressDuplexContext(context.Background(), &wrappedReadWriter{ca}, &wrappedReadWriter{cb}, opt) } // StressDuplexContext enables duplex stress testing of a io.ReadWriter. // It checks that packets are received correctly and in order. func StressDuplexContext(ctx context.Context, ca netctx.ReadWriter, cb netctx.ReadWriter, opt Options) error { errCh := make(chan error) var wg sync.WaitGroup wg.Add(2) go func() { defer wg.Done() errCh <- StressContext(ctx, ca, cb, opt) }() go func() { defer wg.Done() errCh <- StressContext(ctx, cb, ca, opt) }() go func() { wg.Wait() close(errCh) }() return FlattenErrs(GatherErrs(errCh)) } func write(ctx context.Context, c netctx.Writer, bufs chan []byte, opt Options) error { randomizer := initRand() for i := 0; i < opt.MsgCount; i++ { buf, err := randomizer.randBuf(opt.MsgSize) if err != nil { return err } bufs <- buf if _, err = c.WriteContext(ctx, buf); err != nil { return err } } return nil }