Skip to content

Commit 37e234e

Browse files
ardaatahanpcuenca
andauthored
Add get_hf_file_metadata Functionality (#142)
* add getHfFileMetadata function to HubApi * only allow huggingface endpoints in getHfFileMetadata * add test case for getHfFileMetadata * remove hardcoded string from location check in test case * rename getHfFileMetadata to getFileMetadata and refactor * add blob search for file metadata * Update Tests/HubTests/HubApiTests.swift Co-authored-by: Pedro Cuenca <pedro@huggingface.co> --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
1 parent 2f611bf commit 37e234e

File tree

2 files changed

+155
-0
lines changed

2 files changed

+155
-0
lines changed

Sources/Hub/HubApi.swift

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,25 @@ public extension HubApi {
6060
return (data, response)
6161
}
6262

63+
func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) {
64+
var request = URLRequest(url: url)
65+
request.httpMethod = "HEAD"
66+
if let hfToken = hfToken {
67+
request.setValue("Bearer \(hfToken)", forHTTPHeaderField: "Authorization")
68+
}
69+
request.setValue("identity", forHTTPHeaderField: "Accept-Encoding")
70+
let (data, response) = try await URLSession.shared.data(for: request)
71+
guard let response = response as? HTTPURLResponse else { throw Hub.HubClientError.unexpectedError }
72+
73+
switch response.statusCode {
74+
case 200..<300: break
75+
case 400..<500: throw Hub.HubClientError.authorizationRequired
76+
default: throw Hub.HubClientError.httpStatusCode(response.statusCode)
77+
}
78+
79+
return (data, response)
80+
}
81+
6382
func getFilenames(from repo: Repo, matching globs: [String] = []) async throws -> [String] {
6483
// Read repo info and only parse "siblings"
6584
let url = URL(string: "\(endpoint)/api/\(repo.type)/\(repo.id)")!
@@ -222,6 +241,65 @@ public extension HubApi {
222241
}
223242
}
224243

244+
/// Metadata
245+
public extension HubApi {
246+
/// A structure representing metadata for a remote file
247+
struct FileMetadata {
248+
/// The file's Git commit hash
249+
public let commitHash: String?
250+
251+
/// Server-provided ETag for caching
252+
public let etag: String?
253+
254+
/// Stringified URL location of the file
255+
public let location: String
256+
257+
/// The file's size in bytes
258+
public let size: Int?
259+
}
260+
261+
private func normalizeEtag(_ etag: String?) -> String? {
262+
guard let etag = etag else { return nil }
263+
return etag.trimmingPrefix("W/").trimmingCharacters(in: CharacterSet(charactersIn: "\""))
264+
}
265+
266+
func getFileMetadata(url: URL) async throws -> FileMetadata {
267+
let (_, response) = try await httpHead(for: url)
268+
269+
return FileMetadata(
270+
commitHash: response.value(forHTTPHeaderField: "X-Repo-Commit"),
271+
etag: normalizeEtag(
272+
(response.value(forHTTPHeaderField: "X-Linked-Etag")) ?? (response.value(forHTTPHeaderField: "Etag"))
273+
),
274+
location: (response.value(forHTTPHeaderField: "Location")) ?? url.absoluteString,
275+
size: Int(response.value(forHTTPHeaderField: "X-Linked-Size") ?? response.value(forHTTPHeaderField: "Content-Length") ?? "")
276+
)
277+
}
278+
279+
func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [FileMetadata] {
280+
let files = try await getFilenames(from: repo, matching: globs)
281+
let url = URL(string: "\(endpoint)/\(repo.id)/resolve/main")! // TODO: revisions
282+
var selectedMetadata: Array<FileMetadata> = []
283+
for file in files {
284+
let fileURL = url.appending(path: file)
285+
selectedMetadata.append(try await getFileMetadata(url: fileURL))
286+
}
287+
return selectedMetadata
288+
}
289+
290+
func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [FileMetadata] {
291+
return try await getFileMetadata(from: Repo(id: repoId), matching: globs)
292+
}
293+
294+
func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [FileMetadata] {
295+
return try await getFileMetadata(from: repo, matching: [glob])
296+
}
297+
298+
func getFileMetadata(from repoId: String, matching glob: String) async throws -> [FileMetadata] {
299+
return try await getFileMetadata(from: Repo(id: repoId), matching: [glob])
300+
}
301+
}
302+
225303
/// Stateless wrappers that use `HubApi` instances
226304
public extension Hub {
227305
static func getFilenames(from repo: Hub.Repo, matching globs: [String] = []) async throws -> [String] {
@@ -259,6 +337,26 @@ public extension Hub {
259337
static func whoami(token: String) async throws -> Config {
260338
return try await HubApi(hfToken: token).whoami()
261339
}
340+
341+
static func getFileMetadata(fileURL: URL) async throws -> HubApi.FileMetadata {
342+
return try await HubApi.shared.getFileMetadata(url: fileURL)
343+
}
344+
345+
static func getFileMetadata(from repo: Repo, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] {
346+
return try await HubApi.shared.getFileMetadata(from: repo, matching: globs)
347+
}
348+
349+
static func getFileMetadata(from repoId: String, matching globs: [String] = []) async throws -> [HubApi.FileMetadata] {
350+
return try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: globs)
351+
}
352+
353+
static func getFileMetadata(from repo: Repo, matching glob: String) async throws -> [HubApi.FileMetadata] {
354+
return try await HubApi.shared.getFileMetadata(from: repo, matching: [glob])
355+
}
356+
357+
static func getFileMetadata(from repoId: String, matching glob: String) async throws -> [HubApi.FileMetadata] {
358+
return try await HubApi.shared.getFileMetadata(from: Repo(id: repoId), matching: [glob])
359+
}
262360
}
263361

264362
public extension [String] {

Tests/HubTests/HubApiTests.swift

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,63 @@ class HubApiTests: XCTestCase {
8787
XCTFail("\(error)")
8888
}
8989
}
90+
91+
func testGetFileMetadata() async throws {
92+
do {
93+
let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/config.json")
94+
let metadata = try await Hub.getFileMetadata(fileURL: url!)
95+
96+
XCTAssertNotNil(metadata.commitHash)
97+
XCTAssertNotNil(metadata.etag)
98+
XCTAssertEqual(metadata.location, url?.absoluteString)
99+
XCTAssertEqual(metadata.size, 163)
100+
} catch {
101+
XCTFail("\(error)")
102+
}
103+
}
104+
105+
func testGetFileMetadataBlobPath() async throws {
106+
do {
107+
let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/blob/main/config.json")
108+
let metadata = try await Hub.getFileMetadata(fileURL: url!)
109+
110+
XCTAssertEqual(metadata.commitHash, nil)
111+
XCTAssertTrue(metadata.etag != nil && metadata.etag!.hasPrefix("10841-"))
112+
XCTAssertEqual(metadata.location, url?.absoluteString)
113+
XCTAssertEqual(metadata.size, 67649)
114+
} catch {
115+
XCTFail("\(error)")
116+
}
117+
}
118+
119+
func testGetFileMetadataWithRevision() async throws {
120+
do {
121+
let revision = "f2c752cfc5c0ab6f4bdec59acea69eefbee381c2"
122+
let url = URL(string: "https://huggingface.co/julien-c/dummy-unknown/resolve/\(revision)/config.json")
123+
let metadata = try await Hub.getFileMetadata(fileURL: url!)
124+
125+
XCTAssertEqual(metadata.commitHash, revision)
126+
XCTAssertNotNil(metadata.etag)
127+
XCTAssertGreaterThan(metadata.etag!.count, 0)
128+
XCTAssertEqual(metadata.location, url?.absoluteString)
129+
XCTAssertEqual(metadata.size, 851)
130+
} catch {
131+
XCTFail("\(error)")
132+
}
133+
}
134+
135+
func testGetFileMetadataWithBlobSearch() async throws {
136+
let repo = "coreml-projects/Llama-2-7b-chat-coreml"
137+
let metadataFromBlob = try await Hub.getFileMetadata(from: repo, matching: "*.json").sorted { $0.location < $1.location }
138+
let files = try await Hub.getFilenames(from: repo, matching: "*.json").sorted()
139+
for (metadata, file) in zip(metadataFromBlob, files) {
140+
XCTAssertNotNil(metadata.commitHash)
141+
XCTAssertNotNil(metadata.etag)
142+
XCTAssertGreaterThan(metadata.etag!.count, 0)
143+
XCTAssertTrue(metadata.location.contains(file))
144+
XCTAssertGreaterThan(metadata.size!, 0)
145+
}
146+
}
90147
}
91148

92149
class SnapshotDownloadTests: XCTestCase {

0 commit comments

Comments
 (0)