Skip to content

Custom learning tasks tutorial gives error #285

@usiam

Description

@usiam
using Pkg;
Pkg.activate(".")
using FastAI, FastVision, Random, Images
import CairoMakie;
CairoMakie.activate!(type="png");

path = FastAI.load(datasets()["oxford-iiit-pet"])
im_path = joinpath(path, "images")
files = loadfolderdata(im_path; filterfn=FastVision.isimagefile)


function transform_image(image, sz=224)
    image_resized = imresize(convert.(RGB{N0f8}, image), (sz, sz))
    a = permuteddimsview(channelview(image_resized), (2, 3, 1))
end

p = getobs(files, 1)
image = loadfile(p)

label_func(path) = match(r"^(.*)_\d+\.jpg$", pathname(path))[1]
label_func(p)

labels = map(label_func, files)
length(unique(labels))


data = mapobs(files) do file
    return (loadfile(file), label_func(file))
end


idxs = shuffle(1:length(files))
cut = round(Int, 0.8 * length(idxs))
trainidxs, valididxs = idxs[1:cut], idxs[cut+1:end]
trainfiles, validfiles = files[trainidxs], files[valididxs]
summary.((trainfiles, validfiles))



import FastAI.MLUtils


struct SiamesePairs
    labels
    same
    other
    valid
end

function SiamesePairs(labels; valid=false)
    ulabels = unique(labels)
    same = Dict(
        label => [i for (i, l) in enumerate(labels) if l == label]
        for label in ulabels)
    other = Dict(
        label => [i for (i, l) in enumerate(labels) if l != label]
        for label in ulabels)

    return SiamesePairs(labels, same, other, valid)
end

function MLUtils.getobs(si::SiamesePairs, idx::Int)
    rng = si.valid ? MersenneTwister(idx) : Random.GLOBAL_RNG
    if rand(rng) > 0.5
        return ((idx, rand(rng, si.same[si.labels[idx]])), true)
    else
        return ((idx, rand(rng, si.other[si.labels[idx]])), false)
    end
end

MLUtils.numobs(si::SiamesePairs) = length(si.labels)

function siamesedata(files; valid = false, transformfn = identity)
    labels = map(label_func, files)
    si = SiamesePairs(labels; valid = valid)
    return mapobs(si) do obs
        (i, j), same = obs
        image1 = transformfn(loadfile(getobs(files, i)))
        image2 = transformfn(loadfile(getobs(files, j)))
        return ((image1, image2), same)
    end
end

traindata = siamesedata(trainfiles; transformfn=transform_image)
validdata = siamesedata(validfiles; transformfn=transform_image, valid=true);

traindl = FastAI.MLUtils.DataLoader(traindata, 16)

ERROR: MethodError: no method matching MLUtils.DataLoader(::MLUtils.MappedData{:auto, var"#75#76"{typeof(transform_image), ObsView{MLDatasets.FileDataset{typeof(identity), String}, Vector{Int64}}}, SiamesePairs}, ::Int64)

I was trying to recreate the Siamese example in the docs and could not figure out why I am getting this error? And how do I fix this?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions