Skip to content

Commit

Permalink
Assessment colab improvements
Browse files Browse the repository at this point in the history
- Add SKAI library install code back into assessment colab.
- Improve formatting of labeling stats cell

PiperOrigin-RevId: 674338169
  • Loading branch information
jzxu authored and copybara-github committed Sep 16, 2024
1 parent 864fb9e commit 9c68fa0
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 114 deletions.
119 changes: 49 additions & 70 deletions src/colab/skai_assessment_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "69ae1348",
"id": "39e23e24",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -123,7 +123,7 @@
},
{
"cell_type": "markdown",
"id": "96485247",
"id": "74e17a01",
"metadata": {},
"source": [
"#Initialization"
Expand All @@ -132,7 +132,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "d21a6e70",
"id": "83bba7a7",
"metadata": {
"cellView": "form",
"lines_to_next_cell": 1
Expand All @@ -145,6 +145,10 @@
"# @markdown proceed to run the next cell.\n",
"def install_requirements():\n",
" \"\"\"Installs necessary Python libraries.\"\"\"\n",
" !rm -rf {SKAI_CODE_DIR}\n",
" !git clone {SKAI_REPO} {SKAI_CODE_DIR}\n",
" !pip install {SKAI_CODE_DIR}/src/.\n",
"\n",
" requirements = textwrap.dedent('''\n",
" apache_beam[gcp]==2.54.0\n",
" google-cloud-storage>=2.18.2 # https://github.com/apache/beam/issues/32169\n",
Expand All @@ -170,7 +174,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "53984c45",
"id": "558d8bfb",
"metadata": {
"cellView": "form"
},
Expand All @@ -188,7 +192,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "07dd975e",
"id": "84c1eb77",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -408,7 +412,7 @@
},
{
"cell_type": "markdown",
"id": "5b054ec0",
"id": "321beae6",
"metadata": {},
"source": [
"# Check Assessment Status\n",
Expand All @@ -420,7 +424,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "f1e08f2c",
"id": "63bbef0e",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -478,7 +482,7 @@
},
{
"cell_type": "markdown",
"id": "e3b0b2c3",
"id": "2808e15f",
"metadata": {},
"source": [
"# Example Generation"
Expand All @@ -487,7 +491,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "42ce7cd9",
"id": "484b76a9",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -525,7 +529,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "0b9c6f7c",
"id": "f786e424",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -653,7 +657,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "ee69e7e1",
"id": "10f428db",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -707,7 +711,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "6bfbfe70",
"id": "7501f56a",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -741,7 +745,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "effab0a7",
"id": "6f29f530",
"metadata": {
"cellView": "form"
},
Expand All @@ -767,7 +771,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "f5975ab7",
"id": "5664daf3",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -802,7 +806,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "1dd2d0b7",
"id": "40122e86",
"metadata": {
"cellView": "form"
},
Expand All @@ -823,7 +827,7 @@
},
{
"cell_type": "markdown",
"id": "45e587f7",
"id": "a2777899",
"metadata": {},
"source": [
"# Labeling"
Expand All @@ -832,7 +836,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "63a55959",
"id": "1ab0cac0",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -934,7 +938,7 @@
},
{
"cell_type": "markdown",
"id": "ca1d29b2",
"id": "a1cfe69c",
"metadata": {},
"source": [
"When the labeling project is complete, download the CSV from the labeling tool\n",
Expand All @@ -947,7 +951,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "dcd367cf",
"id": "dac5741c",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -982,7 +986,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "8b442ceb",
"id": "175ea805",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -1038,20 +1042,18 @@
{
"cell_type": "code",
"execution_count": null,
"id": "83f330cc",
"id": "3f0b589d",
"metadata": {
"cellView": "form"
},
"outputs": [],
"source": [
"# @title Show Label Stats\n",
"def _load_examples_into_gdf(\n",
"def _load_examples_into_df(\n",
" train_tfrecords: str,\n",
" test_tfrecords: str,\n",
" max_examples: int,\n",
" load_images: bool,\n",
") -> tuple[list[tf.train.Example], gpd.GeoDataFrame]:\n",
" \"\"\"Loads examples from TFRecords into a GeoDataFrame.\n",
") -> pd.DataFrame:\n",
" \"\"\"Loads examples from TFRecords into a DataFrame.\n",
" \"\"\"\n",
" feature_config = {\n",
" 'example_id': tf.io.FixedLenFeature([], tf.string),\n",
Expand All @@ -1060,12 +1062,6 @@
" 'label': tf.io.FixedLenFeature([], tf.float32),\n",
" }\n",
"\n",
" if load_images:\n",
" feature_config['pre_image_png_large'] = tf.io.FixedLenFeature([], tf.string)\n",
" feature_config['post_image_png_large'] = tf.io.FixedLenFeature(\n",
" [], tf.string\n",
" )\n",
"\n",
" def _parse_examples(record_bytes):\n",
" return tf.io.parse_single_example(record_bytes, feature_config)\n",
"\n",
Expand All @@ -1086,64 +1082,47 @@
" columns['string_label'].append(features['string_label'].decode())\n",
" columns['label'].append(features['label'])\n",
" columns['source_path'].append(path)\n",
" if load_images:\n",
" columns['pre_image_bytes'].append(features['pre_image_png_large'])\n",
" columns['post_image_bytes'].append(features['post_image_png_large'])\n",
" if max_examples > 0 and len(columns['example_id']) >= max_examples:\n",
" break\n",
" if max_examples > 0 and len(columns['example_id']) >= max_examples:\n",
" break\n",
"\n",
" return gpd.GeoDataFrame(\n",
" columns,\n",
" geometry=gpd.points_from_xy(longitudes, latitudes))\n",
" return pd.DataFrame(columns)\n",
"\n",
"def _format_counts_table(df: pd.DataFrame):\n",
" for column in df.columns:\n",
" if column != 'All':\n",
" df[column] = [f'{x} ({x/t * 100:0.2f}%)' for x, t in zip(df[column], df['All'])]\n",
"\n",
"def label_counts(df: gpd.GeoDataFrame):\n",
" \"\"\"Creates tables showing various labeling stats.\"\"\"\n",
"def show_label_stats(train_tfrecord: str, test_tfrecord: str):\n",
" \"\"\"Displays tables showing label count stats.\"\"\"\n",
" df = _load_examples_into_df(train_tfrecord, test_tfrecord)\n",
" counts = df.pivot_table(\n",
" index='source_path',\n",
" columns='string_label',\n",
" aggfunc='count',\n",
" values='example_id',\n",
" margins=True)\n",
" margins=True,\n",
" fill_value=0)\n",
" _format_counts_table(counts)\n",
"\n",
" print('String Label Counts')\n",
" display(counts)\n",
"\n",
" percents = counts.drop(columns=['All']).div(counts['All'], axis=0) * 100\n",
"\n",
" print('String Label Percentages')\n",
" display(percents)\n",
" display(data_table.DataTable(counts))\n",
"\n",
" float_counts = df.pivot_table(\n",
" index='source_path',\n",
" columns='label',\n",
" aggfunc='count',\n",
" values='example_id',\n",
" margins=True)\n",
" margins=True,\n",
" fill_value=0.0)\n",
" _format_counts_table(float_counts)\n",
" print('Float Label Counts')\n",
" display(float_counts)\n",
"\n",
" float_percents = (\n",
" float_counts.drop(columns=['All']).div(float_counts['All'], axis=0) * 100\n",
" )\n",
" print('Float Label Percentages')\n",
" display(float_percents)\n",
"\n",
" display(data_table.DataTable(float_counts))\n",
"\n",
"def show_label_stats(\n",
" train_tfrecord: str,\n",
" test_tfrecord: str):\n",
" gdf = _load_examples_into_gdf(train_tfrecord, test_tfrecord, 1000, False)\n",
" label_counts(gdf)\n",
"\n",
"show_label_stats(TRAIN_TFRECORD, TEST_TFRECORD)"
]
},
{
"cell_type": "markdown",
"id": "5f737038",
"id": "44fb0a05",
"metadata": {},
"source": [
"# Fine Tuning"
Expand All @@ -1152,7 +1131,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "7dbcb5f4",
"id": "73e77d52",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -1226,7 +1205,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "65ef5c3f",
"id": "355937bb",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -1272,7 +1251,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "60d8de69",
"id": "750d21fd",
"metadata": {
"cellView": "form"
},
Expand Down Expand Up @@ -1392,7 +1371,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "e4bdce5d",
"id": "149e47c0",
"metadata": {
"cellView": "form"
},
Expand Down
Loading

0 comments on commit 9c68fa0

Please sign in to comment.