Fisher's flowers
In this tutorial, we will go through step by step in a more in depth workflow for using SimSpread.jl
having as example the classic classification problem of R.A. Fisher iris dataset.
Let's go ahead and load the dataset we will work with, that is, the flower classes and features:
using SimSpread
y = read_namedmatrix("data/iris.classes")
X = read_namedmatrix("data/iris.features")
Classes are one-hot encoded due to how SimSpread works:
y[1:5, :]
5×3 Named Matrix{Float64}
Flower ╲ Class │ setosa versicolor virginica
───────────────┼───────────────────────────────────
001 │ 1.0 0.0 0.0
002 │ 1.0 0.0 0.0
003 │ 1.0 0.0 0.0
004 │ 1.0 0.0 0.0
005 │ 1.0 0.0 0.0
And features can be of any type (e.g., continuous floats describing the plants):
X[1:5, :]
5×4 Named Matrix{Float64}
Flower ╲ Feature │ petallength petalwidth sepallength sepalwidth
─────────────────┼───────────────────────────────────────────────────
001 │ 1.4 0.2 5.1 3.5
002 │ 1.4 0.2 4.9 3.0
003 │ 1.3 0.2 4.7 3.2
004 │ 1.5 0.2 4.6 3.1
005 │ 1.4 0.2 5.0 3.6
Next, we will train a model using SimSpread to predict the classes for a subset of plants in the Iris dataset. For this, we will split our dataset in 2 groups: training set, which will correspond to 90% of the data, and testing set, which will correspond to the remaining 10%:
nflowers = size(y, 1)
train = rand(nflowers) .< 0.90
test = .!train
ytrain = y[train, :]
ytest = y[test, :]
Meta-description preparation
As we previously mentioned, SimSpread uses an abstracted feature set where entities are described by their similarity to other entities. This permits the added flexibility of freely choosing any type of features and similarity measurement to correctly describe the problems entities.
To generate this meta-description features, the following steps are taken:
- A similarity matrix $S$ is obtained from the calculation of a similarity metric between all pairs of feature vectors of the entities on the studied dataset:
- From $S$ we can construct a similarity-based feature matrix $S^\prime$ by applying the similarity threshold $\alpha$ using the following equation: $S^\prime_{ij}={w(i,j) \ \text{if} \ S_{ij} \ge \alpha; \ 0 \ \text{otherwise.}}$ where $S$ corresponds to the entities similarity matrix, $S^\prime$ to the final feature matrix, $i$ and $j$ to entities in the studied dataset, and $w(i,j)$ the weighting scheme employed for feature matrix construction, which can be binary, $w(i,j) = S_{ij} > 0$, or continuous, $w(i,j) = (S_{ij} > 0) \times S_{ij}$.
This meta-description matrix encodes the question "Is plant i similar to plant j?", which is later used by the resource spreading algorithm for link prediction.
Here, we will use the Jaccard index as our similarity measure, similarity measurement that is bound between 0 and 1, and will use a cutoff of $J(x,y) = 0.9$, since this will conserve all comparison between highly similar flowers:
using Distances, NamedArrays
S = NamedArray(1 .- pairwise(Jaccard(), X, dims=1))
From this similarity matrix, we will prepare our meta-description for both training and testing sets:
α = 0.9
Xtrain = featurize(S[train, train], α, true)
Xtest = featurize(S[test, train], α, true)
- Training set meta-description matrix:
- Testing set meta-description matrix:
Predicting labels with SimSpread
Now that we have all the information necessary for SimSpread, we can construct the query graph that is used to predict links using network-based-inference resource allocation algorithm.
In first place, we need to construct the query network for label prediction:
G = construct(ytrain, ytest, Xtrain, Xtest)
From this, we can predict the labels as follows:
ŷtrain = predict(G, ytrain)
ŷtest = predict(G, ytest)
Let's visualize the predictions obtained from our model:
- Training set:
- Testing set:
As we can see, we predict the probability for each class of flower possible. To evaluate the predictive performance as a multiclass problem, we will assign the label with the highest score as the predicted label.
Assesing the predictive performance of the proposed model
Firstly, let's visualize the prediction over this scatter plot to map where are the incorrect predictions: