Skip to content

feat: Add slug to collections and use it in public collection URLs #2301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 112 additions & 17 deletions backend/btrixcloud/colls.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import asyncio
import pymongo
from pymongo.collation import Collation
from fastapi import Depends, HTTPException, Response
from fastapi.responses import StreamingResponse
from starlette.requests import Request
Expand Down Expand Up @@ -51,7 +52,7 @@
MIN_UPLOAD_PART_SIZE,
PublicCollOut,
)
from .utils import dt_now
from .utils import dt_now, slug_from_name, get_duplicate_key_error_field

if TYPE_CHECKING:
from .orgs import OrgOps
Expand Down Expand Up @@ -98,8 +99,17 @@ def set_page_ops(self, ops):

async def init_index(self):
"""init lookup index"""
case_insensitive_collation = Collation(locale="en", strength=1)
await self.collections.create_index(
[("oid", pymongo.ASCENDING), ("name", pymongo.ASCENDING)], unique=True
[("oid", pymongo.ASCENDING), ("name", pymongo.ASCENDING)],
unique=True,
collation=case_insensitive_collation,
)

await self.collections.create_index(
[("oid", pymongo.ASCENDING), ("slug", pymongo.ASCENDING)],
unique=True,
collation=case_insensitive_collation,
)

await self.collections.create_index(
Expand All @@ -111,23 +121,27 @@ async def add_collection(self, oid: UUID, coll_in: CollIn):
crawl_ids = coll_in.crawlIds if coll_in.crawlIds else []
coll_id = uuid4()
created = dt_now()
modified = dt_now()

slug = coll_in.slug or slug_from_name(coll_in.name)

coll = Collection(
id=coll_id,
oid=oid,
name=coll_in.name,
slug=slug,
description=coll_in.description,
caption=coll_in.caption,
created=created,
modified=modified,
modified=created,
access=coll_in.access,
defaultThumbnailName=coll_in.defaultThumbnailName,
allowPublicDownload=coll_in.allowPublicDownload,
)
try:
await self.collections.insert_one(coll.to_dict())
org = await self.orgs.get_org_by_id(oid)
await self.clear_org_previous_slugs_matching_slug(slug, org)

if crawl_ids:
await self.crawl_ops.add_to_collection(crawl_ids, coll_id, org)
await self.update_collection_counts_and_tags(coll_id)
Expand All @@ -139,9 +153,10 @@ async def add_collection(self, oid: UUID, coll_in: CollIn):
)

return {"added": True, "id": coll_id, "name": coll.name}
except pymongo.errors.DuplicateKeyError:
except pymongo.errors.DuplicateKeyError as err:
# pylint: disable=raise-missing-from
raise HTTPException(status_code=400, detail="collection_name_taken")
field = get_duplicate_key_error_field(err)
raise HTTPException(status_code=400, detail=f"collection_{field}_taken")

async def update_collection(
self, coll_id: UUID, org: Organization, update: UpdateColl
Expand All @@ -152,23 +167,55 @@ async def update_collection(
if len(query) == 0:
raise HTTPException(status_code=400, detail="no_update_data")

name_update = query.get("name")
slug_update = query.get("slug")

previous_slug = None

if name_update or slug_update:
# If we're updating slug, save old one to previousSlugs to support redirects
coll = await self.get_collection(coll_id)
previous_slug = coll.slug

if name_update and not slug_update:
slug = slug_from_name(name_update)
query["slug"] = slug
slug_update = slug

query["modified"] = dt_now()

db_update = {"$set": query}
if previous_slug:
db_update["$push"] = {"previousSlugs": previous_slug}

try:
result = await self.collections.find_one_and_update(
{"_id": coll_id, "oid": org.id},
{"$set": query},
db_update,
return_document=pymongo.ReturnDocument.AFTER,
)
except pymongo.errors.DuplicateKeyError:
except pymongo.errors.DuplicateKeyError as err:
# pylint: disable=raise-missing-from
raise HTTPException(status_code=400, detail="collection_name_taken")
field = get_duplicate_key_error_field(err)
raise HTTPException(status_code=400, detail=f"collection_{field}_taken")

if not result:
raise HTTPException(status_code=404, detail="collection_not_found")

if slug_update:
await self.clear_org_previous_slugs_matching_slug(slug_update, org)

return {"updated": True}

async def clear_org_previous_slugs_matching_slug(
self, slug: str, org: Organization
):
"""Clear new slug from previousSlugs array of other collections in same org"""
await self.collections.update_many(
{"oid": org.id, "previousSlugs": slug},
{"$pull": {"previousSlugs": slug}},
)

async def add_crawls_to_collection(
self, coll_id: UUID, crawl_ids: List[str], org: Organization
) -> CollOut:
Expand Down Expand Up @@ -234,13 +281,54 @@ async def get_collection_raw(

return result

async def get_collection_raw_by_slug(
self,
coll_slug: str,
previous_slugs: bool = False,
public_or_unlisted_only: bool = False,
) -> Dict[str, Any]:
"""Get collection by slug (current or previous) as dict from database"""
query: dict[str, object] = {}
if previous_slugs:
query["previousSlugs"] = coll_slug
else:
query["slug"] = coll_slug
if public_or_unlisted_only:
query["access"] = {"$in": ["public", "unlisted"]}

result = await self.collections.find_one(query)
if not result:
raise HTTPException(status_code=404, detail="collection_not_found")

return result

async def get_collection(
self, coll_id: UUID, public_or_unlisted_only: bool = False
) -> Collection:
"""Get collection by id"""
result = await self.get_collection_raw(coll_id, public_or_unlisted_only)
return Collection.from_dict(result)

async def get_collection_by_slug(
self, coll_slug: str, public_or_unlisted_only: bool = False
) -> Collection:
"""Get collection by slug"""
try:
result = await self.get_collection_raw_by_slug(
coll_slug, public_or_unlisted_only=public_or_unlisted_only
)
return Collection.from_dict(result)
# pylint: disable=broad-exception-caught
except Exception:
pass

result = await self.get_collection_raw_by_slug(
coll_slug,
previous_slugs=True,
public_or_unlisted_only=public_or_unlisted_only,
)
return Collection.from_dict(result)

async def get_collection_out(
self,
coll_id: UUID,
Expand All @@ -264,7 +352,10 @@ async def get_collection_out(
return CollOut.from_dict(result)

async def get_public_collection_out(
self, coll_id: UUID, org: Organization, allow_unlisted: bool = False
self,
coll_id: UUID,
org: Organization,
allow_unlisted: bool = False,
) -> PublicCollOut:
"""Get PublicCollOut by id"""
result = await self.get_collection_raw(coll_id)
Expand Down Expand Up @@ -1012,13 +1103,13 @@ async def get_org_public_collections(
)

@app.get(
"/public/orgs/{org_slug}/collections/{coll_id}",
"/public/orgs/{org_slug}/collections/{coll_slug}",
tags=["collections", "public"],
response_model=PublicCollOut,
)
async def get_public_collection(
org_slug: str,
coll_id: UUID,
coll_slug: str,
):
try:
org = await colls.orgs.get_org_by_slug(org_slug)
Expand All @@ -1027,16 +1118,18 @@ async def get_public_collection(
# pylint: disable=raise-missing-from
raise HTTPException(status_code=404, detail="collection_not_found")

return await colls.get_public_collection_out(coll_id, org, allow_unlisted=True)
coll = await colls.get_collection_by_slug(coll_slug)

return await colls.get_public_collection_out(coll.id, org, allow_unlisted=True)

@app.get(
"/public/orgs/{org_slug}/collections/{coll_id}/download",
"/public/orgs/{org_slug}/collections/{coll_slug}/download",
tags=["collections", "public"],
response_model=bytes,
)
async def download_public_collection(
org_slug: str,
coll_id: UUID,
coll_slug: str,
):
try:
org = await colls.orgs.get_org_by_slug(org_slug)
Expand All @@ -1046,12 +1139,14 @@ async def download_public_collection(
raise HTTPException(status_code=404, detail="collection_not_found")

# Make sure collection exists and is public/unlisted
coll = await colls.get_collection(coll_id, public_or_unlisted_only=True)
coll = await colls.get_collection_by_slug(
coll_slug, public_or_unlisted_only=True
)

if coll.allowPublicDownload is False:
raise HTTPException(status_code=403, detail="not_allowed")

return await colls.download_collection(coll_id, org)
return await colls.download_collection(coll.id, org)

@app.get(
"/orgs/{oid}/collections/{coll_id}/urls",
Expand Down
2 changes: 1 addition & 1 deletion backend/btrixcloud/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .migrations import BaseMigration


CURR_DB_VERSION = "0038"
CURR_DB_VERSION = "0039"


# ============================================================================
Expand Down
38 changes: 38 additions & 0 deletions backend/btrixcloud/migrations/migration_0039_coll_slugs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Migration 0039 -- collection slugs
"""

from btrixcloud.migrations import BaseMigration
from btrixcloud.utils import slug_from_name


MIGRATION_VERSION = "0039"


class Migration(BaseMigration):
"""Migration class."""

# pylint: disable=unused-argument
def __init__(self, mdb, **kwargs):
super().__init__(mdb, migration_version=MIGRATION_VERSION)

async def migrate_up(self):
"""Perform migration up.

Add slug to collections that don't have one yet, based on name
"""
colls_mdb = self.mdb["collections"]

async for coll_raw in colls_mdb.find({"slug": None}):
coll_id = coll_raw["_id"]
try:
await colls_mdb.find_one_and_update(
{"_id": coll_id},
{"$set": {"slug": slug_from_name(coll_raw.get("name", ""))}},
)
# pylint: disable=broad-exception-caught
except Exception as err:
print(
f"Error saving slug for collection {coll_id}: {err}",
flush=True,
)
10 changes: 10 additions & 0 deletions backend/btrixcloud/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,9 @@ class CollAccessType(str, Enum):
class Collection(BaseMongoModel):
"""Org collection structure"""

id: UUID
name: str = Field(..., min_length=1)
slug: str = Field(..., min_length=1)
oid: UUID
description: Optional[str] = None
caption: Optional[str] = None
Expand Down Expand Up @@ -1264,12 +1266,15 @@ class Collection(BaseMongoModel):

allowPublicDownload: Optional[bool] = True

previousSlugs: List[str] = []


# ============================================================================
class CollIn(BaseModel):
"""Collection Passed in By User"""

name: str = Field(..., min_length=1)
slug: Optional[str] = None
description: Optional[str] = None
caption: Optional[str] = None
crawlIds: Optional[List[str]] = []
Expand All @@ -1284,7 +1289,9 @@ class CollIn(BaseModel):
class CollOut(BaseMongoModel):
"""Collection output model with annotations."""

id: UUID
name: str
slug: str
oid: UUID
description: Optional[str] = None
caption: Optional[str] = None
Expand Down Expand Up @@ -1318,7 +1325,9 @@ class CollOut(BaseMongoModel):
class PublicCollOut(BaseMongoModel):
"""Collection output model with annotations."""

id: UUID
name: str
slug: str
oid: UUID
description: Optional[str] = None
caption: Optional[str] = None
Expand Down Expand Up @@ -1349,6 +1358,7 @@ class UpdateColl(BaseModel):
"""Update collection"""

name: Optional[str] = None
slug: Optional[str] = None
description: Optional[str] = None
caption: Optional[str] = None
access: Optional[CollAccessType] = None
Expand Down
2 changes: 2 additions & 0 deletions backend/btrixcloud/orgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,8 @@ async def import_org(
# collections
for collection in org_data.get("collections", []):
collection = json_stream.to_standard_types(collection)
if not collection.get("slug"):
collection["slug"] = slug_from_name(collection["name"])
await self.colls_db.insert_one(Collection.from_dict(collection).to_dict())

async def delete_org_and_data(
Expand Down
13 changes: 7 additions & 6 deletions backend/btrixcloud/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,16 @@ def stream_dict_list_as_csv(

def get_duplicate_key_error_field(err: DuplicateKeyError) -> str:
"""Get name of duplicate field from pymongo DuplicateKeyError"""
dupe_field = "name"
allowed_fields = ("name", "slug", "subscription.subId")

if err.details:
key_value = err.details.get("keyValue")
if key_value:
try:
dupe_field = list(key_value.keys())[0]
except IndexError:
pass
return dupe_field
for field in key_value.keys():
if field in allowed_fields:
return field

return "name"


def get_origin(headers) -> str:
Expand Down
Loading
Loading