# Step 1:  Importing libraries

In [3]:
#
#   Deep Knowledge Tracing (DKT) Implementation (https://github.com/mmkhajah/dkt)
#
#   Script saves 3 files:
#       dataset.txt.model_weights trained model weights
#       dataset.txt.history training history (training LL, test AUC)
#       dataset.txt.preds predictions for test trials
#
import os
import sys
import numpy as np
#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from keras.preprocessing import sequence
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense, Embedding, LSTM, TimeDistributed
from keras.layers.core import Masking
from keras import backend as K
from sklearn.metrics import roc_auc_score
import random
import math
import argparse

# Step 2: Data preprocessing

In [4]:
# The dataset file contains student, skill and performance information. It is a 3-column space-delimited file.
#   Each row in the file indicates whether a particular student answered a specific problem correctly or not. 
#   The first column is the student id, the second column is the skill id associated with the problem and the last column is whether the student got the problem correctly (1) or not (0).

#The split file is a space-delimited file where each column indicates whether the corresponding student id should be the training set (1) or not (0). For example, the split file:
def load_dataset(dataset, split_file):
    # 
    seqs, num_skills = read_file(dataset)
    
    with open(split_file, 'r') as f:
        student_assignment = f.read().split(' ')
    
    training_seqs = [seqs[i] for i in range(0, len(seqs)) if student_assignment[i] == '1']
    testing_seqs = [seqs[i] for i in range(0, len(seqs)) if student_assignment[i] == '0']
    
    return training_seqs, testing_seqs, num_skills

def read_file(dataset_path):
    seqs_by_student = {}
    problem_ids = {}
    next_problem_id = 0
    with open(dataset_path, 'r') as f:
        for line in f:
            student, problem, is_correct = line.strip().split(' ')
            student = int(student)
            if student not in seqs_by_student:
                seqs_by_student[student] = []
            if problem not in problem_ids:
                problem_ids[problem] = next_problem_id
                next_problem_id += 1
            seqs_by_student[student].append((problem_ids[problem], int(is_correct == '1')))
    
    sorted_keys = sorted(seqs_by_student.keys())
    return [seqs_by_student[k] for k in sorted_keys], next_problem_id


## Auxilliaries functions

In [5]:
def run_func(seqs, num_skills, f, batch_size, time_window, batch_done = None):

    assert(min([len(s) for s in seqs]) > 0)
    
    # randomize samples
    seqs = seqs[:]
    random.shuffle(seqs)
    
    processed = 0
    for start_from in range(0, len(seqs), batch_size):
       end_before = min(len(seqs), start_from + batch_size)
       x = []
       y = []
       for seq in seqs[start_from:end_before]:
           x_seq = []
           y_seq = []
           xt_zeros = [0 for i in range(0, num_skills*2)]
           ct_zeros = [0 for i in range(0, num_skills+1)]
           xt = xt_zeros[:]
           for skill, is_correct in seq:
               x_seq.append(xt)
               
               ct = ct_zeros[:]
               ct[skill] = 1
               ct[num_skills] = is_correct
               y_seq.append(ct)
               
               # one hot encoding of (last_skill, is_correct)
               pos = skill * 2 + is_correct
               xt = xt_zeros[:]
               xt[pos] = 1
               
           x.append(x_seq)
           y.append(y_seq)
       
       maxlen = max([len(s) for s in x])
       maxlen = round_to_multiple(maxlen, time_window)
       # fill up the batch if necessary
       if len(x) < batch_size:
            for e in range(0, batch_size - len(x)):
                x_seq = []
                y_seq = []
                for t in range(0, time_window):
                    x_seq.append([-1.0 for i in range(0, num_skills*2)])
                    y_seq.append([0.0 for i in range(0, num_skills+1)])
                x.append(x_seq)
                y.append(y_seq)
        
       X = pad_sequences(x, padding='post', maxlen = maxlen, dim=num_skills*2, value=-1.0)
       Y = pad_sequences(y, padding='post', maxlen = maxlen, dim=num_skills+1, value=-1.0)
        
       for t in range(0, maxlen, time_window):
           f(X[:,t:(t+time_window),:], Y[:,t:(t+time_window),:])
           
       processed += end_before - start_from
       
       # reset the states for the next batch of sequences
       if batch_done:
           batch_done((processed * 100.0) / len(seqs))
   
def round_to_multiple(x, base):
    return int(base * math.ceil(float(x)/base))

# https://groups.google.com/forum/#!msg/keras-users/7sw0kvhDqCw/QmDMX952tq8J
def pad_sequences(sequences, maxlen=None, dim=1, dtype='int32',
    padding='pre', truncating='pre', value=0.):
    '''
        Override keras method to allow multiple feature dimensions.
        @dim: input feature dimension (number of features per timestep)
    '''
    lengths = [len(s) for s in sequences]

    nb_samples = len(sequences)
    if maxlen is None:
        maxlen = np.max(lengths)

    x = (np.ones((nb_samples, maxlen, dim)) * value).astype(dtype)
    for idx, s in enumerate(sequences):
        if truncating == 'pre':
            trunc = s[-maxlen:]
        elif truncating == 'post':
            trunc = s[:maxlen]
        else:
            raise ValueError("Truncating type '%s' not understood" % padding)

        if padding == 'post':
            x[idx, :len(trunc)] = trunc
        elif padding == 'pre':
            x[idx, -len(trunc):] = trunc
        else:
            raise ValueError("Padding type '%s' not understood" % padding)
    return x

# Step 3: Defining the main function 

In [6]:
def testDKT():
    
    #defining arguments
    dataset = "assistments.txt"
    split_file = "assistments_split.txt"
    hidden_units = 200
    batch_size = 5
    time_window = 100
    #epochs = 50  for best results but takes a lot of time !
    epochs = 2
    
    model_file = dataset + '.model_weights'
    history_file = dataset + '.history'
    preds_file = dataset + '.preds'
    
    overall_loss = [0.0]
    preds = []
    history = []
    
    # load dataset
    training_seqs, testing_seqs, num_skills = load_dataset(dataset, split_file)
    print ("Training Sequences: %d" % len(training_seqs))
    print ("Testing Sequences: %d" % len(testing_seqs))
    print ("Number of skills: %d" % num_skills)
    
    # Our loss function
    # The model gives predictions for all skills so we need to get the 
    # prediction for the skill at time t. We do that by taking the column-wise
    # dot product between the predictions at each time slice and a
    # one-hot encoding of the skill at time t.
    # y_true: (nsamples x nsteps x nskills+1)
    # y_pred: (nsamples x nsteps x nskills)
    def loss_function(y_true, y_pred):
        skill = y_true[:,:,0:num_skills]
        obs = y_true[:,:,num_skills]
        temp = y_pred * skill
        rel_pred = K.sum(temp, axis=2)
        
        # keras implementation does a mean on the last dimension (axis=-1) which
        # it assumes is a singleton dimension. But in our context that would
        # be wrong.
        return K.binary_crossentropy(rel_pred, obs)
    
    
    # build model
    model = Sequential()
    
    # ignore padding
    model.add(Masking(-1.0, batch_input_shape=(batch_size, time_window, num_skills*2)))
    
    # lstm configured to keep states between batches
    model.add(LSTM(input_dim = num_skills*2, 
                   output_dim = hidden_units, 
                   return_sequences=True,
                   batch_input_shape=(batch_size, time_window, num_skills*2),
                   stateful = True
    ))
    
    # readout layer. TimeDistributedDense uses the same weights for all
    # time steps.
    model.add(TimeDistributed(Dense(input_dim = hidden_units, 
        output_dim = num_skills, activation='sigmoid')))
    
    # optimize with rmsprop which dynamically adapts the learning
    # rate of each weight.
    model.compile(loss=loss_function,
                optimizer='rmsprop')
    model.summary()

    # training function
    def trainer(X, Y):
        overall_loss[0] += model.train_on_batch(X, Y)
    
    # prediction
    def predictor(X, Y):
        batch_activations = model.predict_on_batch(X)
        skill = Y[:,:,0:num_skills]
        obs = Y[:,:,num_skills]
        y_pred = np.squeeze(np.array(batch_activations))
        
        rel_pred = np.sum(y_pred * skill, axis=2)
        
        for b in range(0, X.shape[0]):
            for t in range(0, X.shape[1]):
                if X[b, t, 0] == -1.0:
                    continue
                preds.append((rel_pred[b][t], obs[b][t]))
        
    # call when prediction batch is finished
    # resets LSTM state because we are done with all sequences in the batch
    def finished_prediction_batch(percent_done):
        model.reset_states()
        
    # similiar to the above
    def finished_batch(percent_done):
        print ("(%4.3f %%) %f" % (percent_done, overall_loss[0]))
        model.reset_states()
        
    # run the model
    for e in range(0, epochs):
        model.reset_states()
        
        # train
        run_func(training_seqs, num_skills, trainer, batch_size, time_window, finished_batch)
        
        model.reset_states()
        
        # test
        run_func(testing_seqs, num_skills, predictor, batch_size, time_window, finished_prediction_batch)
        
        # compute AUC
        auc = roc_auc_score([p[1] for p in preds], [p[0] for p in preds])
        
        # log
        history.append((overall_loss[0], auc))
        
        # save model
        model.save_weights(model_file, overwrite=True)
        print ("==== Epoch: %d, Test AUC: %f" % (e, auc))
        
        # reset loss
        overall_loss[0] = 0.0
        
        # save predictions
        with open(preds_file, 'w') as f:
            f.write('was_heldout\tprob_recall\tstudent_recalled\n')
            for pred in preds:
                f.write('1\t%f\t%d\n' % (pred[0], pred[1]))
        
        with open(history_file, 'w') as f:
            for h in history:
                f.write('\t'.join([str(he) for he in h]))
                f.write('\n')
                
        # clear preds
        preds = []

# Step 4: Run the DKT model

In [7]:
testDKT()  # 2 epochs only

Training Sequences: 3361
Testing Sequences: 856
Number of skills: 124




Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where




Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
masking_1 (Masking)          (5, 100, 248)             0         
_________________________________________________________________
lstm_1 (LSTM)                (5, 100, 200)             359200    
_________________________________________________________________
time_distributed_1 (TimeDist (5, 100, 124)             24924     
Total params: 384,124
Trainable params: 384,124
Non-trainable params: 0
_________________________________________________________________

(0.149 %) 7.990277
(0.298 %) 9.858946
(0.446 %) 15.272219
(0.595 %) 21.100657
(0.744 %) 24.546907
(0.893 %) 52.693967
(1.041 %) 57.521748
(1.190 %) 58.666865
(1.339 %) 64.405935
(1.488 %) 65.340869
(1.636 %) 77.828749
(1.785 %) 78.860412
(1.934 %) 89.069992
(2.083 %) 91.247758
(2.231 %) 93.911128
(2.380 %) 97.670487
(2.529 %) 101.633261
(2.678 %) 113.604591
(2.827 %) 117.665

(48.646 %) 1938.445088
(48.795 %) 1941.887673
(48.944 %) 1954.287162
(49.093 %) 1957.371875
(49.241 %) 1965.045232
(49.390 %) 1965.431593
(49.539 %) 1966.820359
(49.688 %) 1968.384536
(49.836 %) 1971.254876
(49.985 %) 1972.117650
(50.134 %) 1976.231483
(50.283 %) 1977.637459
(50.431 %) 1982.011797
(50.580 %) 1984.475245
(50.729 %) 1987.420472
(50.878 %) 1988.722064
(51.026 %) 1989.201145
(51.175 %) 1997.678427
(51.324 %) 2004.169332
(51.473 %) 2005.725280
(51.622 %) 2011.919374
(51.770 %) 2013.353544
(51.919 %) 2016.471383
(52.068 %) 2030.188138
(52.217 %) 2033.728415
(52.365 %) 2036.541252
(52.514 %) 2039.912468
(52.663 %) 2040.466220
(52.812 %) 2059.414574
(52.960 %) 2059.804264
(53.109 %) 2064.729184
(53.258 %) 2065.911713
(53.407 %) 2066.880031
(53.555 %) 2073.622962
(53.704 %) 2080.982305
(53.853 %) 2085.791890
(54.002 %) 2088.002308
(54.151 %) 2089.570265
(54.299 %) 2091.240647
(54.448 %) 2093.223325
(54.597 %) 2120.076494
(54.746 %) 2123.105215
(54.894 %) 2123.763547
(55.043 %) 

(1.636 %) 54.851522
(1.785 %) 68.737259
(1.934 %) 71.264208
(2.083 %) 73.195656
(2.231 %) 84.636616
(2.380 %) 85.993682
(2.529 %) 87.321469
(2.678 %) 94.616025
(2.827 %) 106.146920
(2.975 %) 107.787046
(3.124 %) 119.097347
(3.273 %) 121.469958
(3.422 %) 124.344860
(3.570 %) 154.945557
(3.719 %) 161.632253
(3.868 %) 165.039282
(4.017 %) 166.903182
(4.165 %) 167.440124
(4.314 %) 168.276734
(4.463 %) 173.430887
(4.612 %) 180.547081
(4.760 %) 189.918040
(4.909 %) 203.978207
(5.058 %) 213.720070
(5.207 %) 214.944142
(5.356 %) 223.300238
(5.504 %) 227.198641
(5.653 %) 233.914144
(5.802 %) 235.989534
(5.951 %) 239.064320
(6.099 %) 246.504454
(6.248 %) 248.690053
(6.397 %) 250.654941
(6.546 %) 253.845440
(6.694 %) 255.046627
(6.843 %) 255.757466
(6.992 %) 261.814903
(7.141 %) 265.527400
(7.289 %) 279.138139
(7.438 %) 290.802614
(7.587 %) 297.716816
(7.736 %) 299.819932
(7.885 %) 303.138452
(8.033 %) 304.282432
(8.182 %) 304.836154
(8.331 %) 306.315983
(8.480 %) 310.334086
(8.628 %) 320.115844


(56.382 %) 2242.164802
(56.531 %) 2247.479439
(56.680 %) 2252.205288
(56.828 %) 2255.025635
(56.977 %) 2256.086211
(57.126 %) 2261.172313
(57.275 %) 2269.634483
(57.423 %) 2275.157717
(57.572 %) 2278.044791
(57.721 %) 2279.803831
(57.870 %) 2285.564433
(58.018 %) 2287.403771
(58.167 %) 2299.732378
(58.316 %) 2301.040702
(58.465 %) 2314.089831
(58.614 %) 2317.627775
(58.762 %) 2318.126788
(58.911 %) 2324.628697
(59.060 %) 2334.715186
(59.209 %) 2345.125776
(59.357 %) 2348.003587
(59.506 %) 2350.952187
(59.655 %) 2352.855008
(59.804 %) 2359.455680
(59.952 %) 2366.470760
(60.101 %) 2368.979310
(60.250 %) 2374.752312
(60.399 %) 2379.234311
(60.547 %) 2387.432730
(60.696 %) 2388.526234
(60.845 %) 2399.626967
(60.994 %) 2405.566012
(61.143 %) 2406.493181
(61.291 %) 2415.598138
(61.440 %) 2419.911012
(61.589 %) 2421.830340
(61.738 %) 2429.839306
(61.886 %) 2431.667895
(62.035 %) 2434.119883
(62.184 %) 2453.199101
(62.333 %) 2461.764082
(62.481 %) 2470.851761
(62.630 %) 2472.115774
(62.779 %) 