Skip to content

Model filter #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build-macos-bindings.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ jobs:
path: |
/usr/local/include/opencv4
/usr/local/lib/libopencv_*
/usr/local/lib/cmake/opencv4
key: ${{ runner.os }}-opencv-${{ hashFiles('openvino_bindings/scripts/setup_opencv.sh') }}

- name: Install OpenCV
Expand Down
248 changes: 132 additions & 116 deletions lib/pages/import/huggingface.dart
Original file line number Diff line number Diff line change
Expand Up @@ -7,142 +7,158 @@ import 'package:inference/importers/manifest_importer.dart';
import 'package:inference/pages/import/providers/import_provider.dart';
import 'package:inference/pages/import/widgets/badge.dart';
import 'package:inference/pages/import/widgets/model_card.dart';
import 'package:inference/pages/models/widgets/model_filter.dart';
import 'package:inference/providers/project_filter_provider.dart';
import 'package:inference/theme_fluent.dart';
import 'package:inference/widgets/controls/dropdown_multiple_select.dart';
import 'package:inference/widgets/controls/search_bar.dart';
import 'package:inference/widgets/empty_model_widget.dart';
import 'package:inference/widgets/fixed_grid.dart';
import 'package:inference/widgets/grid_container.dart';
import 'package:provider/provider.dart';

class Huggingface extends StatefulWidget {
class Huggingface extends StatelessWidget {
const Huggingface({super.key});

@override
State<Huggingface> createState() => _HuggingfaceState();
}

class _HuggingfaceState extends State<Huggingface> {
List<String> selectedOptimizations = [];
String? searchValue;
bool orderAsc = true;
static Map<String, List<Option>> get filterOptions {
var options = {
"Text Generation": [
const Option("Text generation", "text-generation"),
],
"Image Generation": [
const Option("Text to Image", "text-to-image")
],
"Audio": [
const Option("Speech to text", "speech")
]
};

List<Model> filterModels(List<Model> models) {
var filteredModels = models;
if (searchValue != null && searchValue!.isNotEmpty) {
filteredModels = filteredModels.where((model) => model.name.toLowerCase().contains(searchValue!.toLowerCase())).toList();
}
if (selectedOptimizations.isNotEmpty) {
filteredModels = filteredModels.where((model) => selectedOptimizations.contains(model.optimizationPrecision)).toList();
}

filteredModels.sort((a,b) => a.name.compareTo(b.name) * (orderAsc ? -1 : 1));
return filteredModels;
return options;
}

@override
Widget build(BuildContext context) {
final theme = FluentTheme.of(context);

return Consumer<ImportProvider>(builder: (context, importProvider, child) {
return ConstrainedBox(constraints: const BoxConstraints(maxWidth: 1228),
child: Padding(
padding: const EdgeInsets.only(left: 133, right: 80, top: 36, bottom: 50),
child: Column(
children: [
Row(
mainAxisAlignment: MainAxisAlignment.spaceBetween,
children: [
Row(
children: [
ConstrainedBox(
constraints: const BoxConstraints(maxWidth: 280),
child: Semantics(
label: 'Find a model',
child: SearchBar(onChange: (value) { setState(() {
searchValue = value;
}); }, placeholder: 'Find a model',),
return Consumer<ProjectFilterProvider>(
builder: (context, filter, child) {
return ConstrainedBox(constraints: const BoxConstraints(maxWidth: 1228),
child: Row(
children: [
GridContainer(
color: backgroundColor.of(theme),
padding: const EdgeInsets.all(13),
child: ModelFilter(filterOptions: filterOptions)
),
Expanded(
child: GridContainer(
color: backgroundColor.of(theme),
padding: const EdgeInsets.only(left: 33, right: 80, top: 36, bottom: 50),
child: Column(
children: [
Row(
mainAxisAlignment: MainAxisAlignment.spaceBetween,
children: [
Row(
children: [
ConstrainedBox(
constraints: const BoxConstraints(maxWidth: 280),
child: Semantics(
label: 'Find a model',
child: SearchBar(
placeholder: 'Find a model',
onChange: (value) {
filter.name = value;
},
),
),
),
Padding(
padding: const EdgeInsets.only(left: 8),
child: ConstrainedBox(
constraints: const BoxConstraints(maxWidth: 184),
child: DropdownMultipleSelect(
items: const ['int4', 'int8', 'fp16'],
selectedItems: filter.optimizations,
onChanged: (value) {
filter.optimizations = value;
},
placeholder: 'Select optimizations',
),
),
)
],
),
IconButton(icon: Icon(filter.order ? FluentIcons.descending : FluentIcons.ascending, size: 18,), onPressed: () => filter.order = !filter.order),
],
),
),
Padding(
padding: const EdgeInsets.only(left: 8),
child: ConstrainedBox(
constraints: const BoxConstraints(maxWidth: 184),
child: DropdownMultipleSelect(
items: const ['int4', 'int8', 'fp16'],
selectedItems: selectedOptimizations,
onChanged: (value) {
if (!value.contains(importProvider.selectedModel?.optimizationPrecision)) {
importProvider.selectedModel = null;
}
setState(() {
selectedOptimizations = value;
});
},
placeholder: 'Select optimizations',
Padding(
padding: const EdgeInsets.symmetric(vertical: 12, horizontal: 2),
child: SizedBox(
height: 28,
width: double.infinity,
child: Align(
alignment: Alignment.centerLeft,
child: Wrap(
spacing: 8,
children: [
...filter.optimizations.map((opt) {
return Badge(text: opt, onDelete: () {
if (opt == importProvider.selectedModel?.optimizationPrecision && filter.optimizations.length > 1) {
importProvider.selectedModel = null;
}
filter.removeOptimization(opt);
});
}),
if (filter.option != null)
Badge(text: filter.option!.name, onDelete: () {
filter.option = null;
})
]
),
),
),
),
)
],
),
IconButton(icon: Icon(orderAsc ? FluentIcons.descending : FluentIcons.ascending, size: 18,), onPressed: () => setState(() => orderAsc = !orderAsc),),
],
),
Padding(
padding: const EdgeInsets.symmetric(vertical: 12, horizontal: 2),
child: SizedBox(
height: 28,
width: double.infinity,
child: Align(
alignment: Alignment.centerLeft,
child: Wrap(
spacing: 8,
children: selectedOptimizations.map((opt) {
return Badge(text: opt, onDelete: () {
if (opt == importProvider.selectedModel?.optimizationPrecision && selectedOptimizations.length > 1) {
importProvider.selectedModel = null;
}
setState(() {
selectedOptimizations.remove(opt);
});
});
}).toList(),
),
),
),
),
Expanded(
child: SingleChildScrollView(
child: FutureBuilder<List<Model>>(
future: importProvider.allModelsFuture,
builder: (context, snapshot) {
if (snapshot.connectionState == ConnectionState.waiting) {
return const ProgressRing();
} else if (snapshot.hasError) {
return Text('Error: ${snapshot.error}');
} else if (!snapshot.hasData || snapshot.data!.isEmpty) {
return const Text('No models available');
} else {
var allModels = filterModels(snapshot.data!);
return FixedGrid(
tileWidth: 226,
spacing: 24,
itemCount: allModels.length,
emptyWidget: EmptyModelListWidget(searchQuery: searchValue),
itemBuilder: (context, index) => ModelCard(
model: allModels[index],
checked: importProvider.selectedModel == allModels[index],
onChecked: (value) {
setState(() {
importProvider.selectedModel = value ? allModels[index] : null;
});
},
Expanded(
child: SingleChildScrollView(
child: FutureBuilder<List<Model>>(
future: importProvider.allModelsFuture,
builder: (context, snapshot) {
if (snapshot.connectionState == ConnectionState.waiting) {
return const ProgressRing();
} else if (snapshot.hasError) {
return Text('Error: ${snapshot.error}');
} else if (!snapshot.hasData || snapshot.data!.isEmpty) {
return const Text('No models available');
} else {
var allModels = filter.applyFilterOnModel(snapshot.data!);
return FixedGrid(
tileWidth: 226,
spacing: 24,
itemCount: allModels.length,
emptyWidget: EmptyModelListWidget(searchQuery: filter.name),
itemBuilder: (context, index) => ModelCard(
model: allModels[index],
checked: importProvider.selectedModel == allModels[index],
onChecked: (value) {
importProvider.selectedModel = value ? allModels[index] : null;
},
),
);
}
},
),
),
);
}
},
),
],
),
),
),
),
],
),
),
],
),
);
}
);
});
}
Expand Down
9 changes: 7 additions & 2 deletions lib/pages/import/import.dart
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import 'package:go_router/go_router.dart';
import 'package:inference/pages/import/huggingface.dart';
import 'package:inference/pages/import/providers/import_provider.dart';
import 'package:inference/pages/import/widgets/import_geti_model_dialog.dart';
import 'package:inference/providers/project_filter_provider.dart';
import 'package:inference/theme_fluent.dart';
import 'package:inference/widgets/controls/close_model_button.dart';
import 'package:provider/provider.dart';

Expand All @@ -26,7 +28,7 @@ class _ImportPageState extends State<ImportPage> {
final theme = FluentTheme.of(context);
final updatedTheme = theme.copyWith(
navigationPaneTheme: theme.navigationPaneTheme.merge(NavigationPaneThemeData(
backgroundColor: theme.scaffoldBackgroundColor,
backgroundColor: backgroundColor.of(theme),
))
);

Expand Down Expand Up @@ -56,7 +58,10 @@ class _ImportPageState extends State<ImportPage> {
PaneItem(
icon: SvgPicture.asset('images/huggingface_logo-noborder.svg', width: 15,),
title: const Text("Huggingface"),
body: const Huggingface(),
body: ChangeNotifierProvider<ProjectFilterProvider>(
create: (_) => ProjectFilterProvider(),
child: const Huggingface()
),
),
PaneItemAction(
icon: const Icon(FluentIcons.project_collection),
Expand Down
31 changes: 24 additions & 7 deletions lib/pages/models/models.dart
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,31 @@
import 'package:inference/widgets/import_model_button.dart';
import 'package:provider/provider.dart';

class ModelsPage extends StatefulWidget {
class ModelsPage extends StatelessWidget {
const ModelsPage({super.key});

@override
State<ModelsPage> createState() => _ModelsPageState();
}
static Map<String, List<Option>> get filterOptions {
var options = {
"Image": [
const Option("Detection", "detection"),
const Option("Classification", "classification"),
const Option("Segmentation", "segmentation"),
const Option("Anomaly detection","anomaly")
],
"Text Generation": [
const Option("Text generation", "text"),
],
"Image Generation": [
const Option("Text to Image", "text-to-image")
],
"Audio": [
const Option("Speech to text", "speech")
]
};

return options;
}

class _ModelsPageState extends State<ModelsPage> {
@override
Widget build(BuildContext context) {
final theme = FluentTheme.of(context);
Expand Down Expand Up @@ -50,9 +67,9 @@
),
Expanded(
child: GridContainer(
color: backgroundColor.of(theme),
color: backgroundColor.of(theme),

Check warning on line 70 in lib/pages/models/models.dart

View check run for this annotation

Codecov / codecov/patch

lib/pages/models/models.dart#L70

Added line #L70 was not covered by tests
padding: const EdgeInsets.all(13),
child: const ModelFilter()
child: ModelFilter(filterOptions: filterOptions)

Check warning on line 72 in lib/pages/models/models.dart

View check run for this annotation

Codecov / codecov/patch

lib/pages/models/models.dart#L72

Added line #L72 was not covered by tests
),
),
],
Expand Down
7 changes: 4 additions & 3 deletions lib/pages/models/widgets/model_filter.dart
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import 'package:provider/provider.dart';


class ModelFilter extends StatelessWidget {
const ModelFilter({super.key});
final Map<String, List<Option>> filterOptions;
const ModelFilter({super.key, required this.filterOptions});

@override
Widget build(BuildContext context) {
Expand All @@ -19,8 +20,8 @@ class ModelFilter extends StatelessWidget {
mainAxisAlignment: MainAxisAlignment.start,
crossAxisAlignment: CrossAxisAlignment.stretch,
children: [
...Option.filterOptions.keys.map((key) {
return Group(key, Option.filterOptions[key]!);
...filterOptions.keys.map((key) {
return Group(key, filterOptions[key]!);
}),
]
),
Expand Down
Loading
Loading