diff --git a/application_defined.go b/application_defined.go index 840d0d6..86ae1f6 100644 --- a/application_defined.go +++ b/application_defined.go @@ -88,7 +88,7 @@ func (a *ApplicationDefined) Unmarshal(rawPacket []byte) error { return errPacketTooShort } - if int(header.Length+1)*4 != len(rawPacket) { + if (int(header.Length)+1)*4 != len(rawPacket) { return errAppDefinedInvalidLength } diff --git a/application_defined_test.go b/application_defined_test.go index 2b013de..06cc81b 100644 --- a/application_defined_test.go +++ b/application_defined_test.go @@ -246,3 +246,18 @@ func TestTApplicationPacketMarshal(t *testing.T) { assert.Equalf(t, marshalSize, len(rawPacket), "MarshalSize %q", test.Name) } } + +func TestTApplicationPacketUnmarshalMaxLength(t *testing.T) { + rawPacket := make([]byte, 4*(0xFFFF+1)) + rawPacket[0] = 0x80 + rawPacket[1] = 0xcc + rawPacket[2] = 0xff + rawPacket[3] = 0xff + copy(rawPacket[8:12], []byte("NAME")) + + var packet ApplicationDefined + err := packet.Unmarshal(rawPacket) + assert.NoError(t, err) + assert.Equal(t, "NAME", packet.Name) + assert.Len(t, packet.Data, len(rawPacket)-12) +} diff --git a/full_intra_request.go b/full_intra_request.go index 1e6333f..2dd6655 100644 --- a/full_intra_request.go +++ b/full_intra_request.go @@ -60,7 +60,7 @@ func (p *FullIntraRequest) Unmarshal(rawPacket []byte) error { return err } - if len(rawPacket) < (headerLength + int(4*header.Length)) { + if len(rawPacket) < (headerLength + 4*int(header.Length)) { return errPacketTooShort } @@ -69,16 +69,17 @@ func (p *FullIntraRequest) Unmarshal(rawPacket []byte) error { } // The FCI field MUST contain one or more FIR entries - if 4*header.Length-firOffset <= 0 || (4*header.Length)%8 != 0 { + if 4*int(header.Length)-firOffset <= 0 || (4*int(header.Length))%8 != 0 { return errBadLength } p.SenderSSRC = binary.BigEndian.Uint32(rawPacket[headerLength:]) p.MediaSSRC = binary.BigEndian.Uint32(rawPacket[headerLength+ssrcLength:]) - for i := headerLength + firOffset; i < (headerLength + int(header.Length*4)); i += 8 { + for i := headerLength + firOffset; i < (headerLength + 4*int(header.Length)); i += 8 { + entry := rawPacket[i : i+8] p.FIR = append(p.FIR, FIREntry{ - binary.BigEndian.Uint32(rawPacket[i:]), - rawPacket[i+4], + binary.BigEndian.Uint32(entry), + entry[4], }) } diff --git a/full_intra_request_test.go b/full_intra_request_test.go index 0abf3bc..02b8615 100644 --- a/full_intra_request_test.go +++ b/full_intra_request_test.go @@ -229,3 +229,16 @@ func TestFullIntraRequestUnmarshalHeader(t *testing.T) { assert.Equalf(t, test.Want, fir.Header(), "Unmarshal header %q rr mismatch", test.Name) } } + +func TestFullIntraRequestUnmarshalMaxLength(t *testing.T) { + rawPacket := make([]byte, headerLength+4*0xFFFE) + rawPacket[0] = 0x84 + rawPacket[1] = 0xce + rawPacket[2] = 0xff + rawPacket[3] = 0xfe + + var fir FullIntraRequest + err := fir.Unmarshal(rawPacket) + assert.NoError(t, err) + assert.Len(t, fir.FIR, (4*0xFFFE-firOffset)/8) +} diff --git a/goodbye_test.go b/goodbye_test.go index e7c438a..369fc63 100644 --- a/goodbye_test.go +++ b/goodbye_test.go @@ -204,3 +204,23 @@ func TestGoodbyeRoundTrip(t *testing.T) { assert.Equalf(t, test.Bye, bye, "%q bye round trip mismatch", test.Name) } } + +func TestGoodbyeRoundTripMaxFieldSizes(t *testing.T) { + sources := make([]uint32, countMax) + for i := range sources { + sources[i] = uint32(i + 1) + } + + bye := Goodbye{ + Sources: sources, + Reason: strings.Repeat("x", sdesMaxOctetCount), + } + + data, err := bye.Marshal() + assert.NoError(t, err) + assert.Len(t, data, bye.MarshalSize()) + + var decoded Goodbye + assert.NoError(t, decoded.Unmarshal(data)) + assert.Equal(t, bye, decoded) +} diff --git a/packet.go b/packet.go index ca43964..447120b 100644 --- a/packet.go +++ b/packet.go @@ -68,7 +68,7 @@ func unmarshal(rawData []byte) (packet Packet, bytesprocessed int, err error) { return nil, 0, err } - bytesprocessed = int(header.Length+1) * 4 + bytesprocessed = (int(header.Length) + 1) * 4 if bytesprocessed > len(rawData) { return nil, 0, errPacketTooShort } diff --git a/packet_test.go b/packet_test.go index 94db90f..ef7a8b8 100644 --- a/packet_test.go +++ b/packet_test.go @@ -137,3 +137,19 @@ func TestInvalidHeaderLength(t *testing.T) { _, err := Unmarshal(invalidPacket) assert.ErrorIs(t, err, errPacketTooShort) } + +func TestUnmarshalMaxLengthRawPacket(t *testing.T) { + rawPacket := make([]byte, 4*(0xFFFF+1)) + rawPacket[0] = 0x80 + rawPacket[1] = 0xfa + rawPacket[2] = 0xff + rawPacket[3] = 0xff + + packets, err := Unmarshal(rawPacket) + assert.NoError(t, err) + assert.Len(t, packets, 1) + + parsed, ok := packets[0].(*RawPacket) + assert.True(t, ok) + assert.Equal(t, rawPacket, []byte(*parsed)) +} diff --git a/receiver_report_test.go b/receiver_report_test.go index e65c69a..25d25a5 100644 --- a/receiver_report_test.go +++ b/receiver_report_test.go @@ -259,3 +259,17 @@ func TestReceiverReportRoundTrip(t *testing.T) { assert.Equalf(t, test.Report, decoded, "%s rr round trip mismatch", test.Name) } } + +func TestReceiverReportUnmarshalMaxLength(t *testing.T) { + rawPacket := make([]byte, 4*(0xFFFF+1)) + rawPacket[0] = 0x80 | countMax + rawPacket[1] = 0xc9 + rawPacket[2] = 0xff + rawPacket[3] = 0xff + + var rr ReceiverReport + err := rr.Unmarshal(rawPacket) + assert.NoError(t, err) + assert.Len(t, rr.Reports, countMax) + assert.Len(t, rr.ProfileExtensions, len(rawPacket)-rrReportOffset-countMax*receptionReportLength) +} diff --git a/rfc8888.go b/rfc8888.go index dad6b7f..2719551 100644 --- a/rfc8888.go +++ b/rfc8888.go @@ -145,7 +145,7 @@ func (b CCFeedbackReport) Marshal() ([]byte, error) { if err != nil { return nil, err } - length := 4 * (header.Length + 1) + length := 4 * (int(header.Length) + 1) buf := make([]byte, length) copy(buf[:headerLength], headerBuf) binary.BigEndian.PutUint32(buf[headerLength:], b.SenderSSRC) diff --git a/rfc8888_test.go b/rfc8888_test.go index 9324db0..08f6a7c 100644 --- a/rfc8888_test.go +++ b/rfc8888_test.go @@ -440,3 +440,28 @@ func TestCCFeedbackOverflow(t *testing.T) { }, bytes.Repeat([]byte{0, 0}, 0x7FFF)...)) assert.ErrorIs(t, err, errReportBlockLength) } + +func TestCCFeedbackReportMarshalMaxLength(t *testing.T) { + report := CCFeedbackReport{ + SenderSSRC: 1, + ReportTimestamp: 2, + ReportBlocks: make([]CCFeedbackReportBlock, 8), + } + + for i := range 7 { + report.ReportBlocks[i] = CCFeedbackReportBlock{ + MediaSSRC: uint32(i + 1), + MetricBlocks: make([]CCFeedbackMetricBlock, maxMetricBlocks), + } + } + report.ReportBlocks[7] = CCFeedbackReportBlock{ + MediaSSRC: 8, + MetricBlocks: make([]CCFeedbackMetricBlock, 16346), + } + + buf, err := report.Marshal() + assert.NoError(t, err) + assert.Len(t, buf, 4*(0xFFFF+1)) + assert.Equal(t, []byte{0x8b, 0xcd, 0xff, 0xff}, buf[:4]) + assert.Equal(t, []byte{0x00, 0x00, 0x00, 0x02}, buf[len(buf)-4:]) +} diff --git a/sender_report_test.go b/sender_report_test.go index 43ad144..be93a9a 100644 --- a/sender_report_test.go +++ b/sender_report_test.go @@ -4,6 +4,7 @@ package rtcp import ( + "bytes" "slices" "testing" @@ -279,3 +280,29 @@ func TestSenderReportRoundTrip(t *testing.T) { assert.Equalf(t, test.Report, decoded, "%q sr round trip", test.Name) } } + +func TestSenderReportRoundTripMaxLength(t *testing.T) { + report := SenderReport{ + SSRC: 1, + NTPTime: 2, + RTPTime: 3, + PacketCount: 4, + OctetCount: 5, + Reports: make([]ReceptionReport, countMax), + } + + for i := range report.Reports { + report.Reports[i] = ReceptionReport{SSRC: uint32(i + 1)} + } + + baseSize := report.MarshalSize() + report.ProfileExtensions = bytes.Repeat([]byte{0xab}, 4*(0xFFFF+1)-baseSize) + + data, err := report.Marshal() + assert.NoError(t, err) + assert.Len(t, data, 4*(0xFFFF+1)) + + var decoded SenderReport + assert.NoError(t, decoded.Unmarshal(data)) + assert.Equal(t, report, decoded) +} diff --git a/slice_loss_indication.go b/slice_loss_indication.go index 176b020..66eda54 100644 --- a/slice_loss_indication.go +++ b/slice_loss_indication.go @@ -72,7 +72,7 @@ func (p *SliceLossIndication) Unmarshal(rawPacket []byte) error { return err } - if len(rawPacket) < (headerLength + int(4*header.Length)) { + if len(rawPacket) < (headerLength + 4*int(header.Length)) { return errPacketTooShort } @@ -82,7 +82,7 @@ func (p *SliceLossIndication) Unmarshal(rawPacket []byte) error { p.SenderSSRC = binary.BigEndian.Uint32(rawPacket[headerLength:]) p.MediaSSRC = binary.BigEndian.Uint32(rawPacket[headerLength+ssrcLength:]) - for i := headerLength + sliOffset; i < (headerLength + int(header.Length*4)); i += 4 { + for i := headerLength + sliOffset; i < (headerLength + 4*int(header.Length)); i += 4 { sli := binary.BigEndian.Uint32(rawPacket[i:]) p.SLI = append(p.SLI, SLIEntry{ First: uint16((sli >> 19) & 0x1FFF), //nolint:gosec // G115 diff --git a/slice_loss_indication_test.go b/slice_loss_indication_test.go index c14ef86..817fe33 100644 --- a/slice_loss_indication_test.go +++ b/slice_loss_indication_test.go @@ -111,3 +111,16 @@ func TestSliceLossIndicationRoundTrip(t *testing.T) { assert.Equalf(t, test.Report, decoded, "%q sli round trip mismatch", test.Name) } } + +func TestSliceLossIndicationUnmarshalMaxLength(t *testing.T) { + rawPacket := make([]byte, 4*(0xFFFF+1)) + rawPacket[0] = 0x82 + rawPacket[1] = 0xcd + rawPacket[2] = 0xff + rawPacket[3] = 0xff + + var sli SliceLossIndication + err := sli.Unmarshal(rawPacket) + assert.NoError(t, err) + assert.Len(t, sli.SLI, 0xFFFF-sliLength) +} diff --git a/source_description_test.go b/source_description_test.go index e95a57b..eccf75b 100644 --- a/source_description_test.go +++ b/source_description_test.go @@ -340,3 +340,38 @@ func TestSourceDescriptionRoundTrip(t *testing.T) { assert.Equalf(t, test.Desc, decoded, "%s sdes round trip mismatch", test.Name) } } + +func TestSourceDescriptionRoundTripMaxLength(t *testing.T) { + const maxPacketLength = 4 * (0xFFFF + 1) + remainingItemBytes := maxPacketLength - headerLength - sdesSourceLen - sdesTypeLen + fullItemLength := sdesTypeLen + sdesOctetCountLen + sdesMaxOctetCount + fullItemCount := remainingItemBytes / fullItemLength + lastItemLength := remainingItemBytes % fullItemLength + + items := make([]SourceDescriptionItem, 0, fullItemCount+1) + maxText := strings.Repeat("x", sdesMaxOctetCount) + for range fullItemCount { + items = append(items, SourceDescriptionItem{Type: SDESNote, Text: maxText}) + } + if lastItemLength > 0 { + items = append(items, SourceDescriptionItem{ + Type: SDESNote, + Text: strings.Repeat("y", lastItemLength-sdesTypeLen-sdesOctetCountLen), + }) + } + + desc := SourceDescription{ + Chunks: []SourceDescriptionChunk{{ + Source: 1, + Items: items, + }}, + } + + data, err := desc.Marshal() + assert.NoError(t, err) + assert.Len(t, data, maxPacketLength) + + var decoded SourceDescription + assert.NoError(t, decoded.Unmarshal(data)) + assert.Equal(t, desc, decoded) +} diff --git a/transport_layer_cc.go b/transport_layer_cc.go index f5a3c03..79176c3 100644 --- a/transport_layer_cc.go +++ b/transport_layer_cc.go @@ -472,13 +472,13 @@ func (t *TransportLayerCC) Unmarshal(rawPacket []byte) error { // https://tools.ietf.org/html/rfc4585#page-33 // header's length + payload's length - totalLength := 4 * (t.Header.Length + 1) + totalLength := 4 * (int(t.Header.Length) + 1) if totalLength < headerLength+packetChunkOffset { return errPacketTooShort } - if len(rawPacket) < int(totalLength) { + if len(rawPacket) < totalLength { return errPacketTooShort } @@ -493,7 +493,7 @@ func (t *TransportLayerCC) Unmarshal(rawPacket []byte) error { t.ReferenceTime = get24BitsFromBytes(rawPacket[headerLength+referenceTimeOffset : headerLength+referenceTimeOffset+3]) t.FbPktCount = rawPacket[headerLength+fbPktCountOffset] - packetStatusPos := uint16(headerLength + packetChunkOffset) + packetStatusPos := int(headerLength + packetChunkOffset) var processedPacketNum uint16 for processedPacketNum < t.PacketStatusCount { if packetStatusPos+packetStatusChunkLength >= totalLength { diff --git a/transport_layer_cc_test.go b/transport_layer_cc_test.go index e366f7f..eda1df5 100644 --- a/transport_layer_cc_test.go +++ b/transport_layer_cc_test.go @@ -47,6 +47,20 @@ func TestTransportLayerCC_RunLengthChunkUnmarshal(t *testing.T) { } } +func TestTransportLayerCCUnmarshalMaxLength(t *testing.T) { + rawPacket := make([]byte, 4*(0xFFFF+1)) + rawPacket[0] = 0x8f + rawPacket[1] = 0xcd + rawPacket[2] = 0xff + rawPacket[3] = 0xff + + var packet TransportLayerCC + err := packet.Unmarshal(rawPacket) + assert.NoError(t, err) + assert.Equal(t, uint16(0xFFFF), packet.Header.Length) + assert.Zero(t, packet.PacketStatusCount) +} + func TestTransportLayerCC_RunLengthChunkMarshal(t *testing.T) { for _, test := range []struct { Name string diff --git a/transport_layer_nack.go b/transport_layer_nack.go index f728b2d..1918e05 100644 --- a/transport_layer_nack.go +++ b/transport_layer_nack.go @@ -132,7 +132,7 @@ func (p *TransportLayerNack) Unmarshal(rawPacket []byte) error { return err } - if len(rawPacket) < (headerLength + int(4*header.Length)) { + if len(rawPacket) < (headerLength + 4*int(header.Length)) { return errPacketTooShort } @@ -141,13 +141,13 @@ func (p *TransportLayerNack) Unmarshal(rawPacket []byte) error { } // The FCI field MUST contain at least one and MAY contain more than one Generic NACK - if 4*header.Length <= nackOffset { + if 4*int(header.Length) <= nackOffset { return errBadLength } p.SenderSSRC = binary.BigEndian.Uint32(rawPacket[headerLength:]) p.MediaSSRC = binary.BigEndian.Uint32(rawPacket[headerLength+ssrcLength:]) - for i := headerLength + nackOffset; i < (headerLength + int(header.Length*4)); i += 4 { + for i := headerLength + nackOffset; i < (headerLength + 4*int(header.Length)); i += 4 { p.Nacks = append(p.Nacks, NackPair{ binary.BigEndian.Uint16(rawPacket[i:]), PacketBitmap(binary.BigEndian.Uint16(rawPacket[i+2:])), diff --git a/transport_layer_nack_test.go b/transport_layer_nack_test.go index 1fd5751..3d1d826 100644 --- a/transport_layer_nack_test.go +++ b/transport_layer_nack_test.go @@ -117,6 +117,19 @@ func TestTransportLayerNackRoundTrip(t *testing.T) { } } +func TestTransportLayerNackUnmarshalMaxLength(t *testing.T) { + rawPacket := make([]byte, 4*(0xFFFF+1)) + rawPacket[0] = 0x81 + rawPacket[1] = 0xcd + rawPacket[2] = 0xff + rawPacket[3] = 0xff + + var nack TransportLayerNack + err := nack.Unmarshal(rawPacket) + assert.NoError(t, err) + assert.Len(t, nack.Nacks, 0xFFFF-tlnLength) +} + func testNackPair(t *testing.T, s []uint16, n NackPair) { t.Helper()