Skip to content

Commit

Permalink
Fix pylint "too-many-positional-arguments" errors
Browse files Browse the repository at this point in the history
  • Loading branch information
akorgor committed Sep 28, 2024
1 parent c39d88d commit 154d9fd
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,12 @@
group_size = 32 # number of instances over which to evaluate the learning performance
n_iter = 50 # number of iterations

n_input_symbols = 4 # number of input populations, e.g. 4 = left, right, recall, noise
n_cues = 7 # number of cues given before decision
prob_group = 0.3 # probability with which one input group is present
input = {
"n_symbols": 4, # number of input populations, e.g. 4 = left, right, recall, noise
"n_cues": 7, # number of cues given before decision
"prob_group": 0.3, # probability with which one input group is present
"spike_prob": 0.04, # spike probability of frozen input noise
}

steps = {
"cue": 100, # time steps in one cue presentation
Expand All @@ -135,7 +138,7 @@
"recall": 150, # time steps of recall
}

steps["cues"] = n_cues * (steps["cue"] + steps["spacing"]) # time steps of all cues
steps["cues"] = input["n_cues"] * (steps["cue"] + steps["spacing"]) # time steps of all cues
steps["sequence"] = steps["cues"] + steps["bg_noise"] + steps["recall"] # time steps of one full sequence
steps["learning_window"] = steps["recall"] # time steps of window with non-zero learning signals
steps["task"] = n_iter * group_size * steps["sequence"] # time steps of task
Expand Down Expand Up @@ -457,25 +460,23 @@ def calculate_glorot_dist(fan_in, fan_out):
# assigned randomly to the left or right.


def generate_evidence_accumulation_input_output(
batch_size, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps
):
n_pop_nrn = n_in // n_input_symbols
def generate_evidence_accumulation_input_output(batch_size, n_in, steps, input):
n_pop_nrn = n_in // input["n_symbols"]

prob_choices = np.array([prob_group, 1 - prob_group], dtype=np.float32)
prob_choices = np.array([input["prob_group"], 1 - input["prob_group"]], dtype=np.float32)
idx = np.random.choice([0, 1], batch_size)
probs = np.zeros((batch_size, 2), dtype=np.float32)
probs[:, 0] = prob_choices[idx]
probs[:, 1] = prob_choices[1 - idx]

batched_cues = np.zeros((batch_size, n_cues), dtype=int)
batched_cues = np.zeros((batch_size, input["n_cues"]), dtype=int)
for b_idx in range(batch_size):
batched_cues[b_idx, :] = np.random.choice([0, 1], n_cues, p=probs[b_idx])
batched_cues[b_idx, :] = np.random.choice([0, 1], input["n_cues"], p=probs[b_idx])

input_spike_probs = np.zeros((batch_size, steps["sequence"], n_in))

for b_idx in range(batch_size):
for c_idx in range(n_cues):
for c_idx in range(input["n_cues"]):
cue = batched_cues[b_idx, c_idx]

step_start = c_idx * (steps["cue"] + steps["spacing"]) + steps["spacing"]
Expand All @@ -484,29 +485,26 @@ def generate_evidence_accumulation_input_output(
pop_nrn_start = cue * n_pop_nrn
pop_nrn_stop = pop_nrn_start + n_pop_nrn

input_spike_probs[b_idx, step_start:step_stop, pop_nrn_start:pop_nrn_stop] = input_spike_prob
input_spike_probs[b_idx, step_start:step_stop, pop_nrn_start:pop_nrn_stop] = input["spike_prob"]

input_spike_probs[:, -steps["recall"] :, 2 * n_pop_nrn : 3 * n_pop_nrn] = input_spike_prob
input_spike_probs[:, :, 3 * n_pop_nrn :] = input_spike_prob / 4.0
input_spike_probs[:, -steps["recall"] :, 2 * n_pop_nrn : 3 * n_pop_nrn] = input["spike_prob"]
input_spike_probs[:, :, 3 * n_pop_nrn :] = input["spike_prob"] / 4.0
input_spike_bools = input_spike_probs > np.random.rand(input_spike_probs.size).reshape(input_spike_probs.shape)
input_spike_bools[:, 0, :] = 0 # remove spikes in 0th time step of every sequence for technical reasons

target_cues = np.zeros(batch_size, dtype=int)
target_cues[:] = np.sum(batched_cues, axis=1) > int(n_cues / 2)
target_cues[:] = np.sum(batched_cues, axis=1) > int(input["n_cues"] / 2)

return input_spike_bools, target_cues


input_spike_prob = 0.04 # spike probability of frozen input noise
dtype_in_spks = np.float32 # data type of input spikes - for reproducing TF results set to np.float32

input_spike_bools_list = []
target_cues_list = []

for _ in range(n_iter):
input_spike_bools, target_cues = generate_evidence_accumulation_input_output(
group_size, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps
)
input_spike_bools, target_cues = generate_evidence_accumulation_input_output(group_size, n_in, steps, input)
input_spike_bools_list.append(input_spike_bools)
target_cues_list.extend(target_cues)

Expand Down Expand Up @@ -768,7 +766,10 @@ def plot_spikes(ax, events, ylabel, xlims):
# the first time step and we add the initial weights manually.


def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel):
def plot_weight_time_course(ax, events, nrns_weight_record, label, ylabel):
sender_label, target_label = label.split("_")
nrns_senders = nrns_weight_record[sender_label]
nrns_targets = nrns_weight_record[target_label]
for sender in nrns_senders.tolist():
for target in nrns_targets.tolist():
idc_syn = (events["senders"] == sender) & (events["targets"] == target)
Expand All @@ -787,11 +788,15 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe
fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4))
fig.suptitle("Weight time courses")

plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)")
plot_weight_time_course(
axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)"
)
plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)")
nrns_weight_record = {
"in": nrns_in[:n_record_w],
"rec": nrns_rec[:n_record_w],
"out": nrns_out,
}

plot_weight_time_course(axs[0], events_wr, nrns_weight_record, "in_rec", r"$W_\text{in}$ (pA)")
plot_weight_time_course(axs[1], events_wr, nrns_weight_record, "rec_rec", r"$W_\text{rec}$ (pA)")
plot_weight_time_course(axs[2], events_wr, nrns_weight_record, nrns_out, "rec_out", r"$W_\text{out}$ (pA)")

axs[-1].set_xlabel(r"$t$ (ms)")
axs[-1].set_xlim(0, steps["task"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,12 @@
batch_size = 32 # batch size, 64 in reference [2], 32 in the README to reference [2]
n_iter = 50 # number of iterations, 2000 in reference [2]

n_input_symbols = 4 # number of input populations, e.g. 4 = left, right, recall, noise
n_cues = 7 # number of cues given before decision
prob_group = 0.3 # probability with which one input group is present
input = {
"n_symbols": 4, # number of input populations, e.g. 4 = left, right, recall, noise
"n_cues": 7, # number of cues given before decision
"prob_group": 0.3, # probability with which one input group is present
"spike_prob": 0.04, # spike probability of frozen input noise
}

do_early_stopping = True # if True, stop training as soon as stop criterion fulfilled
n_validate_every = 10 # number of training iterations before validation
Expand All @@ -145,7 +148,7 @@
"recall": 150, # time steps of recall
}

steps["cues"] = n_cues * (steps["cue"] + steps["spacing"]) # time steps of all cues
steps["cues"] = input["n_cues"] * (steps["cue"] + steps["spacing"]) # time steps of all cues
steps["sequence"] = steps["cues"] + steps["bg_noise"] + steps["recall"] # time steps of one full sequence
steps["learning_window"] = steps["recall"] # time steps of window with non-zero learning signals

Expand Down Expand Up @@ -468,25 +471,23 @@ def calculate_glorot_dist(fan_in, fan_out):
# assigned randomly to the left or right.


def generate_evidence_accumulation_input_output(
batch_size, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps
):
n_pop_nrn = n_in // n_input_symbols
def generate_evidence_accumulation_input_output(batch_size, n_in, steps, input):
n_pop_nrn = n_in // input["n_symbols"]

prob_choices = np.array([prob_group, 1 - prob_group], dtype=np.float32)
prob_choices = np.array([input["prob_group"], 1 - input["prob_group"]], dtype=np.float32)
idx = np.random.choice([0, 1], batch_size)
probs = np.zeros((batch_size, 2), dtype=np.float32)
probs[:, 0] = prob_choices[idx]
probs[:, 1] = prob_choices[1 - idx]

batched_cues = np.zeros((batch_size, n_cues), dtype=int)
batched_cues = np.zeros((batch_size, input["n_cues"]), dtype=int)
for b_idx in range(batch_size):
batched_cues[b_idx, :] = np.random.choice([0, 1], n_cues, p=probs[b_idx])
batched_cues[b_idx, :] = np.random.choice([0, 1], input["n_cues"], p=probs[b_idx])

input_spike_probs = np.zeros((batch_size, steps["sequence"], n_in))

for b_idx in range(batch_size):
for c_idx in range(n_cues):
for c_idx in range(input["n_cues"]):
cue = batched_cues[b_idx, c_idx]

step_start = c_idx * (steps["cue"] + steps["spacing"]) + steps["spacing"]
Expand All @@ -495,30 +496,27 @@ def generate_evidence_accumulation_input_output(
pop_nrn_start = cue * n_pop_nrn
pop_nrn_stop = pop_nrn_start + n_pop_nrn

input_spike_probs[b_idx, step_start:step_stop, pop_nrn_start:pop_nrn_stop] = input_spike_prob
input_spike_probs[b_idx, step_start:step_stop, pop_nrn_start:pop_nrn_stop] = input["spike_prob"]

input_spike_probs[:, -steps["recall"] :, 2 * n_pop_nrn : 3 * n_pop_nrn] = input_spike_prob
input_spike_probs[:, :, 3 * n_pop_nrn :] = input_spike_prob / 4.0
input_spike_probs[:, -steps["recall"] :, 2 * n_pop_nrn : 3 * n_pop_nrn] = input["spike_prob"]
input_spike_probs[:, :, 3 * n_pop_nrn :] = input["spike_prob"] / 4.0
input_spike_bools = input_spike_probs > np.random.rand(input_spike_probs.size).reshape(input_spike_probs.shape)
input_spike_bools[:, 0, :] = 0 # remove spikes in 0th time step of every sequence for technical reasons

target_cues = np.zeros(batch_size, dtype=int)
target_cues[:] = np.sum(batched_cues, axis=1) > int(n_cues / 2)
target_cues[:] = np.sum(batched_cues, axis=1) > int(input["n_cues"] / 2)

return input_spike_bools, target_cues


def get_params_task_input_output(n_iter_interval):
input_spike_prob = 0.04 # spike probability of frozen input noise
dtype_in_spks = np.float32 # data type of input spikes - for reproducing TF results set to np.float32

input_spike_bools_list = []
target_cues_list = []

for _ in range(n_iter_interval):
input_spike_bools, target_cues = generate_evidence_accumulation_input_output(
batch_size, n_in, prob_group, input_spike_prob, n_cues, n_input_symbols, steps
)
input_spike_bools, target_cues = generate_evidence_accumulation_input_output(batch_size, n_in, steps, input)
input_spike_bools_list.append(input_spike_bools)
target_cues_list.extend(target_cues)

Expand Down Expand Up @@ -876,7 +874,10 @@ def plot_spikes(ax, events, ylabel, xlims):
# the first time step and we add the initial weights manually.


def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel):
def plot_weight_time_course(ax, events, nrns_weight_record, label, ylabel):
sender_label, target_label = label.split("_")
nrns_senders = nrns_weight_record[sender_label]
nrns_targets = nrns_weight_record[target_label]
for sender in nrns_senders.tolist():
for target in nrns_targets.tolist():
idc_syn = (events["senders"] == sender) & (events["targets"] == target)
Expand All @@ -895,11 +896,15 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe
fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4))
fig.suptitle("Weight time courses")

plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)")
plot_weight_time_course(
axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)"
)
plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)")
nrns_weight_record = {
"in": nrns_in[:n_record_w],
"rec": nrns_rec[:n_record_w],
"out": nrns_out,
}

plot_weight_time_course(axs[0], events_wr, nrns_weight_record, "in_rec", r"$W_\text{in}$ (pA)")
plot_weight_time_course(axs[1], events_wr, nrns_weight_record, "rec_rec", r"$W_\text{rec}$ (pA)")
plot_weight_time_course(axs[2], events_wr, nrns_weight_record, "rec_out", r"$W_\text{out}$ (pA)")

axs[-1].set_xlabel(r"$t$ (ms)")
axs[-1].set_xlim(0, duration["task"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,10 @@ def plot_spikes(ax, events, ylabel, xlims):
# the first time step and we add the initial weights manually.


def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel):
def plot_weight_time_course(ax, events, nrns_weight_record, label, ylabel):
sender_label, target_label = label.split("_")
nrns_senders = nrns_weight_record[sender_label]
nrns_targets = nrns_weight_record[target_label]
for sender in nrns_senders.tolist():
for target in nrns_targets.tolist():
idc_syn = (events["senders"] == sender) & (events["targets"] == target)
Expand All @@ -843,11 +846,15 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe
fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4))
fig.suptitle("Weight time courses")

plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)")
plot_weight_time_course(
axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)"
)
plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)")
nrns_weight_record = {
"in": nrns_in[:n_record_w],
"rec": nrns_rec[:n_record_w],
"out": nrns_out,
}

plot_weight_time_course(axs[0], events_wr, nrns_weight_record, "in_rec", r"$W_\text{in}$ (pA)")
plot_weight_time_course(axs[1], events_wr, nrns_weight_record, "rec_rec", r"$W_\text{rec}$ (pA)")
plot_weight_time_course(axs[2], events_wr, nrns_weight_record, "rec_out", r"$W_\text{out}$ (pA)")

axs[-1].set_xlabel(r"$t$ (ms)")
axs[-1].set_xlim(0, steps["task"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,10 @@ def plot_spikes(ax, events, ylabel, xlims):
# the first time step and we add the initial weights manually.


def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel):
def plot_weight_time_course(ax, events, nrns_weight_record, label, ylabel):
sender_label, target_label = label.split("_")
nrns_senders = nrns_weight_record[sender_label]
nrns_targets = nrns_weight_record[target_label]
for sender in nrns_senders.tolist():
for target in nrns_targets.tolist():
idc_syn = (events["senders"] == sender) & (events["targets"] == target)
Expand All @@ -691,11 +694,15 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe
fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4))
fig.suptitle("Weight time courses")

plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)")
plot_weight_time_course(
axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)"
)
plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)")
nrns_weight_record = {
"in": nrns_in[:n_record_w],
"rec": nrns_rec[:n_record_w],
"out": nrns_out,
}

plot_weight_time_course(axs[0], events_wr, nrns_weight_record, "in_rec", r"$W_\text{in}$ (pA)")
plot_weight_time_course(axs[1], events_wr, nrns_weight_record, "rec_rec", r"$W_\text{rec}$ (pA)")
plot_weight_time_course(axs[2], events_wr, nrns_weight_record, "rec_out", r"$W_\text{out}$ (pA)")

axs[-1].set_xlabel(r"$t$ (ms)")
axs[-1].set_xlim(0, steps["task"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,10 @@ def plot_spikes(ax, events, ylabel, xlims):
# the first time step and we add the initial weights manually.


def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel):
def plot_weight_time_course(ax, events, nrns_weight_record, label, ylabel):
sender_label, target_label = label.split("_")
nrns_senders = nrns_weight_record[sender_label]
nrns_targets = nrns_weight_record[target_label]
for sender in nrns_senders.tolist():
for target in nrns_targets.tolist():
idc_syn = (events["senders"] == sender) & (events["targets"] == target)
Expand All @@ -672,11 +675,15 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe
fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4))
fig.suptitle("Weight time courses")

plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)")
plot_weight_time_course(
axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)"
)
plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)")
nrns_weight_record = {
"in": nrns_in[:n_record_w],
"rec": nrns_rec[:n_record_w],
"out": nrns_out,
}

plot_weight_time_course(axs[0], events_wr, nrns_weight_record, "in_rec", r"$W_\text{in}$ (pA)")
plot_weight_time_course(axs[1], events_wr, nrns_weight_record, "rec_rec", r"$W_\text{rec}$ (pA)")
plot_weight_time_course(axs[2], events_wr, nrns_weight_record, "rec_out", r"$W_\text{out}$ (pA)")

axs[-1].set_xlabel(r"$t$ (ms)")
axs[-1].set_xlim(0, steps["task"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,10 @@ def plot_spikes(ax, events, ylabel, xlims):
# the first time step and we add the initial weights manually.


def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabel):
def plot_weight_time_course(ax, events, nrns_weight_record, label, ylabel):
sender_label, target_label = label.split("_")
nrns_senders = nrns_weight_record[sender_label]
nrns_targets = nrns_weight_record[target_label]
for sender in nrns_senders.tolist():
for target in nrns_targets.tolist():
idc_syn = (events["senders"] == sender) & (events["targets"] == target)
Expand All @@ -647,11 +650,15 @@ def plot_weight_time_course(ax, events, nrns_senders, nrns_targets, label, ylabe
fig, axs = plt.subplots(3, 1, sharex=True, figsize=(3, 4))
fig.suptitle("Weight time courses")

plot_weight_time_course(axs[0], events_wr, nrns_in[:n_record_w], nrns_rec[:n_record_w], "in_rec", r"$W_\text{in}$ (pA)")
plot_weight_time_course(
axs[1], events_wr, nrns_rec[:n_record_w], nrns_rec[:n_record_w], "rec_rec", r"$W_\text{rec}$ (pA)"
)
plot_weight_time_course(axs[2], events_wr, nrns_rec[:n_record_w], nrns_out, "rec_out", r"$W_\text{out}$ (pA)")
nrns_weight_record = {
"in": nrns_in[:n_record_w],
"rec": nrns_rec[:n_record_w],
"out": nrns_out,
}

plot_weight_time_course(axs[0], events_wr, nrns_weight_record, "in_rec", r"$W_\text{in}$ (pA)")
plot_weight_time_course(axs[1], events_wr, nrns_weight_record, "rec_rec", r"$W_\text{rec}$ (pA)")
plot_weight_time_course(axs[2], events_wr, nrns_weight_record, "rec_out", r"$W_\text{out}$ (pA)")

axs[-1].set_xlabel(r"$t$ (ms)")
axs[-1].set_xlim(0, steps["task"])
Expand Down
Loading

0 comments on commit 154d9fd

Please sign in to comment.