Skip to content
175 changes: 159 additions & 16 deletions vulnfeeds/conversion/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"fmt"
"io/fs"
"log/slog"
"net/http"
"os"
"path/filepath"
"slices"
Expand All @@ -21,6 +22,7 @@ import (
"github.com/google/osv/vulnfeeds/utility/logger"
"github.com/google/osv/vulnfeeds/vulns"
"github.com/ossf/osv-schema/bindings/go/osvschema"
"google.golang.org/protobuf/types/known/structpb"
)

// AddAffected adds an osvschema.Affected to a vulnerability, ensuring that no duplicate ranges are added.
Expand Down Expand Up @@ -64,6 +66,7 @@ func AddAffected(v *vulns.Vulnerability, aff *osvschema.Affected, metrics *model
}

func DeduplicateRefs(refs []models.Reference) []models.Reference {
refs = slices.Clone(refs)
// Deduplicate references by URL.
refs = slices.Clone(refs)
slices.SortStableFunc(refs, func(a, b models.Reference) int {
Expand Down Expand Up @@ -175,7 +178,7 @@ func WriteMetricsFile(metrics *models.ConversionMetrics, metricsFile *os.File) e

// GitVersionsToCommits examines repos and tries to convert versions to commits by treating them as Git tags.
// Returns the resolved ranges, unresolved ranges, and successful repos involved.
func GitVersionsToCommits(versionRanges []*osvschema.Range, repos []string, metrics *models.ConversionMetrics, cache *git.RepoTagsCache) ([]*osvschema.Range, []*osvschema.Range, []string) {
func GitVersionsToCommits(versionRanges []models.RangeWithMetadata, repos []string, metrics *models.ConversionMetrics, cache *git.RepoTagsCache) ([]*osvschema.Range, []models.RangeWithMetadata, []string) {
var newVersionRanges []*osvschema.Range
unresolvedRanges := versionRanges
var successfulRepos []string
Expand All @@ -187,6 +190,18 @@ func GitVersionsToCommits(versionRanges []*osvschema.Range, repos []string, metr
if cache.IsInvalid(repo) {
continue
}

repo, err := git.FindCanonicalLink(repo, http.DefaultClient, cache)
if err != nil {
metrics.AddNote("Failed to find canonical link - %s %v", repo, err)
if errors.Is(err, git.ErrRateLimit) || strings.Contains(err.Error(), "429") {
metrics.Outcome = models.Error
return nil, nil, nil
}

continue
}

normalizedTags, err := git.NormalizeRepoTags(repo, cache)
if err != nil {
if errors.Is(err, git.ErrRateLimit) || strings.Contains(err.Error(), "429") {
Expand All @@ -198,10 +213,10 @@ func GitVersionsToCommits(versionRanges []*osvschema.Range, repos []string, metr
continue
}

var stillUnresolvedRanges []*osvschema.Range
var stillUnresolvedRanges []models.RangeWithMetadata
for _, vr := range unresolvedRanges {
var introduced, fixed, lastAffected string
for _, e := range vr.GetEvents() {
for _, e := range vr.Range.GetEvents() {
if e.GetIntroduced() != "" {
introduced = e.GetIntroduced()
}
Expand Down Expand Up @@ -231,7 +246,7 @@ func GitVersionsToCommits(versionRanges []*osvschema.Range, repos []string, metr
metrics.AddNote("error resolving version to commit - %s - %s", lastAffected, err)
}

if introducedCommit != "" && (fixedCommit != "" || lastAffectedCommit != "") {
if fixedCommit != "" || lastAffectedCommit != "" {
var newVR *osvschema.Range

if fixedCommit != "" {
Expand All @@ -242,8 +257,14 @@ func GitVersionsToCommits(versionRanges []*osvschema.Range, repos []string, metr
successfulRepos = append(successfulRepos, repo)
newVR.Repo = repo
newVR.Type = osvschema.Range_GIT
if len(vr.GetEvents()) > 0 {
databaseSpecific, err := utility.NewStructpbFromMap(map[string]any{"versions": vr.GetEvents()})
if len(vr.Range.GetEvents()) > 0 {
dbSpecificMap := map[string]any{
"versions": vr.Range.GetEvents(),
}
if vr.Metadata.CPE != "" {
dbSpecificMap["cpe"] = vr.Metadata.CPE
}
databaseSpecific, err := utility.NewStructpbFromMap(dbSpecificMap)
if err != nil {
metrics.AddNote("failed to make database specific: %v", err)
} else {
Expand Down Expand Up @@ -324,7 +345,7 @@ func MergeTwoRanges(range1, range2 *osvschema.Range) (*osvschema.Range, error) {
for k, v := range db2.GetFields() {
val2 := v.AsInterface()
if existing, ok := mergedMap[k]; ok {
mergedVal, err := mergeDatabaseSpecificValues(existing, val2)
mergedVal, err := MergeDatabaseSpecificValues(existing, val2)
if err != nil {
logger.Info("Failed to merge database specific key", "key", k, "err", err)
}
Expand All @@ -346,18 +367,26 @@ func MergeTwoRanges(range1, range2 *osvschema.Range) (*osvschema.Range, error) {
return mergedRange, nil
}

// mergeDatabaseSpecificValues is a helper function that recursively merges two
// MergeDatabaseSpecificValues is a helper function that recursively merges two
// values from a DatabaseSpecific field. It handles lists (by appending), maps
// (by recursively merging keys), and simple strings (by creating a list if they
// differ). It returns an error if the types of the two values do not match.
func mergeDatabaseSpecificValues(val1, val2 any) (any, error) {
func MergeDatabaseSpecificValues(val1, val2 any) (any, error) {
switch v1 := val1.(type) {
case []any:
if v2, ok := val2.([]any); ok {
return append(v1, v2...), nil
return deduplicateList(append(v1, v2...)), nil
}

return nil, fmt.Errorf("mismatching types: %T and %T", val1, val2)
// Check if the list contains elements of the same type as val2
if len(v1) > 0 {
if fmt.Sprintf("%T", v1[0]) != fmt.Sprintf("%T", val2) {
return nil, fmt.Errorf("mismatching types: list of %T and %T", v1[0], val2)
}
}

// Append single value to list
return deduplicateList(append(v1, val2)), nil
case map[string]any:
if v2, ok := val2.(map[string]any); ok {
merged := make(map[string]any)
Expand All @@ -366,7 +395,7 @@ func mergeDatabaseSpecificValues(val1, val2 any) (any, error) {
}
for k, v := range v2 {
if existing, ok := merged[k]; ok {
mergedVal, err := mergeDatabaseSpecificValues(existing, v)
mergedVal, err := MergeDatabaseSpecificValues(existing, v)
if err != nil {
return nil, err
}
Expand All @@ -382,22 +411,136 @@ func mergeDatabaseSpecificValues(val1, val2 any) (any, error) {
return nil, fmt.Errorf("mismatching types: %T and %T", val1, val2)
case string:
if v2, ok := val2.(string); ok {
if v1 == v2 {
return v1, nil
return deduplicateList([]any{v1, v2}), nil
}
if v2, ok := val2.([]any); ok {
if len(v2) > 0 {
if _, isString := v2[0].(string); !isString {
return nil, fmt.Errorf("mismatching types: string and list of %T", v2[0])
}
}

return []any{v1, v2}, nil
return deduplicateList(append([]any{v1}, v2...)), nil
}

return nil, fmt.Errorf("mismatching types: %T and %T", val1, val2)
default:
if v2, ok := val2.([]any); ok {
if len(v2) > 0 {
if fmt.Sprintf("%T", val1) != fmt.Sprintf("%T", v2[0]) {
return nil, fmt.Errorf("mismatching types: %T and list of %T", val1, v2[0])
}
}

return deduplicateList(append([]any{val1}, v2...)), nil
}
if fmt.Sprintf("%T", val1) != fmt.Sprintf("%T", val2) {
return nil, fmt.Errorf("mismatching types: %T and %T", val1, val2)
}
if val1 == val2 {
return val1, nil
}

return []any{val1, val2}, nil
return deduplicateList([]any{val1, val2}), nil
}
}

// deduplicateList removes duplicate comparable elements (like strings) from a list.
func deduplicateList(list []any) []any {
var unique []any
seen := make(map[any]bool)
for _, item := range list {
switch item.(type) {
case string, int, int32, int64, float32, float64, bool:
if !seen[item] {
seen[item] = true
unique = append(unique, item)
}
default:
unique = append(unique, item)
}
}

return unique
}

func CreateUnresolvedRanges(unresolvedRanges []models.RangeWithMetadata) *structpb.ListValue {
if len(unresolvedRanges) > 0 {
var unresolvedRangesMap []map[string]any
for _, ur := range unresolvedRanges {
urMap := map[string]any{
"range": ur.Range,
}
if ur.Metadata.CPE != "" {
urMap["metadata"] = map[string]any{
"cpe": ur.Metadata.CPE,
}
}
unresolvedRangesMap = append(unresolvedRangesMap, urMap)
}

ds, err := utility.NewStructpbFromMap(map[string]any{
"list": unresolvedRangesMap,
})
if err != nil {
logger.Warn("failed to convert unresolved ranges to structpb", "err", err)
return nil
}
return ds.Fields["list"].GetListValue()
}

return nil
}

func AddFieldToDatabaseSpecific(ds *structpb.Struct, field string, value any) error {
if ds == nil {
return errors.New("database specific is nil")
}
if ds.Fields == nil {
return errors.New("database specific fields is nil")
}
if ds.GetFields()[field] != nil {
return fmt.Errorf("field %s already exists", field)
}

switch v := value.(type) {
case *structpb.Value:
ds.Fields[field] = v
case *structpb.Struct:
ds.Fields[field] = structpb.NewStructValue(v)
case *structpb.ListValue:
ds.Fields[field] = structpb.NewListValue(v)
default:
val, err := structpb.NewValue(v)
if err != nil {
return fmt.Errorf("failed to create structpb value: %w", err)
}
ds.Fields[field] = val
}

return nil
}

// ProcessRanges attempts to resolve the given ranges to commits and updates the metrics accordingly.
func ProcessRanges(ranges []models.RangeWithMetadata, repos []string, metrics *models.ConversionMetrics, cache *git.RepoTagsCache, source models.VersionSource) ([]*osvschema.Range, []models.RangeWithMetadata, []string) {
if len(ranges) == 0 {
return nil, nil, nil
}

r, un, sR := GitVersionsToCommits(ranges, repos, metrics, cache)
if len(r) > 0 {
metrics.ResolvedRangesCount += len(r)
metrics.SetOutcome(models.Successful)
}

if len(un) > 0 {
metrics.UnresolvedRangesCount += len(un)
if len(r) == 0 {
metrics.SetOutcome(models.NoCommitRanges)
}
}

metrics.VersionSources = append(metrics.VersionSources, source)

return r, un, sR
}
22 changes: 14 additions & 8 deletions vulnfeeds/conversion/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,16 @@ func TestMergeDatabaseSpecificValues(t *testing.T) {
want: []any{"a", "b", "c", "d"},
},
{
name: "List and string mismatch",
val1: []any{"a", "b"},
val2: "c",
wantErr: true,
name: "List and string",
val1: []any{"a", "b"},
val2: "c",
want: []any{"a", "b", "c"},
},
{
name: "String and list",
val1: "a",
val2: []any{"b", "c"},
want: []any{"a", "b", "c"},
},
{
name: "Merge maps",
Expand Down Expand Up @@ -268,7 +274,7 @@ func TestMergeDatabaseSpecificValues(t *testing.T) {
name: "Merge same strings",
val1: "value1",
val2: "value1",
want: "value1",
want: []any{"value1"},
},
{
name: "Merge different strings",
Expand Down Expand Up @@ -304,13 +310,13 @@ func TestMergeDatabaseSpecificValues(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := mergeDatabaseSpecificValues(tt.val1, tt.val2)
got, err := MergeDatabaseSpecificValues(tt.val1, tt.val2)
if (err != nil) != tt.wantErr {
t.Errorf("mergeDatabaseSpecificValues() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("MergeDatabaseSpecificValues() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && !cmp.Equal(got, tt.want) {
t.Errorf("mergeDatabaseSpecificValues() mismatch (-want +got):\n%s", cmp.Diff(tt.want, got))
t.Errorf("MergeDatabaseSpecificValues() mismatch (-want +got):\n%s", cmp.Diff(tt.want, got))
}
})
}
Expand Down
Loading