Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -148,58 +148,20 @@ internal fun AggregationSpec.getFeatureId(): String {
}
}

internal fun List<AggregationSpec>.metrics(): List<MetricDefinition> = buildList {
for (aggregation in this@metrics) {
when (aggregation) {
// Count and PrivacyIdCount do not aggregate any specific value, therefore they are handled
// differently.
is PrivacyIdCount ->
add(
MetricDefinition(
MetricType.PRIVACY_ID_COUNT,
aggregation.budget?.toInternalBudgetPerOpSpec(),
)
)
is Count ->
add(MetricDefinition(MetricType.COUNT, aggregation.budget?.toInternalBudgetPerOpSpec()))
is ValueAggregations<*> -> {
for (valueAggregationSpec in aggregation.valueAggregationSpecs) {
add(
MetricDefinition(
valueAggregationSpec.metricType,
valueAggregationSpec.budget?.toInternalBudgetPerOpSpec(),
)
)
}
}
is VectorAggregations<*> -> {
for (vectorAggregationSpec in aggregation.vectorAggregationSpecs) {
add(
MetricDefinition(
vectorAggregationSpec.metricType,
vectorAggregationSpec.budget?.toInternalBudgetPerOpSpec(),
)
)
}
}
}
}
}

internal fun List<AggregationSpec>.outputColumnNamesWithMetricTypes():
List<Pair<String, MetricType>> = buildList {
for (aggregation in this@outputColumnNamesWithMetricTypes) {
when (aggregation) {
is PrivacyIdCount -> add(aggregation.outputColumnName to MetricType.PRIVACY_ID_COUNT)
is Count -> add(aggregation.outputColumnName to MetricType.COUNT)
is PrivacyIdCount -> add(Pair(aggregation.outputColumnName, MetricType.PRIVACY_ID_COUNT))
is Count -> add(Pair(aggregation.outputColumnName, MetricType.COUNT))
is ValueAggregations<*> -> {
for (valueAggregationSpec in aggregation.valueAggregationSpecs) {
add(valueAggregationSpec.outputColumnName to valueAggregationSpec.metricType)
add(Pair(valueAggregationSpec.outputColumnName, valueAggregationSpec.metricType))
}
}
is VectorAggregations<*> -> {
for (vectorAggregationSpec in aggregation.vectorAggregationSpecs) {
add(vectorAggregationSpec.outputColumnName to vectorAggregationSpec.metricType)
add(Pair(vectorAggregationSpec.outputColumnName, vectorAggregationSpec.metricType))
}
}
}
Expand Down Expand Up @@ -227,3 +189,22 @@ internal fun List<AggregationSpec>.outputColumnNameToFeatureIdMap(): Map<String,

internal fun List<AggregationSpec>.outputColumnNames(): List<String> =
outputColumnNamesWithMetricTypes().map { it.first }

internal fun AggregationSpec.toNonFeatureMetricDefinition(): MetricDefinition {
val (metricType, budget) =
when (this) {
is Count -> Pair(MetricType.COUNT, this.budget)
is PrivacyIdCount -> Pair(MetricType.PRIVACY_ID_COUNT, this.budget)
else ->
throw IllegalArgumentException("Unsupported AggregationSpec type for non feature metrics")
}
return MetricDefinition(metricType, budget?.toInternalBudgetPerOpSpec())
}

internal fun ValueAggregationSpec.toMetricDefinition(): MetricDefinition {
return MetricDefinition(this.metricType, this.budget?.toInternalBudgetPerOpSpec())
}

internal fun VectorAggregationSpec.toMetricDefinition(): MetricDefinition {
return MetricDefinition(this.metricType, this.budget?.toInternalBudgetPerOpSpec())
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@ import com.google.privacy.differentialprivacy.pipelinedp4j.core.DpEngine
import com.google.privacy.differentialprivacy.pipelinedp4j.core.DpEngineBudgetSpec
import com.google.privacy.differentialprivacy.pipelinedp4j.core.Encoder
import com.google.privacy.differentialprivacy.pipelinedp4j.core.EncoderFactory
import com.google.privacy.differentialprivacy.pipelinedp4j.core.FeatureSpec
import com.google.privacy.differentialprivacy.pipelinedp4j.core.FeatureValuesExtractor
import com.google.privacy.differentialprivacy.pipelinedp4j.core.FrameworkCollection
import com.google.privacy.differentialprivacy.pipelinedp4j.core.FrameworkTable
import com.google.privacy.differentialprivacy.pipelinedp4j.core.MetricType
import com.google.privacy.differentialprivacy.pipelinedp4j.core.ScalarFeatureSpec
import com.google.privacy.differentialprivacy.pipelinedp4j.core.SelectPartitionsParams
import com.google.privacy.differentialprivacy.pipelinedp4j.core.VectorFeatureSpec
import com.google.privacy.differentialprivacy.pipelinedp4j.proto.DpAggregates
import com.google.privacy.differentialprivacy.pipelinedp4j.proto.PerFeature
import com.google.privacy.differentialprivacy.pipelinedp4j.proto.copy
Expand Down Expand Up @@ -494,22 +497,53 @@ protected constructor(
valueAggregations: ValueAggregations<*>?,
vectorAggregations: VectorAggregations<*>?,
): AggregationParams {
val valueContributionBounds = valueAggregations?.contributionBounds
val vectorContributionBounds = vectorAggregations?.vectorContributionBounds
val nonFeatureMetrics =
aggregationSpecs
.filter { it is Count || it is PrivacyIdCount }
.map { it.toNonFeatureMetricDefinition() }
val features =
buildList<FeatureSpec> {
if (valueAggregations != null) {
val valueContributionBounds = valueAggregations.contributionBounds
add(
ScalarFeatureSpec(
featureId = valueAggregations.getFeatureId(),
metrics =
valueAggregations.valueAggregationSpecs
.map { it.toMetricDefinition() }
.toImmutableList(),
minValue = valueContributionBounds.valueBounds?.minValue,
maxValue = valueContributionBounds.valueBounds?.maxValue,
minTotalValue = valueContributionBounds.totalValueBounds?.minValue,
maxTotalValue = valueContributionBounds.totalValueBounds?.maxValue,
)
)
}
if (vectorAggregations != null) {
val vectorContributionBounds = vectorAggregations.vectorContributionBounds
add(
VectorFeatureSpec(
featureId = vectorAggregations.getFeatureId(),
metrics =
vectorAggregations.vectorAggregationSpecs
.map { it.toMetricDefinition() }
.toImmutableList(),
vectorSize = vectorAggregations.vectorSize,
normKind = vectorContributionBounds.maxVectorTotalNorm.normKind.toInternalNormKind(),
vectorMaxTotalNorm = vectorContributionBounds.maxVectorTotalNorm.value,
)
)
}
}

return AggregationParams(
metrics = ImmutableList.copyOf(aggregationSpecs.metrics()),
nonFeatureMetrics = nonFeatureMetrics.toImmutableList(),
features = features.toImmutableList(),
noiseKind =
checkNotNull(noiseKind) { "noiseKind cannot be null if there are aggregations." }
.toInternalNoiseKind(),
maxPartitionsContributed = contributionBoundingLevel.getMaxPartitionsContributed(),
maxContributionsPerPartition = contributionBoundingLevel.getMaxContributionsPerPartition(),
minValue = valueContributionBounds?.valueBounds?.minValue,
maxValue = valueContributionBounds?.valueBounds?.maxValue,
minTotalValue = valueContributionBounds?.totalValueBounds?.minValue,
maxTotalValue = valueContributionBounds?.totalValueBounds?.maxValue,
vectorNormKind = vectorContributionBounds?.maxVectorTotalNorm?.normKind?.toInternalNormKind(),
vectorMaxTotalNorm = vectorContributionBounds?.maxVectorTotalNorm?.value,
vectorSize = vectorAggregations?.vectorSize,
partitionSelectionBudget = groupsType.getBudget()?.toInternalBudgetPerOpSpec(),
preThreshold = groupsType.getPreThreshold(),
contributionBoundingLevel = contributionBoundingLevel.toInternalContributionBoundingLevel(),
Expand All @@ -534,3 +568,5 @@ protected constructor(
}
}
}

private fun <T : Any> Iterable<T>.toImmutableList(): ImmutableList<T> = ImmutableList.copyOf(this)
Loading
Loading