From a0a10ba687f98820ec82ec41014afa8dab558600 Mon Sep 17 00:00:00 2001 From: Joseph Xu Date: Mon, 16 Sep 2024 12:58:08 -0700 Subject: [PATCH] Assessment colab improvements - Add SKAI library install code back into assessment colab. - Improve formatting of labeling stats cell PiperOrigin-RevId: 675259198 --- src/colab/skai_assessment_notebook.ipynb | 119 ++++++++++------------- src/colab/skai_assessment_notebook.py | 69 +++++-------- 2 files changed, 74 insertions(+), 114 deletions(-) diff --git a/src/colab/skai_assessment_notebook.ipynb b/src/colab/skai_assessment_notebook.ipynb index 6e7e4c2..b5ee301 100644 --- a/src/colab/skai_assessment_notebook.ipynb +++ b/src/colab/skai_assessment_notebook.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": null, - "id": "69ae1348", + "id": "39e23e24", "metadata": { "cellView": "form" }, @@ -123,7 +123,7 @@ }, { "cell_type": "markdown", - "id": "96485247", + "id": "74e17a01", "metadata": {}, "source": [ "#Initialization" @@ -132,7 +132,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d21a6e70", + "id": "83bba7a7", "metadata": { "cellView": "form", "lines_to_next_cell": 1 @@ -145,6 +145,10 @@ "# @markdown proceed to run the next cell.\n", "def install_requirements():\n", " \"\"\"Installs necessary Python libraries.\"\"\"\n", + " !rm -rf {SKAI_CODE_DIR}\n", + " !git clone {SKAI_REPO} {SKAI_CODE_DIR}\n", + " !pip install {SKAI_CODE_DIR}/src/.\n", + "\n", " requirements = textwrap.dedent('''\n", " apache_beam[gcp]==2.54.0\n", " google-cloud-storage>=2.18.2 # https://github.com/apache/beam/issues/32169\n", @@ -170,7 +174,7 @@ { "cell_type": "code", "execution_count": null, - "id": "53984c45", + "id": "558d8bfb", "metadata": { "cellView": "form" }, @@ -188,7 +192,7 @@ { "cell_type": "code", "execution_count": null, - "id": "07dd975e", + "id": "84c1eb77", "metadata": { "cellView": "form" }, @@ -408,7 +412,7 @@ }, { "cell_type": "markdown", - "id": "5b054ec0", + "id": "321beae6", "metadata": {}, "source": [ "# Check Assessment Status\n", @@ -420,7 +424,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f1e08f2c", + "id": "63bbef0e", "metadata": { "cellView": "form" }, @@ -478,7 +482,7 @@ }, { "cell_type": "markdown", - "id": "e3b0b2c3", + "id": "2808e15f", "metadata": {}, "source": [ "# Example Generation" @@ -487,7 +491,7 @@ { "cell_type": "code", "execution_count": null, - "id": "42ce7cd9", + "id": "484b76a9", "metadata": { "cellView": "form" }, @@ -525,7 +529,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0b9c6f7c", + "id": "f786e424", "metadata": { "cellView": "form" }, @@ -653,7 +657,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ee69e7e1", + "id": "10f428db", "metadata": { "cellView": "form" }, @@ -707,7 +711,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6bfbfe70", + "id": "7501f56a", "metadata": { "cellView": "form" }, @@ -741,7 +745,7 @@ { "cell_type": "code", "execution_count": null, - "id": "effab0a7", + "id": "6f29f530", "metadata": { "cellView": "form" }, @@ -767,7 +771,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f5975ab7", + "id": "5664daf3", "metadata": { "cellView": "form" }, @@ -802,7 +806,7 @@ { "cell_type": "code", "execution_count": null, - "id": "1dd2d0b7", + "id": "40122e86", "metadata": { "cellView": "form" }, @@ -823,7 +827,7 @@ }, { "cell_type": "markdown", - "id": "45e587f7", + "id": "a2777899", "metadata": {}, "source": [ "# Labeling" @@ -832,7 +836,7 @@ { "cell_type": "code", "execution_count": null, - "id": "63a55959", + "id": "1ab0cac0", "metadata": { "cellView": "form" }, @@ -934,7 +938,7 @@ }, { "cell_type": "markdown", - "id": "ca1d29b2", + "id": "a1cfe69c", "metadata": {}, "source": [ "When the labeling project is complete, download the CSV from the labeling tool\n", @@ -947,7 +951,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dcd367cf", + "id": "dac5741c", "metadata": { "cellView": "form" }, @@ -982,7 +986,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8b442ceb", + "id": "175ea805", "metadata": { "cellView": "form" }, @@ -1038,20 +1042,18 @@ { "cell_type": "code", "execution_count": null, - "id": "83f330cc", + "id": "3f0b589d", "metadata": { "cellView": "form" }, "outputs": [], "source": [ "# @title Show Label Stats\n", - "def _load_examples_into_gdf(\n", + "def _load_examples_into_df(\n", " train_tfrecords: str,\n", " test_tfrecords: str,\n", - " max_examples: int,\n", - " load_images: bool,\n", - ") -> tuple[list[tf.train.Example], gpd.GeoDataFrame]:\n", - " \"\"\"Loads examples from TFRecords into a GeoDataFrame.\n", + ") -> pd.DataFrame:\n", + " \"\"\"Loads examples from TFRecords into a DataFrame.\n", " \"\"\"\n", " feature_config = {\n", " 'example_id': tf.io.FixedLenFeature([], tf.string),\n", @@ -1060,12 +1062,6 @@ " 'label': tf.io.FixedLenFeature([], tf.float32),\n", " }\n", "\n", - " if load_images:\n", - " feature_config['pre_image_png_large'] = tf.io.FixedLenFeature([], tf.string)\n", - " feature_config['post_image_png_large'] = tf.io.FixedLenFeature(\n", - " [], tf.string\n", - " )\n", - "\n", " def _parse_examples(record_bytes):\n", " return tf.io.parse_single_example(record_bytes, feature_config)\n", "\n", @@ -1086,64 +1082,47 @@ " columns['string_label'].append(features['string_label'].decode())\n", " columns['label'].append(features['label'])\n", " columns['source_path'].append(path)\n", - " if load_images:\n", - " columns['pre_image_bytes'].append(features['pre_image_png_large'])\n", - " columns['post_image_bytes'].append(features['post_image_png_large'])\n", - " if max_examples > 0 and len(columns['example_id']) >= max_examples:\n", - " break\n", - " if max_examples > 0 and len(columns['example_id']) >= max_examples:\n", - " break\n", "\n", - " return gpd.GeoDataFrame(\n", - " columns,\n", - " geometry=gpd.points_from_xy(longitudes, latitudes))\n", + " return pd.DataFrame(columns)\n", "\n", + "def _format_counts_table(df: pd.DataFrame):\n", + " for column in df.columns:\n", + " if column != 'All':\n", + " df[column] = [f'{x} ({x/t * 100:0.2f}%)' for x, t in zip(df[column], df['All'])]\n", "\n", - "def label_counts(df: gpd.GeoDataFrame):\n", - " \"\"\"Creates tables showing various labeling stats.\"\"\"\n", + "def show_label_stats(train_tfrecord: str, test_tfrecord: str):\n", + " \"\"\"Displays tables showing label count stats.\"\"\"\n", + " df = _load_examples_into_df(train_tfrecord, test_tfrecord)\n", " counts = df.pivot_table(\n", " index='source_path',\n", " columns='string_label',\n", " aggfunc='count',\n", " values='example_id',\n", - " margins=True)\n", + " margins=True,\n", + " fill_value=0)\n", + " _format_counts_table(counts)\n", "\n", " print('String Label Counts')\n", - " display(counts)\n", - "\n", - " percents = counts.drop(columns=['All']).div(counts['All'], axis=0) * 100\n", - "\n", - " print('String Label Percentages')\n", - " display(percents)\n", + " display(data_table.DataTable(counts))\n", "\n", " float_counts = df.pivot_table(\n", " index='source_path',\n", " columns='label',\n", " aggfunc='count',\n", " values='example_id',\n", - " margins=True)\n", + " margins=True,\n", + " fill_value=0.0)\n", + " _format_counts_table(float_counts)\n", " print('Float Label Counts')\n", - " display(float_counts)\n", - "\n", - " float_percents = (\n", - " float_counts.drop(columns=['All']).div(float_counts['All'], axis=0) * 100\n", - " )\n", - " print('Float Label Percentages')\n", - " display(float_percents)\n", - "\n", + " display(data_table.DataTable(float_counts))\n", "\n", - "def show_label_stats(\n", - " train_tfrecord: str,\n", - " test_tfrecord: str):\n", - " gdf = _load_examples_into_gdf(train_tfrecord, test_tfrecord, 1000, False)\n", - " label_counts(gdf)\n", "\n", "show_label_stats(TRAIN_TFRECORD, TEST_TFRECORD)" ] }, { "cell_type": "markdown", - "id": "5f737038", + "id": "44fb0a05", "metadata": {}, "source": [ "# Fine Tuning" @@ -1152,7 +1131,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7dbcb5f4", + "id": "73e77d52", "metadata": { "cellView": "form" }, @@ -1226,7 +1205,7 @@ { "cell_type": "code", "execution_count": null, - "id": "65ef5c3f", + "id": "355937bb", "metadata": { "cellView": "form" }, @@ -1272,7 +1251,7 @@ { "cell_type": "code", "execution_count": null, - "id": "60d8de69", + "id": "750d21fd", "metadata": { "cellView": "form" }, @@ -1392,7 +1371,7 @@ { "cell_type": "code", "execution_count": null, - "id": "e4bdce5d", + "id": "149e47c0", "metadata": { "cellView": "form" }, diff --git a/src/colab/skai_assessment_notebook.py b/src/colab/skai_assessment_notebook.py index e743668..a68fcbb 100644 --- a/src/colab/skai_assessment_notebook.py +++ b/src/colab/skai_assessment_notebook.py @@ -135,6 +135,10 @@ def process_image_entries(entries: list[str]) -> list[str]: # @markdown proceed to run the next cell. def install_requirements(): """Installs necessary Python libraries.""" + # !rm -rf {SKAI_CODE_DIR} + # !git clone {SKAI_REPO} {SKAI_CODE_DIR} + # !pip install {SKAI_CODE_DIR}/src/. + requirements = textwrap.dedent(''' apache_beam[gcp]==2.54.0 google-cloud-storage>=2.18.2 # https://github.com/apache/beam/issues/32169 @@ -897,13 +901,11 @@ def create_labeled_examples( # %% cellView="form" # @title Show Label Stats -def _load_examples_into_gdf( +def _load_examples_into_df( train_tfrecords: str, test_tfrecords: str, - max_examples: int, - load_images: bool, -) -> tuple[list[tf.train.Example], gpd.GeoDataFrame]: - """Loads examples from TFRecords into a GeoDataFrame. +) -> pd.DataFrame: + """Loads examples from TFRecords into a DataFrame. """ feature_config = { 'example_id': tf.io.FixedLenFeature([], tf.string), @@ -912,12 +914,6 @@ def _load_examples_into_gdf( 'label': tf.io.FixedLenFeature([], tf.float32), } - if load_images: - feature_config['pre_image_png_large'] = tf.io.FixedLenFeature([], tf.string) - feature_config['post_image_png_large'] = tf.io.FixedLenFeature( - [], tf.string - ) - def _parse_examples(record_bytes): return tf.io.parse_single_example(record_bytes, feature_config) @@ -938,57 +934,42 @@ def _parse_examples(record_bytes): columns['string_label'].append(features['string_label'].decode()) columns['label'].append(features['label']) columns['source_path'].append(path) - if load_images: - columns['pre_image_bytes'].append(features['pre_image_png_large']) - columns['post_image_bytes'].append(features['post_image_png_large']) - if max_examples > 0 and len(columns['example_id']) >= max_examples: - break - if max_examples > 0 and len(columns['example_id']) >= max_examples: - break - return gpd.GeoDataFrame( - columns, - geometry=gpd.points_from_xy(longitudes, latitudes)) + return pd.DataFrame(columns) +def _format_counts_table(df: pd.DataFrame): + for column in df.columns: + if column != 'All': + df[column] = [ + f'{x} ({x/t * 100:0.2f}%)' for x, t in zip(df[column], df['All']) + ] -def label_counts(df: gpd.GeoDataFrame): - """Creates tables showing various labeling stats.""" +def show_label_stats(train_tfrecord: str, test_tfrecord: str): + """Displays tables showing label count stats.""" + df = _load_examples_into_df(train_tfrecord, test_tfrecord) counts = df.pivot_table( index='source_path', columns='string_label', aggfunc='count', values='example_id', - margins=True) + margins=True, + fill_value=0) + _format_counts_table(counts) print('String Label Counts') - display(counts) - - percents = counts.drop(columns=['All']).div(counts['All'], axis=0) * 100 - - print('String Label Percentages') - display(percents) + display(data_table.DataTable(counts)) float_counts = df.pivot_table( index='source_path', columns='label', aggfunc='count', values='example_id', - margins=True) + margins=True, + fill_value=0.0) + _format_counts_table(float_counts) print('Float Label Counts') - display(float_counts) - - float_percents = ( - float_counts.drop(columns=['All']).div(float_counts['All'], axis=0) * 100 - ) - print('Float Label Percentages') - display(float_percents) - + display(data_table.DataTable(float_counts)) -def show_label_stats( - train_tfrecord: str, - test_tfrecord: str): - gdf = _load_examples_into_gdf(train_tfrecord, test_tfrecord, 1000, False) - label_counts(gdf) show_label_stats(TRAIN_TFRECORD, TEST_TFRECORD)