|
| 1 | +package scip |
| 2 | + |
| 3 | +import ( |
| 4 | + "io" |
| 5 | + |
| 6 | + "google.golang.org/protobuf/encoding/protowire" |
| 7 | + "google.golang.org/protobuf/proto" |
| 8 | + |
| 9 | + "github.com/sourcegraph/sourcegraph/lib/errors" |
| 10 | +) |
| 11 | + |
| 12 | +// IndexVisitor is a struct of functions rather than an interface since Go |
| 13 | +// doesn't support adding new functions to interfaces with default |
| 14 | +// implementations, so adding new functions here for new fields in |
| 15 | +// the SCIP schema would break clients. Individual functions may be nil. |
| 16 | +type IndexVisitor struct { |
| 17 | + VisitMetadata func(*Metadata) |
| 18 | + VisitDocument func(*Document) |
| 19 | + VisitExternalSymbol func(*SymbolInformation) |
| 20 | +} |
| 21 | + |
| 22 | +// See https://protobuf.dev/programming-guides/encoding/#varints |
| 23 | +const maxVarintBytes = 10 |
| 24 | + |
| 25 | +// ParseStreaming processes an index by incrementally reading input from the io.Reader. |
| 26 | +// |
| 27 | +// Parsing takes place at Document granularity for ease of use. |
| 28 | +func (pi *IndexVisitor) ParseStreaming(r io.Reader) error { |
| 29 | + // The tag is encoded as a varint with value: (field_number << 3) | wire_type |
| 30 | + // Varints < 128 fit in 1 byte, which means 4 bits are available for field |
| 31 | + // numbers. The Index type has less than 15 fields, so the tag will fit in 1 byte. |
| 32 | + tagBuf := make([]byte, 1) |
| 33 | + lenBuf := make([]byte, 0, maxVarintBytes) |
| 34 | + dataBuf := make([]byte, 0, 1024) |
| 35 | + |
| 36 | + for { |
| 37 | + numRead, err := r.Read(tagBuf) |
| 38 | + if err == io.EOF { |
| 39 | + return nil |
| 40 | + } |
| 41 | + if err != nil { |
| 42 | + return errors.Wrapf(err, "failed to read from index reader: %w") |
| 43 | + } |
| 44 | + if numRead == 0 { |
| 45 | + return errors.New("read 0 bytes from index") |
| 46 | + } |
| 47 | + fieldNumber, fieldType, errCode := protowire.ConsumeTag(tagBuf) |
| 48 | + if errCode < 0 { |
| 49 | + return errors.Wrap(protowire.ParseError(errCode), "failed to consume tag") |
| 50 | + } |
| 51 | + switch fieldNumber { |
| 52 | + // As per scip.proto, all of Metadata, Document and SymbolInformation are sub-messages |
| 53 | + case metadataFieldNumber, documentsFieldNumber, externalSymbolsFieldNumber: |
| 54 | + if fieldType != protowire.BytesType { |
| 55 | + return errors.Newf("expected LEN type tag for %s", indexFieldName(fieldNumber)) |
| 56 | + } |
| 57 | + lenBuf = lenBuf[:0] |
| 58 | + dataLen, err := readVarint(r, lenBuf) |
| 59 | + if err != nil { |
| 60 | + return errors.Wrapf(err, "failed to read length for %s", indexFieldName(fieldNumber)) |
| 61 | + } |
| 62 | + if dataLen > uint64(cap(dataBuf)) { |
| 63 | + dataBuf = make([]byte, dataLen) |
| 64 | + } else { |
| 65 | + dataBuf = dataBuf[:0] |
| 66 | + for i := uint64(0); i < dataLen; i++ { |
| 67 | + dataBuf = append(dataBuf, 0) |
| 68 | + } |
| 69 | + } |
| 70 | + // Keep going when len == 0 instead of short-circuiting to preserve empty sub-messages |
| 71 | + if dataLen > 0 { |
| 72 | + numRead, err := r.Read(dataBuf) |
| 73 | + if err != nil { |
| 74 | + return errors.Wrapf(err, "failed to read data for %s", indexFieldName(fieldNumber)) |
| 75 | + } |
| 76 | + if uint64(numRead) != dataLen { |
| 77 | + return errors.Newf( |
| 78 | + "expected to read %d bytes based on LEN but read %d bytes", dataLen, numRead) |
| 79 | + } |
| 80 | + } |
| 81 | + if fieldNumber == metadataFieldNumber && pi.VisitMetadata != nil { |
| 82 | + m := Metadata{} |
| 83 | + if err := proto.Unmarshal(dataBuf, &m); err != nil { |
| 84 | + return errors.Wrapf(err, "failed to read %s", indexFieldName(fieldNumber)) |
| 85 | + } |
| 86 | + pi.VisitMetadata(&m) |
| 87 | + } else if fieldNumber == documentsFieldNumber && pi.VisitDocument != nil { |
| 88 | + d := Document{} |
| 89 | + if err := proto.Unmarshal(dataBuf, &d); err != nil { |
| 90 | + return errors.Wrapf(err, "failed to read %s", indexFieldName(fieldNumber)) |
| 91 | + } |
| 92 | + pi.VisitDocument(&d) |
| 93 | + } else if fieldNumber == externalSymbolsFieldNumber && pi.VisitExternalSymbol != nil { |
| 94 | + s := SymbolInformation{} |
| 95 | + if err := proto.Unmarshal(dataBuf, &s); err != nil { |
| 96 | + return errors.Wrapf(err, "failed to read %s", indexFieldName(fieldNumber)) |
| 97 | + } |
| 98 | + pi.VisitExternalSymbol(&s) |
| 99 | + } else { |
| 100 | + return errors.Newf("added new field in scip.Index but forgot to add unmarshaling code") |
| 101 | + } |
| 102 | + default: |
| 103 | + return errors.Newf("added new field in scip.Index but forgot to update streaming parser") |
| 104 | + } |
| 105 | + } |
| 106 | +} |
| 107 | + |
| 108 | +const ( |
| 109 | + metadataFieldNumber = 1 |
| 110 | + documentsFieldNumber = 2 |
| 111 | + externalSymbolsFieldNumber = 3 |
| 112 | +) |
| 113 | + |
| 114 | +// readVarint attempts to read a varint, using scratchBuf for temporary storage |
| 115 | +// |
| 116 | +// scratchBuf should be able to accommodate any varint size |
| 117 | +// based on its capacity, and be cleared before readVarint is called |
| 118 | +func readVarint(r io.Reader, scratchBuf []byte) (uint64, error) { |
| 119 | + nextByteBuf := make([]byte, 1, 1) |
| 120 | + for i := 0; i < cap(scratchBuf); i++ { |
| 121 | + numRead, err := r.Read(nextByteBuf) |
| 122 | + if err != nil { |
| 123 | + return 0, errors.Wrapf(err, "failed to read %d-th byte of Varint. soFar: %v", i, scratchBuf) |
| 124 | + } |
| 125 | + if numRead == 0 { |
| 126 | + return 0, errors.Newf("failed to read %d-th byte of Varint. soFar: %v", scratchBuf) |
| 127 | + } |
| 128 | + nextByte := nextByteBuf[0] |
| 129 | + scratchBuf = append(scratchBuf, nextByte) |
| 130 | + if nextByte <= 127 { // https://protobuf.dev/programming-guides/encoding/#varints |
| 131 | + // Continuation bit is not set, so Varint must've ended |
| 132 | + break |
| 133 | + } |
| 134 | + } |
| 135 | + value, errCode := protowire.ConsumeVarint(scratchBuf) |
| 136 | + if errCode < 0 { |
| 137 | + return value, protowire.ParseError(errCode) |
| 138 | + } |
| 139 | + return value, nil |
| 140 | +} |
| 141 | + |
| 142 | +func indexFieldName(i protowire.Number) string { |
| 143 | + if i == metadataFieldNumber { |
| 144 | + return "metadata" |
| 145 | + } else if i == documentsFieldNumber { |
| 146 | + return "documents" |
| 147 | + } else if i == externalSymbolsFieldNumber { |
| 148 | + return "external_symbols" |
| 149 | + } |
| 150 | + return "<unknown>" |
| 151 | +} |
0 commit comments