Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assessment colab improvements #273

Merged
merged 1 commit into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 49 additions & 70 deletions src/colab/skai_assessment_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "69ae1348",
"id": "39e23e24",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -123,7 +123,7 @@
},
{
"cell_type": "markdown",
"id": "96485247",
"id": "74e17a01",
"metadata": {},
"source": [
"#Initialization"
Expand All @@ -132,7 +132,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "d21a6e70",
"id": "83bba7a7",
"metadata": {
"cellView": "form",
"lines_to_next_cell": 1
Expand All @@ -145,6 +145,10 @@
"# @markdown proceed to run the next cell.\n",
"def install_requirements():\n",
" \"\"\"Installs necessary Python libraries.\"\"\"\n",
" !rm -rf {SKAI_CODE_DIR}\n",
" !git clone {SKAI_REPO} {SKAI_CODE_DIR}\n",
" !pip install {SKAI_CODE_DIR}/src/.\n",
"\n",
" requirements = textwrap.dedent('''\n",
" apache_beam[gcp]==2.54.0\n",
" google-cloud-storage>=2.18.2 # https://github.com/apache/beam/issues/32169\n",
Expand All @@ -170,7 +174,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "53984c45",
"id": "558d8bfb",
"metadata": {
"cellView": "form"
},
Expand All @@ -188,7 +192,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "07dd975e",
"id": "84c1eb77",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -408,7 +412,7 @@
},
{
"cell_type": "markdown",
"id": "5b054ec0",
"id": "321beae6",
"metadata": {},
"source": [
"# Check Assessment Status\n",
Expand All @@ -420,7 +424,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "f1e08f2c",
"id": "63bbef0e",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -478,7 +482,7 @@
},
{
"cell_type": "markdown",
"id": "e3b0b2c3",
"id": "2808e15f",
"metadata": {},
"source": [
"# Example Generation"
Expand All @@ -487,7 +491,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "42ce7cd9",
"id": "484b76a9",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -525,7 +529,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "0b9c6f7c",
"id": "f786e424",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -653,7 +657,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "ee69e7e1",
"id": "10f428db",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -707,7 +711,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "6bfbfe70",
"id": "7501f56a",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -741,7 +745,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "effab0a7",
"id": "6f29f530",
"metadata": {
"cellView": "form"
},
Expand All @@ -767,7 +771,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "f5975ab7",
"id": "5664daf3",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -802,7 +806,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "1dd2d0b7",
"id": "40122e86",
"metadata": {
"cellView": "form"
},
Expand All @@ -823,7 +827,7 @@
},
{
"cell_type": "markdown",
"id": "45e587f7",
"id": "a2777899",
"metadata": {},
"source": [
"# Labeling"
Expand All @@ -832,7 +836,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "63a55959",
"id": "1ab0cac0",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -934,7 +938,7 @@
},
{
"cell_type": "markdown",
"id": "ca1d29b2",
"id": "a1cfe69c",
"metadata": {},
"source": [
"When the labeling project is complete, download the CSV from the labeling tool\n",
Expand All @@ -947,7 +951,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "dcd367cf",
"id": "dac5741c",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -982,7 +986,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "8b442ceb",
"id": "175ea805",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -1038,20 +1042,18 @@
{
"cell_type": "code",
"execution_count": null,
"id": "83f330cc",
"id": "3f0b589d",
"metadata": {
"cellView": "form"
},
"outputs": [],
"source": [
"# @title Show Label Stats\n",
"def _load_examples_into_gdf(\n",
"def _load_examples_into_df(\n",
" train_tfrecords: str,\n",
" test_tfrecords: str,\n",
" max_examples: int,\n",
" load_images: bool,\n",
") -> tuple[list[tf.train.Example], gpd.GeoDataFrame]:\n",
" \"\"\"Loads examples from TFRecords into a GeoDataFrame.\n",
") -> pd.DataFrame:\n",
" \"\"\"Loads examples from TFRecords into a DataFrame.\n",
" \"\"\"\n",
" feature_config = {\n",
" 'example_id': tf.io.FixedLenFeature([], tf.string),\n",
Expand All @@ -1060,12 +1062,6 @@
" 'label': tf.io.FixedLenFeature([], tf.float32),\n",
" }\n",
"\n",
" if load_images:\n",
" feature_config['pre_image_png_large'] = tf.io.FixedLenFeature([], tf.string)\n",
" feature_config['post_image_png_large'] = tf.io.FixedLenFeature(\n",
" [], tf.string\n",
" )\n",
"\n",
" def _parse_examples(record_bytes):\n",
" return tf.io.parse_single_example(record_bytes, feature_config)\n",
"\n",
Expand All @@ -1086,64 +1082,47 @@
" columns['string_label'].append(features['string_label'].decode())\n",
" columns['label'].append(features['label'])\n",
" columns['source_path'].append(path)\n",
" if load_images:\n",
" columns['pre_image_bytes'].append(features['pre_image_png_large'])\n",
" columns['post_image_bytes'].append(features['post_image_png_large'])\n",
" if max_examples > 0 and len(columns['example_id']) >= max_examples:\n",
" break\n",
" if max_examples > 0 and len(columns['example_id']) >= max_examples:\n",
" break\n",
"\n",
" return gpd.GeoDataFrame(\n",
" columns,\n",
" geometry=gpd.points_from_xy(longitudes, latitudes))\n",
" return pd.DataFrame(columns)\n",
"\n",
"def _format_counts_table(df: pd.DataFrame):\n",
" for column in df.columns:\n",
" if column != 'All':\n",
" df[column] = [f'{x} ({x/t * 100:0.2f}%)' for x, t in zip(df[column], df['All'])]\n",
"\n",
"def label_counts(df: gpd.GeoDataFrame):\n",
" \"\"\"Creates tables showing various labeling stats.\"\"\"\n",
"def show_label_stats(train_tfrecord: str, test_tfrecord: str):\n",
" \"\"\"Displays tables showing label count stats.\"\"\"\n",
" df = _load_examples_into_df(train_tfrecord, test_tfrecord)\n",
" counts = df.pivot_table(\n",
" index='source_path',\n",
" columns='string_label',\n",
" aggfunc='count',\n",
" values='example_id',\n",
" margins=True)\n",
" margins=True,\n",
" fill_value=0)\n",
" _format_counts_table(counts)\n",
"\n",
" print('String Label Counts')\n",
" display(counts)\n",
"\n",
" percents = counts.drop(columns=['All']).div(counts['All'], axis=0) * 100\n",
"\n",
" print('String Label Percentages')\n",
" display(percents)\n",
" display(data_table.DataTable(counts))\n",
"\n",
" float_counts = df.pivot_table(\n",
" index='source_path',\n",
" columns='label',\n",
" aggfunc='count',\n",
" values='example_id',\n",
" margins=True)\n",
" margins=True,\n",
" fill_value=0.0)\n",
" _format_counts_table(float_counts)\n",
" print('Float Label Counts')\n",
" display(float_counts)\n",
"\n",
" float_percents = (\n",
" float_counts.drop(columns=['All']).div(float_counts['All'], axis=0) * 100\n",
" )\n",
" print('Float Label Percentages')\n",
" display(float_percents)\n",
"\n",
" display(data_table.DataTable(float_counts))\n",
"\n",
"def show_label_stats(\n",
" train_tfrecord: str,\n",
" test_tfrecord: str):\n",
" gdf = _load_examples_into_gdf(train_tfrecord, test_tfrecord, 1000, False)\n",
" label_counts(gdf)\n",
"\n",
"show_label_stats(TRAIN_TFRECORD, TEST_TFRECORD)"
]
},
{
"cell_type": "markdown",
"id": "5f737038",
"id": "44fb0a05",
"metadata": {},
"source": [
"# Fine Tuning"
Expand All @@ -1152,7 +1131,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "7dbcb5f4",
"id": "73e77d52",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -1226,7 +1205,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "65ef5c3f",
"id": "355937bb",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -1272,7 +1251,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "60d8de69",
"id": "750d21fd",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -1392,7 +1371,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "e4bdce5d",
"id": "149e47c0",
"metadata": {
"cellView": "form"
},
Expand Down
Loading
Loading