How to use a prefitted model

It is possible to use an already fitted model, to make the prediction on the new data. This can be very helpful for huge data sets, which fitting procedure can be very time consuming. Using SaveToJason method one can save the fitting outputs. In this tutorial we’ll use the saved output to make prediction on a new data set whit already computed SVM’s hyperplane.

[9]:
import orsvm
import pandas as pd
import numpy as np

Load Dataset

Suppose Monks data set is a new data set, different from the data set which is alreday used to fit and output the Jason file. Therefore first you need to load data set.

[3]:
# Fitting a model requires the data-set to be prepared, in order to be a binary classification.
df = pd.read_csv(r'D:\IPM\ORSVM\DataSets\DataSets\Classification\monks-problems\monks1_train.csv')


y_train=df['label'].to_numpy()         # convert y_train to numpy array
df.drop('label', axis=1, inplace=True) # drop the class label
X_train=df.to_numpy()                  # convert x_train to numpy array


# load test-set
df = pd.read_csv(r'D:\IPM\ORSVM\DataSets\DataSets\Classification\monks-problems\monks1_test.csv')

y_test=df['label'].to_numpy()
df.drop('label', axis=1, inplace=True)
X_test=df.to_numpy()

Load Json File.

Load the Jason file containing all model parameters such as Support Vectors, Bias, weights and the kernel instance. These parameters will be used to make the hyperplane and therfore, the prediction.

[7]:
# load the Jason file which is already saved in a file for example named 'output.jason' and keep it in a dictionary to inpect teh values.
data = orsvm.orsvm.LoadJason('output.json')
[10]:
data
[10]:
{'kernel': 'legendre',
 'kernelParam1': None,
 'kernelParam2': None,
 'order': 4,
 'form': 'r',
 'noise': 0.1,
 'mode': 'fractional',
 'C': None,
 'transition': 0.3,
 'svd': 'a',
 'support multipliers': array([4.31616233e-05, 3.89142170e-05, 3.17406317e-04, 1.21298781e-03,
        2.26153630e-03, 9.46278728e-04, 4.48721425e-04, 1.55539798e-04,
        1.76994177e-03, 9.74950658e-04, 6.60982882e-05, 3.33188669e-03,
        1.17478909e-03, 1.36909791e-03, 1.91821866e-03, 9.56413563e-05,
        6.20364081e-05, 1.47440492e-03, 5.30630721e-04, 1.09339305e-04,
        5.52046728e-04, 2.80712857e-03, 2.65326106e-03, 6.43361900e-04,
        4.82953952e-04, 2.39357351e-03, 1.28777105e-06, 1.65828539e-04,
        1.51709459e-04, 9.23363273e-03, 2.05142062e-02, 1.10989496e-01,
        2.55213384e-02, 9.59451474e-03, 3.29399755e-03, 1.03252221e-03,
        1.99687458e-02, 9.61922898e-02, 6.46303906e-04, 5.39030946e-04,
        1.58653037e-02, 2.79692699e-02, 2.47487466e-04, 1.03700941e-03,
        4.04852964e-04, 1.31887063e-04, 3.99202585e-04, 2.52186718e-02,
        8.14330077e-04, 1.03097010e-01, 6.38327703e-04, 1.27344896e-02,
        6.66340171e-03, 8.89939033e-05, 1.92628687e-02, 4.75350555e-03,
        1.77760937e-03, 8.99825443e-02, 1.02012638e-03, 1.34319496e-02,
        1.06050364e-02, 5.91328073e-03]),
 'support vectors': array([[-1.        , -1.        , -1.        , -1.        ,  0.77093499,
         -1.        ],
        [-1.        , -1.        , -1.        , -1.        ,  0.77093499,
          0.43844619],
        [-1.        , -1.        , -1.        ,  0.77093499,  0.43844619,
         -1.        ],
        [-1.        , -1.        , -1.        ,  0.77093499,  0.77093499,
          0.43844619],
        [-1.        , -1.        ,  0.43844619, -1.        ,  0.43844619,
          0.43844619],
        [-1.        , -1.        ,  0.43844619,  0.43844619,  0.77093499,
         -1.        ],
        [-1.        , -1.        ,  0.43844619,  0.43844619,  1.        ,
         -1.        ],
        [-1.        ,  0.43844619, -1.        , -1.        ,  0.43844619,
         -1.        ],
        [-1.        ,  0.43844619, -1.        ,  0.43844619,  0.77093499,
          0.43844619],
        [-1.        ,  0.43844619, -1.        ,  0.77093499,  0.43844619,
         -1.        ],
        [-1.        ,  0.43844619, -1.        ,  0.77093499,  1.        ,
          0.43844619],
        [-1.        ,  0.43844619,  0.43844619, -1.        ,  0.43844619,
          0.43844619],
        [-1.        ,  0.43844619,  0.43844619,  0.43844619,  0.77093499,
          0.43844619],
        [-1.        ,  0.43844619,  0.43844619,  0.43844619,  1.        ,
         -1.        ],
        [-1.        ,  0.43844619,  0.43844619,  0.77093499,  0.43844619,
          0.43844619],
        [-1.        ,  0.43844619,  0.43844619,  0.77093499,  0.77093499,
         -1.        ],
        [-1.        ,  0.77093499, -1.        ,  0.77093499, -1.        ,
          0.43844619],
        [-1.        ,  0.77093499,  0.43844619,  0.43844619, -1.        ,
          0.43844619],
        [-1.        ,  0.77093499,  0.43844619,  0.77093499, -1.        ,
         -1.        ],
        [ 0.43844619, -1.        , -1.        , -1.        ,  0.77093499,
          0.43844619],
        [ 0.43844619, -1.        , -1.        ,  0.43844619, -1.        ,
          0.43844619],
        [ 0.43844619, -1.        , -1.        ,  0.43844619,  0.43844619,
          0.43844619],
        [ 0.43844619, -1.        ,  0.43844619, -1.        ,  0.43844619,
          0.43844619],
        [ 0.43844619, -1.        ,  0.43844619, -1.        ,  0.77093499,
         -1.        ],
        [ 0.43844619, -1.        ,  0.43844619, -1.        ,  1.        ,
          0.43844619],
        [ 0.43844619, -1.        ,  0.43844619,  0.43844619,  0.77093499,
         -1.        ],
        [ 0.43844619, -1.        ,  0.43844619,  0.43844619,  1.        ,
          0.43844619],
        [ 0.43844619, -1.        ,  0.43844619,  0.77093499,  0.43844619,
          0.43844619],
        [ 0.43844619, -1.        ,  0.43844619,  0.77093499,  1.        ,
         -1.        ],
        [ 0.43844619,  0.43844619, -1.        ,  0.43844619,  0.77093499,
          0.43844619],
        [ 0.43844619,  0.43844619, -1.        ,  0.77093499,  1.        ,
          0.43844619],
        [ 0.43844619,  0.43844619,  0.43844619, -1.        ,  0.77093499,
          0.43844619],
        [ 0.43844619,  0.43844619,  0.43844619,  0.43844619,  0.43844619,
         -1.        ],
        [ 0.43844619,  0.43844619,  0.43844619,  0.77093499,  1.        ,
         -1.        ],
        [ 0.43844619,  0.77093499, -1.        ,  0.43844619,  0.77093499,
         -1.        ],
        [ 0.43844619,  0.77093499, -1.        ,  0.77093499,  0.77093499,
         -1.        ],
        [ 0.43844619,  0.77093499, -1.        ,  0.77093499,  1.        ,
          0.43844619],
        [ 0.43844619,  0.77093499,  0.43844619, -1.        ,  0.77093499,
          0.43844619],
        [ 0.43844619,  0.77093499,  0.43844619,  0.43844619, -1.        ,
         -1.        ],
        [ 0.43844619,  0.77093499,  0.43844619,  0.43844619, -1.        ,
          0.43844619],
        [ 0.43844619,  0.77093499,  0.43844619,  0.43844619,  0.43844619,
         -1.        ],
        [ 0.43844619,  0.77093499,  0.43844619,  0.77093499,  0.77093499,
          0.43844619],
        [ 0.77093499, -1.        , -1.        , -1.        , -1.        ,
          0.43844619],
        [ 0.77093499, -1.        , -1.        ,  0.77093499,  0.43844619,
          0.43844619],
        [ 0.77093499, -1.        ,  0.43844619, -1.        , -1.        ,
         -1.        ],
        [ 0.77093499, -1.        ,  0.43844619,  0.43844619,  0.43844619,
          0.43844619],
        [ 0.77093499, -1.        ,  0.43844619,  0.77093499,  0.43844619,
          0.43844619],
        [ 0.77093499,  0.43844619, -1.        ,  0.43844619,  1.        ,
          0.43844619],
        [ 0.77093499,  0.43844619,  0.43844619, -1.        , -1.        ,
          0.43844619],
        [ 0.77093499,  0.43844619,  0.43844619, -1.        ,  0.77093499,
          0.43844619],
        [ 0.77093499,  0.43844619,  0.43844619,  0.77093499, -1.        ,
         -1.        ],
        [ 0.77093499,  0.43844619,  0.43844619,  0.77093499,  0.43844619,
         -1.        ],
        [ 0.77093499,  0.43844619,  0.43844619,  0.77093499,  1.        ,
         -1.        ],
        [ 0.77093499,  0.77093499, -1.        , -1.        ,  0.43844619,
         -1.        ],
        [ 0.77093499,  0.77093499, -1.        ,  0.43844619,  1.        ,
          0.43844619],
        [ 0.77093499,  0.77093499, -1.        ,  0.77093499,  0.43844619,
         -1.        ],
        [ 0.77093499,  0.77093499, -1.        ,  0.77093499,  1.        ,
          0.43844619],
        [ 0.77093499,  0.77093499,  0.43844619, -1.        ,  0.77093499,
          0.43844619],
        [ 0.77093499,  0.77093499,  0.43844619, -1.        ,  1.        ,
         -1.        ],
        [ 0.77093499,  0.77093499,  0.43844619,  0.77093499,  0.43844619,
          0.43844619],
        [ 0.77093499,  0.77093499,  0.43844619,  0.77093499,  0.77093499,
          0.43844619],
        [ 0.77093499,  0.77093499,  0.43844619,  0.77093499,  1.        ,
          0.43844619]]),
 'support vector labels': array([ 1.,  1.,  1.,  1.,  1.,  1.,  1., -1., -1., -1., -1., -1., -1.,
        -1., -1., -1.,  1.,  1.,  1., -1.,  1., -1., -1., -1., -1., -1.,
        -1., -1., -1.,  1.,  1.,  1.,  1.,  1., -1., -1., -1., -1.,  1.,
         1., -1., -1.,  1., -1.,  1., -1., -1., -1.,  1., -1.,  1., -1.,
        -1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.]),
 'weights': array([4.31616233e-05, 3.89142170e-05, 3.17406317e-04, 1.21298781e-03,
        2.26153630e-03, 9.46278728e-04, 4.48721425e-04, 1.55539798e-04,
        1.76994177e-03, 9.74950658e-04, 6.60982882e-05, 3.33188669e-03,
        1.17478909e-03, 1.36909791e-03, 1.91821866e-03, 9.56413563e-05,
        6.20364081e-05, 1.47440492e-03, 5.30630721e-04, 1.09339305e-04,
        5.52046728e-04, 2.80712857e-03, 2.65326106e-03, 6.43361900e-04,
        4.82953952e-04, 2.39357351e-03, 1.28777105e-06, 1.65828539e-04,
        1.51709459e-04, 9.23363273e-03, 2.05142062e-02, 1.10989496e-01,
        2.55213384e-02, 9.59451474e-03, 3.29399755e-03, 1.03252221e-03,
        1.99687458e-02, 9.61922898e-02, 6.46303906e-04, 5.39030946e-04,
        1.58653037e-02, 2.79692699e-02, 2.47487466e-04, 1.03700941e-03,
        4.04852964e-04, 1.31887063e-04, 3.99202585e-04, 2.52186718e-02,
        8.14330077e-04, 1.03097010e-01, 6.38327703e-04, 1.27344896e-02,
        6.66340171e-03, 8.89939033e-05, 1.92628687e-02, 4.75350555e-03,
        1.77760937e-03, 8.99825443e-02, 1.02012638e-03, 1.34319496e-02,
        1.06050364e-02, 5.91328073e-03]),
 'kernel matrix': None,
 'bias': 0.2798437094039253,
 'accuracy': 0.9328703703703703,
 'status': 'optimal'}

Predict using the Json file

The function PredictWithJson, reads all required parameters from saved file directly. New data(X_test) will be transfomred into the space(Normal/Fractional) same as the fitted model. All parameters such as orthogonal function, are also the same.

[8]:
 orsvm.orsvm.PredictWithJson("output.json", X_test, y_test)
2022-11-22 21:24:30,276:INFO:** Accuracy score: 0.9328703703703703
2022-11-22 21:24:30,288:INFO:** Classification Report:
               precision    recall  f1-score   support

          -1       0.95      0.91      0.93       216
           1       0.92      0.95      0.93       216

    accuracy                           0.93       432
   macro avg       0.93      0.93      0.93       432
weighted avg       0.93      0.93      0.93       432

2022-11-22 21:24:30,294:INFO:** Confusion Matrix:
 [[197  19]
 [ 10 206]]
[8]:
0.9328703703703703