diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.py b/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.py index 1c77975b78..2e33233b44 100644 --- a/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.py +++ b/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation.py @@ -124,9 +124,12 @@ group_size = 32 # number of instances over which to evaluate the learning performance n_iter = 50 # number of iterations -n_input_symbols = 4 # number of input populations, e.g. 4 = left, right, recall, noise -n_cues = 7 # number of cues given before decision -prob_group = 0.3 # probability with which one input group is present +input = { + "n_symbols": 4, # number of input populations, e.g. 4 = left, right, recall, noise + "n_cues": 7, # number of cues given before decision + "prob_group": 0.3, # probability with which one input group is present + "spike_prob": 0.04, # spike probability of frozen input noise +} steps = { "cue": 100, # time steps in one cue presentation @@ -135,7 +138,7 @@ "recall": 150, # time steps of recall } -steps["cues"] = n_cues * (steps["cue"] + steps["spacing"]) # time steps of all cues +steps["cues"] = input["n_cues"] * (steps["cue"] + steps["spacing"]) # time steps of all cues steps["sequence"] = steps["cues"] + steps["bg_noise"] + steps["recall"] # time steps of one full sequence steps["learning_window"] = steps["recall"] # time steps of window with non-zero learning signals steps["task"] = n_iter * group_size * steps["sequence"] # time steps of task @@ -457,25 +460,23 @@ def calculate_glorot_dist(fan_in, fan_out): # assigned randomly to the left or right. -def generate_evidence_accumulation_input_output( - batch_size, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps -): - n_pop_nrn = n_in // n_input_symbols +def generate_evidence_accumulation_input_output(batch_size, n_in, steps, input): + n_pop_nrn = n_in // input["n_symbols"] - prob_choices = np.array([prob_group, 1 - prob_group], dtype=np.float32) + prob_choices = np.array([input["prob_group"], 1 - input["prob_group"]], dtype=np.float32) idx = np.random.choice([0, 1], batch_size) probs = np.zeros((batch_size, 2), dtype=np.float32) probs[:, 0] = prob_choices[idx] probs[:, 1] = prob_choices[1 - idx] - batched_cues = np.zeros((batch_size, n_cues), dtype=int) + batched_cues = np.zeros((batch_size, input["n_cues"]), dtype=int) for b_idx in range(batch_size): - batched_cues[b_idx, :] = np.random.choice([0, 1], n_cues, p=probs[b_idx]) + batched_cues[b_idx, :] = np.random.choice([0, 1], input["n_cues"], p=probs[b_idx]) input_spike_probs = np.zeros((batch_size, steps["sequence"], n_in)) for b_idx in range(batch_size): - for c_idx in range(n_cues): + for c_idx in range(input["n_cues"]): cue = batched_cues[b_idx, c_idx] step_start = c_idx * (steps["cue"] + steps["spacing"]) + steps["spacing"] @@ -484,29 +485,26 @@ def generate_evidence_accumulation_input_output( pop_nrn_start = cue * n_pop_nrn pop_nrn_stop = pop_nrn_start + n_pop_nrn - input_spike_probs[b_idx, step_start:step_stop, pop_nrn_start:pop_nrn_stop] = input_spike_prob + input_spike_probs[b_idx, step_start:step_stop, pop_nrn_start:pop_nrn_stop] = input["spike_prob"] - input_spike_probs[:, -steps["recall"] :, 2 * n_pop_nrn : 3 * n_pop_nrn] = input_spike_prob - input_spike_probs[:, :, 3 * n_pop_nrn :] = input_spike_prob / 4.0 + input_spike_probs[:, -steps["recall"] :, 2 * n_pop_nrn : 3 * n_pop_nrn] = input["spike_prob"] + input_spike_probs[:, :, 3 * n_pop_nrn :] = input["spike_prob"] / 4.0 input_spike_bools = input_spike_probs > np.random.rand(input_spike_probs.size).reshape(input_spike_probs.shape) input_spike_bools[:, 0, :] = 0 # remove spikes in 0th time step of every sequence for technical reasons target_cues = np.zeros(batch_size, dtype=int) - target_cues[:] = np.sum(batched_cues, axis=1) > int(n_cues / 2) + target_cues[:] = np.sum(batched_cues, axis=1) > int(input["n_cues"] / 2) return input_spike_bools, target_cues -input_spike_prob = 0.04 # spike probability of frozen input noise dtype_in_spks = np.float32 # data type of input spikes - for reproducing TF results set to np.float32 input_spike_bools_list = [] target_cues_list = [] for _ in range(n_iter): - input_spike_bools, target_cues = generate_evidence_accumulation_input_output( - group_size, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps - ) + input_spike_bools, target_cues = generate_evidence_accumulation_input_output(group_size, n_in, steps, input) input_spike_bools_list.append(input_spike_bools) target_cues_list.extend(target_cues) @@ -768,7 +766,10 @@ def plot_spikes(ax, events, ylabel, xlims): # the first time step and we add the initial weights manually. -def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel): +def plot_weight_time_course(ax, events, nrns_weight_record, label, ylabel): + sender_label, target_label = label.split("_") + nrns_senders = nrns_weight_record[sender_label] + nrns_targets = nrns_weight_record[target_label] for sender in nrns_senders.tolist(): for target in nrns_targets.tolist(): idc_syn = (events["senders"] == sender) & (events["targets"] == target) @@ -787,11 +788,15 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4)) fig.suptitle("Weight time courses") -plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)") -plot_weight_time_course( - axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)" -) -plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)") +nrns_weight_record = { + "in": nrns_in[:n_record_w], + "rec": nrns_rec[:n_record_w], + "out": nrns_out, +} + +plot_weight_time_course(axs[0], events_wr, nrns_weight_record, "in_rec", r"$W_\text{in}$ (pA)") +plot_weight_time_course(axs[1], events_wr, nrns_weight_record, "rec_rec", r"$W_\text{rec}$ (pA)") +plot_weight_time_course(axs[2], events_wr, nrns_weight_record, nrns_out, "rec_out", r"$W_\text{out}$ (pA)") axs[-1].set_xlabel(r"$t$ (ms)") axs[-1].set_xlim(0, steps["task"]) diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation_bsshslm_2020.py b/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation_bsshslm_2020.py index f95953250a..a1f44c3711 100644 --- a/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation_bsshslm_2020.py +++ b/pynest/examples/eprop_plasticity/eprop_supervised_classification_evidence-accumulation_bsshslm_2020.py @@ -124,9 +124,12 @@ batch_size = 32 # batch size, 64 in reference [2], 32 in the README to reference [2] n_iter = 50 # number of iterations, 2000 in reference [2] -n_input_symbols = 4 # number of input populations, e.g. 4 = left, right, recall, noise -n_cues = 7 # number of cues given before decision -prob_group = 0.3 # probability with which one input group is present +input = { + "n_symbols": 4, # number of input populations, e.g. 4 = left, right, recall, noise + "n_cues": 7, # number of cues given before decision + "prob_group": 0.3, # probability with which one input group is present + "spike_prob": 0.04, # spike probability of frozen input noise +} do_early_stopping = True # if True, stop training as soon as stop criterion fulfilled n_validate_every = 10 # number of training iterations before validation @@ -145,7 +148,7 @@ "recall": 150, # time steps of recall } -steps["cues"] = n_cues * (steps["cue"] + steps["spacing"]) # time steps of all cues +steps["cues"] = input["n_cues"] * (steps["cue"] + steps["spacing"]) # time steps of all cues steps["sequence"] = steps["cues"] + steps["bg_noise"] + steps["recall"] # time steps of one full sequence steps["learning_window"] = steps["recall"] # time steps of window with non-zero learning signals @@ -468,25 +471,23 @@ def calculate_glorot_dist(fan_in, fan_out): # assigned randomly to the left or right. -def generate_evidence_accumulation_input_output( - batch_size, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps -): - n_pop_nrn = n_in // n_input_symbols +def generate_evidence_accumulation_input_output(batch_size, n_in, steps, input): + n_pop_nrn = n_in // input["n_symbols"] - prob_choices = np.array([prob_group, 1 - prob_group], dtype=np.float32) + prob_choices = np.array([input["prob_group"], 1 - input["prob_group"]], dtype=np.float32) idx = np.random.choice([0, 1], batch_size) probs = np.zeros((batch_size, 2), dtype=np.float32) probs[:, 0] = prob_choices[idx] probs[:, 1] = prob_choices[1 - idx] - batched_cues = np.zeros((batch_size, n_cues), dtype=int) + batched_cues = np.zeros((batch_size, input["n_cues"]), dtype=int) for b_idx in range(batch_size): - batched_cues[b_idx, :] = np.random.choice([0, 1], n_cues, p=probs[b_idx]) + batched_cues[b_idx, :] = np.random.choice([0, 1], input["n_cues"], p=probs[b_idx]) input_spike_probs = np.zeros((batch_size, steps["sequence"], n_in)) for b_idx in range(batch_size): - for c_idx in range(n_cues): + for c_idx in range(input["n_cues"]): cue = batched_cues[b_idx, c_idx] step_start = c_idx * (steps["cue"] + steps["spacing"]) + steps["spacing"] @@ -495,30 +496,27 @@ def generate_evidence_accumulation_input_output( pop_nrn_start = cue * n_pop_nrn pop_nrn_stop = pop_nrn_start + n_pop_nrn - input_spike_probs[b_idx, step_start:step_stop, pop_nrn_start:pop_nrn_stop] = input_spike_prob + input_spike_probs[b_idx, step_start:step_stop, pop_nrn_start:pop_nrn_stop] = input["spike_prob"] - input_spike_probs[:, -steps["recall"] :, 2 * n_pop_nrn : 3 * n_pop_nrn] = input_spike_prob - input_spike_probs[:, :, 3 * n_pop_nrn :] = input_spike_prob / 4.0 + input_spike_probs[:, -steps["recall"] :, 2 * n_pop_nrn : 3 * n_pop_nrn] = input["spike_prob"] + input_spike_probs[:, :, 3 * n_pop_nrn :] = input["spike_prob"] / 4.0 input_spike_bools = input_spike_probs > np.random.rand(input_spike_probs.size).reshape(input_spike_probs.shape) input_spike_bools[:, 0, :] = 0 # remove spikes in 0th time step of every sequence for technical reasons target_cues = np.zeros(batch_size, dtype=int) - target_cues[:] = np.sum(batched_cues, axis=1) > int(n_cues / 2) + target_cues[:] = np.sum(batched_cues, axis=1) > int(input["n_cues"] / 2) return input_spike_bools, target_cues def get_params_task_input_output(n_iter_interval): - input_spike_prob = 0.04 # spike probability of frozen input noise dtype_in_spks = np.float32 # data type of input spikes - for reproducing TF results set to np.float32 input_spike_bools_list = [] target_cues_list = [] for _ in range(n_iter_interval): - input_spike_bools, target_cues = generate_evidence_accumulation_input_output( - batch_size, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps - ) + input_spike_bools, target_cues = generate_evidence_accumulation_input_output(batch_size, n_in, steps, input) input_spike_bools_list.append(input_spike_bools) target_cues_list.extend(target_cues) @@ -876,7 +874,10 @@ def plot_spikes(ax, events, ylabel, xlims): # the first time step and we add the initial weights manually. -def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel): +def plot_weight_time_course(ax, events, nrns_weight_record, label, ylabel): + sender_label, target_label = label.split("_") + nrns_senders = nrns_weight_record[sender_label] + nrns_targets = nrns_weight_record[target_label] for sender in nrns_senders.tolist(): for target in nrns_targets.tolist(): idc_syn = (events["senders"] == sender) & (events["targets"] == target) @@ -895,11 +896,15 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4)) fig.suptitle("Weight time courses") -plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)") -plot_weight_time_course( - axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)" -) -plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)") +nrns_weight_record = { + "in": nrns_in[:n_record_w], + "rec": nrns_rec[:n_record_w], + "out": nrns_out, +} + +plot_weight_time_course(axs[0], events_wr, nrns_weight_record, "in_rec", r"$W_\text{in}$ (pA)") +plot_weight_time_course(axs[1], events_wr, nrns_weight_record, "rec_rec", r"$W_\text{rec}$ (pA)") +plot_weight_time_course(axs[2], events_wr, nrns_weight_record, "rec_out", r"$W_\text{out}$ (pA)") axs[-1].set_xlabel(r"$t$ (ms)") axs[-1].set_xlim(0, duration["task"]) diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_classification_neuromorphic_mnist.py b/pynest/examples/eprop_plasticity/eprop_supervised_classification_neuromorphic_mnist.py index d408f73739..e991f8a7f2 100644 --- a/pynest/examples/eprop_plasticity/eprop_supervised_classification_neuromorphic_mnist.py +++ b/pynest/examples/eprop_plasticity/eprop_supervised_classification_neuromorphic_mnist.py @@ -824,7 +824,10 @@ def plot_spikes(ax, events, ylabel, xlims): # the first time step and we add the initial weights manually. -def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel): +def plot_weight_time_course(ax, events, nrns_weight_record, label, ylabel): + sender_label, target_label = label.split("_") + nrns_senders = nrns_weight_record[sender_label] + nrns_targets = nrns_weight_record[target_label] for sender in nrns_senders.tolist(): for target in nrns_targets.tolist(): idc_syn = (events["senders"] == sender) & (events["targets"] == target) @@ -843,11 +846,15 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4)) fig.suptitle("Weight time courses") -plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)") -plot_weight_time_course( - axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)" -) -plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)") +nrns_weight_record = { + "in": nrns_in[:n_record_w], + "rec": nrns_rec[:n_record_w], + "out": nrns_out, +} + +plot_weight_time_course(axs[0], events_wr, nrns_weight_record, "in_rec", r"$W_\text{in}$ (pA)") +plot_weight_time_course(axs[1], events_wr, nrns_weight_record, "rec_rec", r"$W_\text{rec}$ (pA)") +plot_weight_time_course(axs[2], events_wr, nrns_weight_record, "rec_out", r"$W_\text{out}$ (pA)") axs[-1].set_xlabel(r"$t$ (ms)") axs[-1].set_xlim(0, steps["task"]) diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_handwriting_bsshslm_2020.py b/pynest/examples/eprop_plasticity/eprop_supervised_regression_handwriting_bsshslm_2020.py index fc0cebfd78..a22ab3e822 100644 --- a/pynest/examples/eprop_plasticity/eprop_supervised_regression_handwriting_bsshslm_2020.py +++ b/pynest/examples/eprop_plasticity/eprop_supervised_regression_handwriting_bsshslm_2020.py @@ -672,7 +672,10 @@ def plot_spikes(ax, events, ylabel, xlims): # the first time step and we add the initial weights manually. -def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel): +def plot_weight_time_course(ax, events, nrns_weight_record, label, ylabel): + sender_label, target_label = label.split("_") + nrns_senders = nrns_weight_record[sender_label] + nrns_targets = nrns_weight_record[target_label] for sender in nrns_senders.tolist(): for target in nrns_targets.tolist(): idc_syn = (events["senders"] == sender) & (events["targets"] == target) @@ -691,11 +694,15 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4)) fig.suptitle("Weight time courses") -plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)") -plot_weight_time_course( - axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)" -) -plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)") +nrns_weight_record = { + "in": nrns_in[:n_record_w], + "rec": nrns_rec[:n_record_w], + "out": nrns_out, +} + +plot_weight_time_course(axs[0], events_wr, nrns_weight_record, "in_rec", r"$W_\text{in}$ (pA)") +plot_weight_time_course(axs[1], events_wr, nrns_weight_record, "rec_rec", r"$W_\text{rec}$ (pA)") +plot_weight_time_course(axs[2], events_wr, nrns_weight_record, "rec_out", r"$W_\text{out}$ (pA)") axs[-1].set_xlabel(r"$t$ (ms)") axs[-1].set_xlim(0, steps["task"]) diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_lemniscate_bsshslm_2020.py b/pynest/examples/eprop_plasticity/eprop_supervised_regression_lemniscate_bsshslm_2020.py index 02cae830a0..90e6000e06 100644 --- a/pynest/examples/eprop_plasticity/eprop_supervised_regression_lemniscate_bsshslm_2020.py +++ b/pynest/examples/eprop_plasticity/eprop_supervised_regression_lemniscate_bsshslm_2020.py @@ -653,7 +653,10 @@ def plot_spikes(ax, events, ylabel, xlims): # the first time step and we add the initial weights manually. -def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel): +def plot_weight_time_course(ax, events, nrns_weight_record, label, ylabel): + sender_label, target_label = label.split("_") + nrns_senders = nrns_weight_record[sender_label] + nrns_targets = nrns_weight_record[target_label] for sender in nrns_senders.tolist(): for target in nrns_targets.tolist(): idc_syn = (events["senders"] == sender) & (events["targets"] == target) @@ -672,11 +675,15 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4)) fig.suptitle("Weight time courses") -plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)") -plot_weight_time_course( - axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)" -) -plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)") +nrns_weight_record = { + "in": nrns_in[:n_record_w], + "rec": nrns_rec[:n_record_w], + "out": nrns_out, +} + +plot_weight_time_course(axs[0], events_wr, nrns_weight_record, "in_rec", r"$W_\text{in}$ (pA)") +plot_weight_time_course(axs[1], events_wr, nrns_weight_record, "rec_rec", r"$W_\text{rec}$ (pA)") +plot_weight_time_course(axs[2], events_wr, nrns_weight_record, "rec_out", r"$W_\text{out}$ (pA)") axs[-1].set_xlabel(r"$t$ (ms)") axs[-1].set_xlim(0, steps["task"]) diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py b/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py index 673fd9ed06..0659fff3c2 100644 --- a/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py +++ b/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves.py @@ -628,7 +628,10 @@ def plot_spikes(ax, events, ylabel, xlims): # the first time step and we add the initial weights manually. -def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel): +def plot_weight_time_course(ax, events, nrns_weight_record, label, ylabel): + sender_label, target_label = label.split("_") + nrns_senders = nrns_weight_record[sender_label] + nrns_targets = nrns_weight_record[target_label] for sender in nrns_senders.tolist(): for target in nrns_targets.tolist(): idc_syn = (events["senders"] == sender) & (events["targets"] == target) @@ -647,11 +650,15 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4)) fig.suptitle("Weight time courses") -plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)") -plot_weight_time_course( - axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)" -) -plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)") +nrns_weight_record = { + "in": nrns_in[:n_record_w], + "rec": nrns_rec[:n_record_w], + "out": nrns_out, +} + +plot_weight_time_course(axs[0], events_wr, nrns_weight_record, "in_rec", r"$W_\text{in}$ (pA)") +plot_weight_time_course(axs[1], events_wr, nrns_weight_record, "rec_rec", r"$W_\text{rec}$ (pA)") +plot_weight_time_course(axs[2], events_wr, nrns_weight_record, "rec_out", r"$W_\text{out}$ (pA)") axs[-1].set_xlabel(r"$t$ (ms)") axs[-1].set_xlim(0, steps["task"]) diff --git a/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves_bsshslm_2020.py b/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves_bsshslm_2020.py index 50a8248dc8..a1375ee279 100644 --- a/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves_bsshslm_2020.py +++ b/pynest/examples/eprop_plasticity/eprop_supervised_regression_sine-waves_bsshslm_2020.py @@ -602,7 +602,10 @@ def plot_spikes(ax, events, ylabel, xlims): # the first time step and we add the initial weights manually. -def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel): +def plot_weight_time_course(ax, events, nrns_weight_record, label, ylabel): + sender_label, target_label = label.split("_") + nrns_senders = nrns_weight_record[sender_label] + nrns_targets = nrns_weight_record[target_label] for sender in nrns_senders.tolist(): for target in nrns_targets.tolist(): idc_syn = (events["senders"] == sender) & (events["targets"] == target) @@ -621,11 +624,15 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4)) fig.suptitle("Weight time courses") -plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)") -plot_weight_time_course( - axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)" -) -plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)") +nrns_weight_record = { + "in": nrns_in[:n_record_w], + "rec": nrns_rec[:n_record_w], + "out": nrns_out, +} + +plot_weight_time_course(axs[0], events_wr, nrns_weight_record, "in_rec", r"$W_\text{in}$ (pA)") +plot_weight_time_course(axs[1], events_wr, nrns_weight_record, "rec_rec", r"$W_\text{rec}$ (pA)") +plot_weight_time_course(axs[2], events_wr, nrns_weight_record, "rec_out", r"$W_\text{out}$ (pA)") axs[-1].set_xlabel(r"$t$ (ms)") axs[-1].set_xlim(0, steps["task"]) diff --git a/testsuite/pytests/test_eprop_bsshslm_2020_plasticity.py b/testsuite/pytests/test_eprop_bsshslm_2020_plasticity.py index 38f0fcc67d..f8daaf7fe8 100644 --- a/testsuite/pytests/test_eprop_bsshslm_2020_plasticity.py +++ b/testsuite/pytests/test_eprop_bsshslm_2020_plasticity.py @@ -428,9 +428,12 @@ def test_eprop_classification(batch_size, loss_nest_reference): n_iter = 5 - n_input_symbols = 4 - n_cues = 7 - prob_group = 0.3 + input = { + "n_symbols": 4, + "n_cues": 7, + "prob_group": 0.3, + "spike_prob": 0.04, + } steps = { "cue": 100, @@ -439,7 +442,7 @@ def test_eprop_classification(batch_size, loss_nest_reference): "recall": 150, } - steps["cues"] = n_cues * (steps["cue"] + steps["spacing"]) + steps["cues"] = input["n_cues"] * (steps["cue"] + steps["spacing"]) steps["sequence"] = steps["cues"] + steps["bg_noise"] + steps["recall"] steps["learning_window"] = steps["recall"] steps["task"] = n_iter * batch_size * steps["sequence"] @@ -696,25 +699,23 @@ def calculate_glorot_dist(fan_in, fan_out): # Create input and output - def generate_evidence_accumulation_input_output( - batch_size, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps - ): - n_pop_nrn = n_in // n_input_symbols + def generate_evidence_accumulation_input_output(batch_size, n_in, steps, input): + n_pop_nrn = n_in // input["n_symbols"] - prob_choices = np.array([prob_group, 1 - prob_group], dtype=np.float32) + prob_choices = np.array([input["prob_group"], 1 - input["prob_group"]], dtype=np.float32) idx = np.random.choice([0, 1], batch_size) probs = np.zeros((batch_size, 2), dtype=np.float32) probs[:, 0] = prob_choices[idx] probs[:, 1] = prob_choices[1 - idx] - batched_cues = np.zeros((batch_size, n_cues), dtype=int) + batched_cues = np.zeros((batch_size, input["n_cues"]), dtype=int) for b_idx in range(batch_size): - batched_cues[b_idx, :] = np.random.choice([0, 1], n_cues, p=probs[b_idx]) + batched_cues[b_idx, :] = np.random.choice([0, 1], input["n_cues"], p=probs[b_idx]) input_spike_probs = np.zeros((batch_size, steps["sequence"], n_in)) for b_idx in range(batch_size): - for c_idx in range(n_cues): + for c_idx in range(input["n_cues"]): cue = batched_cues[b_idx, c_idx] step_start = c_idx * (steps["cue"] + steps["spacing"]) + steps["spacing"] @@ -723,28 +724,25 @@ def generate_evidence_accumulation_input_output( pop_nrn_start = cue * n_pop_nrn pop_nrn_stop = pop_nrn_start + n_pop_nrn - input_spike_probs[b_idx, step_start:step_stop, pop_nrn_start:pop_nrn_stop] = input_spike_prob + input_spike_probs[b_idx, step_start:step_stop, pop_nrn_start:pop_nrn_stop] = input["spike_prob"] - input_spike_probs[:, -steps["recall"] :, 2 * n_pop_nrn : 3 * n_pop_nrn] = input_spike_prob - input_spike_probs[:, :, 3 * n_pop_nrn :] = input_spike_prob / 4.0 + input_spike_probs[:, -steps["recall"] :, 2 * n_pop_nrn : 3 * n_pop_nrn] = input["spike_prob"] + input_spike_probs[:, :, 3 * n_pop_nrn :] = input["spike_prob"] / 4.0 input_spike_bools = input_spike_probs > np.random.rand(input_spike_probs.size).reshape(input_spike_probs.shape) input_spike_bools[:, 0, :] = 0 target_cues = np.zeros(batch_size, dtype=int) - target_cues[:] = np.sum(batched_cues, axis=1) > int(n_cues / 2) + target_cues[:] = np.sum(batched_cues, axis=1) > int(input["n_cues"] / 2) return input_spike_bools, target_cues - input_spike_prob = 0.04 dtype_in_spks = np.float32 input_spike_bools_list = [] target_cues_list = [] for _ in range(n_iter): - input_spike_bools, target_cues = generate_evidence_accumulation_input_output( - batch_size, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps - ) + input_spike_bools, target_cues = generate_evidence_accumulation_input_output(batch_size, n_in, steps, input) input_spike_bools_list.append(input_spike_bools) target_cues_list.extend(target_cues)