Skip to content

Supplier Mixin #9761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/backend/InvenTree/InvenTree/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,3 +529,52 @@ def validate_remote_image(self, url):
raise ValidationError(_('Failed to download image from remote URL'))

return url


class ListUniqueValidator:
"""List validator that validates unique fields for bulk create.

See: https://github.com/encode/django-rest-framework/issues/6395#issuecomment-452412653
"""

message = 'This field must be unique.'

def __init__(self, unique_field_names):
"""Initialize the validator with a list of unique field names."""
self.unique_field_names = unique_field_names

@staticmethod
def has_duplicates(counter):
"""Check if there are any duplicate values in the counter."""
return any(count for count in counter.values() if count > 1)

def __call__(self, value):
"""Validate that the specified fields are unique across the list of items."""
from collections import Counter

field_counters = {
field_name: Counter(
item[field_name] for item in value if field_name in item
)
for field_name in self.unique_field_names
}
has_duplicates = any(
ListUniqueValidator.has_duplicates(counter)
for counter in field_counters.values()
)
if has_duplicates:
errors = []
for item in value:
error = {}
for field_name in self.unique_field_names:
counter = field_counters[field_name]
if counter[item.get(field_name)] > 1:
error[field_name] = self.message
errors.append(error)
raise ValidationError(errors)

def __repr__(self):
"""Return a string representation of the validator."""
return (
f'<{self.__class__.__name__}(unique_field_names={self.unique_field_names})>'
)
23 changes: 23 additions & 0 deletions src/backend/InvenTree/part/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from drf_spectacular.utils import extend_schema_field
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from rest_framework.generics import CreateAPIView
from rest_framework.response import Response

import InvenTree.permissions
Expand Down Expand Up @@ -1702,6 +1703,23 @@ class PartParameterList(PartParameterAPIMixin, DataExportViewMixin, ListCreateAP
]


class PartParameterBulkCreate(CreateAPIView):
"""Bulk create part parameters.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wolflu05 it might be worth looking into a BulkCreateMixin class here, to make this a generic approach.

I recently added in a BulkUpdateMixin - #9313 - which employs a very similar approach. The intent here is to:

  1. Reduce API calls for better throughput and atomicity
  2. Utilize the already defined serializer classes for back-end validation

So, thoughts? A BulkCreateMixin would complement the BulkUpdateMixin and BulkDeleteMixin classess nicely!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I'll see what I can do and if I need any pointers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you submit that as a separate PR first? I'd like to be able to review that separately


- POST: Bulk create part parameters
"""

serializer_class = part_serializers.PartParameterBulkSerializer
queryset = PartParameter.objects.all()

def get_serializer(self, *args, **kwargs):
"""Return the serializer instance for this API endpoint."""
if isinstance(kwargs.get('data', {}), list):
kwargs['many'] = True

return super().get_serializer(*args, **kwargs)


class PartParameterDetail(PartParameterAPIMixin, RetrieveUpdateDestroyAPI):
"""API endpoint for detail view of a single PartParameter object."""

Expand Down Expand Up @@ -2184,6 +2202,11 @@ class BomItemSubstituteDetail(RetrieveUpdateDestroyAPI):
),
]),
),
path(
'bulk/',
PartParameterBulkCreate.as_view(),
name='api-part-parameter-bulk-create',
),
path('', PartParameterList.as_view(), name='api-part-parameter-list'),
]),
),
Expand Down
23 changes: 23 additions & 0 deletions src/backend/InvenTree/part/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,29 @@ def __init__(self, *args, **kwargs):
)


class TemplateUniquenessListSerializer(serializers.ListSerializer):
"""List serializer that validates unique fields for bulk create."""

validators = [
InvenTree.serializers.ListUniqueValidator(unique_field_names=['template'])
]


class PartParameterBulkSerializer(InvenTree.serializers.InvenTreeModelSerializer):
"""JSON serializers for the PartParameter model."""

class Meta:
"""Metaclass defining serializer fields."""

model = PartParameter
fields = ['pk', 'part', 'template', 'data', 'data_numeric']
list_serializer_class = TemplateUniquenessListSerializer

def __init__(self, *args, **kwargs):
"""Custom initialization method for the serializer."""
serializers.ModelSerializer.__init__(self, *args, **kwargs)


class DuplicatePartSerializer(serializers.Serializer):
"""Serializer for specifying options when duplicating a Part.

Expand Down
2 changes: 2 additions & 0 deletions src/backend/InvenTree/plugin/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from plugin.base.action.api import ActionPluginView
from plugin.base.barcodes.api import barcode_api_urls
from plugin.base.locate.api import LocatePluginView
from plugin.base.supplier.api import supplier_api_urls
from plugin.base.ui.api import ui_plugins_api_urls
from plugin.models import PluginConfig, PluginSetting
from plugin.plugin import InvenTreePlugin
Expand Down Expand Up @@ -525,4 +526,5 @@ class PluginMetadataView(MetadataView):
path('', PluginList.as_view(), name='api-plugin-list'),
]),
),
path('supplier/', include(supplier_api_urls)),
]
Empty file.
185 changes: 185 additions & 0 deletions src/backend/InvenTree/plugin/base/supplier/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""API views for supplier plugins in InvenTree."""

from django.db import transaction
from django.urls import path

from rest_framework import status
from rest_framework.exceptions import NotFound
from rest_framework.response import Response
from rest_framework.views import APIView

from InvenTree import permissions
from part.models import PartCategoryParameterTemplate
from plugin import registry
from plugin.plugin import PluginMixinEnum

from .serializers import (
ImportRequestSerializer,
ImportResultSerializer,
SearchResultSerializer,
)

# from .supplier import ImportParameter, PartNotFoundError


class SearchPart(APIView):
"""Search parts by supplier.

- GET: Start part search
"""

role_required = 'part.add'
permission_classes = [
permissions.IsAuthenticatedOrReadScope,
permissions.RolePermission,
]

def get(self, request):
"""Search parts by supplier."""
supplier_slug = request.query_params.get('supplier', '')

supplier = None
for plugin in registry.with_mixin(PluginMixinEnum.SUPPLIER):
if plugin.slug == supplier_slug:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the "supplier" supposed to be the name of the supplier (e.g. digikey) or the slug for the plugin?

supplier = plugin
break

if not supplier:
raise NotFound(detail=f"Supplier '{supplier_slug}' not found")

term = request.query_params.get('term', '')
try:
results = supplier.get_search_results(term)
except Exception as e:
return Response(
{'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
)

response = SearchResultSerializer(results, many=True).data
return Response(response)


class ImportPart(APIView):
"""Import a part by supplier.

- POST: Attempt to import part by sku
"""

role_required = 'part.add'
permission_classes = [
permissions.IsAuthenticatedOrReadScope,
permissions.RolePermission,
]

def post(self, request):
"""Import a part by supplier."""
serializer = ImportRequestSerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)

# Extract validated data
supplier_slug = serializer.validated_data.get('supplier', '')
part_import_id = serializer.validated_data.get('part_import_id', None)
category = serializer.validated_data.get('category_id', None)
part = serializer.validated_data.get('part_id', None)

# Find the supplier plugin
supplier = None
for plugin in registry.with_mixin(PluginMixinEnum.SUPPLIER):
if plugin.slug == supplier_slug:
supplier = plugin
break

# Validate supplier and part/category
if not supplier:
raise NotFound(detail=f"Supplier '{supplier_slug}' not found")
if not part and not category:
return Response(
{
'detail': "'category_id' is not provided, but required if no part_id is provided"
},
status=status.HTTP_400_BAD_REQUEST,
)

from plugin.base.supplier.mixins import SupplierMixin

# Import part data
try:
import_data = supplier.get_import_data(part_import_id)

with transaction.atomic():
# create part if it does not exist
if not part:
part = supplier.import_part(
import_data, category=category, creation_user=request.user
)

# create manufacturer part
manufacturer_part = supplier.import_manufacturer_part(
import_data, part=part
)

# create supplier part
supplier_part = supplier.import_supplier_part(
import_data, part=part, manufacturer_part=manufacturer_part
)

# set default supplier if not set
if not part.default_supplier:
part.default_supplier = supplier_part
part.save()

# get pricing
pricing = supplier.get_pricing_data(import_data)

# get parameters
parameters = supplier.get_parameters(import_data)
except SupplierMixin.PartNotFoundError:
return Response(
{'detail': f"Part with id: '{part_import_id}' not found"},
status=status.HTTP_404_NOT_FOUND,
)
except Exception as e:
return Response(
{'detail': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
)

# add default parameters for category
if category:
categories = category.get_ancestors(include_self=True)
category_parameters = PartCategoryParameterTemplate.objects.filter(
category__in=categories
)

for c in category_parameters:
for p in parameters:
if p.parameter_template == c.parameter_template:
p.on_category = True
p.value = p.value if p.value is not None else c.default_value
break
else:
parameters.append(
SupplierMixin.ImportParameter(
name=c.parameter_template.name,
value=c.default_value,
on_category=True,
parameter_template=c.parameter_template,
)
)
parameters.sort(key=lambda x: x.on_category, reverse=True)

response = ImportResultSerializer({
'part_id': part.pk,
'part_detail': part,
'supplier_part_id': supplier_part.pk,
'manufacturer_part_id': manufacturer_part.pk,
'pricing': pricing,
'parameters': parameters,
}).data
return Response(response)


supplier_api_urls = [
path('search/', SearchPart.as_view(), name='api-supplier-search'),
path('import/', ImportPart.as_view(), name='api-supplier-import'),
]
Loading
Loading